wcs/wcs/sql.py

4642 lines
163 KiB
Python

# w.c.s. - web application for online forms
# Copyright (C) 2005-2012 Entr'ouvert
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import copy
import datetime
import io
import json
import re
import time
import psycopg2
import psycopg2.extensions
import psycopg2.extras
import unidecode
from psycopg2.sql import SQL, Identifier, Literal
try:
import cPickle as pickle
except ImportError:
import pickle
from django.utils.encoding import force_bytes, force_text
from django.utils.timezone import now
from quixote import get_publisher
import wcs.carddata
import wcs.categories
import wcs.custom_views
import wcs.formdata
import wcs.logged_errors
import wcs.qommon.tokens
import wcs.roles
import wcs.snapshots
import wcs.tracking_code
import wcs.users
from wcs.qommon import PICKLE_KWARGS, force_str
from . import qommon
from .publisher import UnpicklerClass
from .qommon import _, get_cfg
from .qommon.misc import JSONEncoder, strftime
from .qommon.storage import NothingToUpdate, _take, deep_bytes2str
from .qommon.storage import parse_clause as parse_storage_clause
from .qommon.substitution import invalidate_substitution_cache
from .qommon.upload_storage import PicklableUpload
# enable psycogp2 unicode mode, this will fetch postgresql varchar/text columns
# as unicode objects
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
# automatically adapt dictionaries into json fields
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)
SQL_TYPE_MAPPING = {
'title': None,
'subtitle': None,
'comment': None,
'page': None,
'text': 'text',
'bool': 'boolean',
'file': 'bytea',
'date': 'date',
'items': 'text[]',
'table': 'text[][]',
'table-select': 'text[][]',
'tablerows': 'text[][]',
# mapping of dicts
'ranked-items': 'text[][]',
'password': 'text[][]',
# field block
'block': 'jsonb',
# computed data field
'computed': 'jsonb',
}
def like_escape(value):
value = value.replace('\\', '\\\\')
value = value.replace('_', '\\_')
value = value.replace('%', '\\%')
return value
def pickle_loads(value):
if hasattr(value, 'tobytes'):
value = value.tobytes()
obj = UnpicklerClass(io.BytesIO(force_bytes(value)), **PICKLE_KWARGS).load()
obj = deep_bytes2str(obj)
return obj
class Criteria(qommon.storage.Criteria):
def __init__(self, attribute, value, **kwargs):
self.attribute = attribute.replace('-', '_')
self.value = value
self.field = kwargs.get('field')
def as_sql(self):
if self.field and getattr(self.field, 'block_field', None):
# eq: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
# lt: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' < 'value')
# lte: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' <= 'value')
# gt: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' > 'value')
# gte: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' >= 'value')
# with a NOT EXISTS and the opposite operator:
# ne: NOT EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
# note: aa->>'FOOBAR' can be written with an integer or bool cast
attribute = "aa->>'%s'" % self.field.id
if self.field.key in ['item', 'string'] and isinstance(self.value, int):
# integer cast of db values
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int ELSE NULL END)" % (
attribute,
attribute,
)
elif self.field.key == 'bool':
# bool cast of db values
attribute = '(%s)::bool' % attribute
return "%s(SELECT 1 FROM jsonb_array_elements(%s->'data') AS datas(aa) WHERE %s %s %%(c%s)s)" % (
getattr(self, 'sql_exists', 'EXISTS'),
get_field_id(self.field.block_field),
attribute,
getattr(self, 'sql_op_exists', self.sql_op),
id(self.value),
)
attribute = self.attribute
if self.field and self.field.key == 'items':
# eq: 'value' = ANY (ITEMS)
# ne: 'value' != ALL (ITEMS)
# with reversed operator:
# lt: 'value' > ANY (ITEMS)
# lte: 'value' >= ANY (ITEMS)
# gt: 'value' < ANY (ITEMS)
# gte: 'value' <= ANY (ITEMS)
# note: ITEMS is written with an integer cast or with a COALESCE expression
if isinstance(self.value, int):
# integer cast of db values
attribute = (
"CASE WHEN array_to_string(%s, '')~E'^\\\\d+$' THEN %s::int[] ELSE ARRAY[]::int[] END"
% (attribute, attribute)
)
else:
# for none values
attribute = "COALESCE(%s, ARRAY[]::text[])" % attribute
return '%%(c%s)s %s %s (%s)' % (
id(self.value),
getattr(self, 'sql_reversed_op', self.sql_op),
getattr(self, 'sql_array_op', 'ANY'),
attribute,
)
if self.field and self.field.key == 'computed':
attribute = "%s->>'data'" % self.attribute
elif self.field and self.field.key in ['item', 'string'] and isinstance(self.value, int):
# integer cast of db values
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN %s::int ELSE NULL END)" % (attribute, attribute)
return '%s %s %%(c%s)s' % (attribute, self.sql_op, id(self.value))
def as_sql_param(self):
if isinstance(self.value, datetime.date) and not isinstance(self.value, datetime.datetime):
value = self.value.strftime('%Y-%m-%d')
elif isinstance(self.value, time.struct_time):
value = datetime.datetime.fromtimestamp(time.mktime(self.value))
else:
value = self.value
return {'c%s' % id(self.value): value}
class Less(Criteria):
sql_op = '<'
sql_reversed_op = '>'
class Greater(Criteria):
sql_op = '>'
sql_reversed_op = '<'
class Equal(Criteria):
sql_op = '='
def as_sql(self):
if self.value in ([], ()):
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
return super().as_sql()
class LessOrEqual(Criteria):
sql_op = '<='
sql_reversed_op = '>='
class GreaterOrEqual(Criteria):
sql_op = '>='
sql_reversed_op = '<='
class NotEqual(Criteria):
sql_op = '!='
# in case of items field, we want to write this clause:
# 'value' != ALL (ITEMS)
sql_array_op = 'ALL'
# in case of block field, we want to write this clause:
# NOT EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
# and not:
# EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' != 'value')
sql_exists = 'NOT EXISTS'
sql_op_exists = '='
def as_sql(self):
if self.field and getattr(self.field, 'block_field', None):
return super().as_sql()
return "(%s is NULL OR %s)" % (self.attribute, super().as_sql())
class StrictNotEqual(Criteria):
sql_op = '!='
class Contains(Criteria):
sql_op = 'IN'
def as_sql(self):
if not self.value:
return 'FALSE'
return super().as_sql()
def as_sql_param(self):
return {'c%s' % id(self.value): tuple(self.value)}
class NotContains(Contains):
sql_op = 'NOT IN'
def as_sql(self):
if not self.value:
return 'TRUE'
return super().as_sql()
class ArrayContains(Contains):
sql_op = '@>'
def as_sql_param(self):
return {'c%s' % id(self.value): self.value}
class NotNull(Criteria):
sql_op = 'IS NOT NULL'
def __init__(self, attribute):
self.attribute = attribute
def as_sql(self):
return '%s %s' % (self.attribute, self.sql_op)
def as_sql_param(self):
return {}
class Null(Criteria):
sql_op = 'IS NULL'
def __init__(self, attribute, **kwargs):
self.attribute = attribute
def as_sql(self):
return '%s %s' % (self.attribute, self.sql_op)
def as_sql_param(self):
return {}
class Or(Criteria):
def __init__(self, criterias, **kwargs):
self.criterias = []
for element in criterias:
sql_class = globals().get(element.__class__.__name__)
sql_element = sql_class(**element.__dict__)
self.criterias.append(sql_element)
def as_sql(self):
if not self.criterias:
return '( FALSE )'
return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
def as_sql_param(self):
d = {}
for criteria in self.criterias:
d.update(criteria.as_sql_param())
return d
class And(Criteria):
def __init__(self, criterias, **kwargs):
self.criterias = []
for element in criterias:
sql_class = globals().get(element.__class__.__name__)
sql_element = sql_class(**element.__dict__)
self.criterias.append(sql_element)
def as_sql(self):
return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
def as_sql_param(self):
d = {}
for criteria in self.criterias:
d.update(criteria.as_sql_param())
return d
class Not(Criteria):
def __init__(self, criteria, **kwargs):
sql_class = globals().get(criteria.__class__.__name__)
sql_element = sql_class(**criteria.__dict__)
self.criteria = sql_element
def as_sql(self):
return 'NOT ( %s )' % self.criteria.as_sql()
def as_sql_param(self):
return self.criteria.as_sql_param()
class Intersects(Criteria):
def as_sql(self):
if not self.value:
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
else:
return '%s && %%(c%s)s' % (self.attribute, id(self.value))
def as_sql_param(self):
return {'c%s' % id(self.value): list(self.value)}
class ILike(Criteria):
def __init__(self, attribute, value, **kwargs):
super().__init__(attribute, value, **kwargs)
self.value = '%' + like_escape(self.value) + '%'
def as_sql(self):
return '%s ILIKE %%(c%s)s' % (self.attribute, id(self.value))
class FtsMatch(Criteria):
def __init__(self, value):
self.value = self.get_fts_value(value)
@classmethod
def get_fts_value(cls, value):
return unidecode.unidecode(value)
def as_sql(self):
return 'fts @@ plainto_tsquery(%%(c%s)s)' % id(self.value)
class ElementEqual(Criteria):
def __init__(self, attribute, key, value):
super().__init__(attribute, value)
self.key = key
def as_sql(self):
return "%s->>'%s' = %%(c%s)s" % (self.attribute, self.key, id(self.value))
class ElementILike(ElementEqual):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.value = '%' + like_escape(self.value) + '%'
def as_sql(self):
return "%s->>'%s' ILIKE %%(c%s)s" % (self.attribute, self.key, id(self.value))
class ElementIntersects(ElementEqual):
def as_sql(self):
if not self.value:
return 'FALSE'
if not isinstance(self.value, (tuple, list, set)):
self.value = [self.value]
else:
self.value = list(self.value)
return "EXISTS(SELECT 1 FROM jsonb_array_elements_text(%s->'%s') foo WHERE foo = ANY(%%(c%s)s))" % (
self.attribute,
self.key,
id(self.value),
)
def get_name_as_sql_identifier(name):
name = qommon.misc.simplify(name)
for char in '<>|{}!?^*+/=\'': # forbidden chars
name = name.replace(char, '')
name = name.replace('-', '_')
return name
def get_field_id(field):
return 'f' + str(field.id).replace('-', '_').lower()
def parse_clause(clause):
# returns a three-elements tuple with:
# - a list of SQL 'WHERE' clauses
# - a dict for query parameters
# - a callable, or None if all clauses have been successfully translated
if clause is None:
return ([], {}, None)
if callable(clause): # already a callable
return ([], {}, clause)
# create 'WHERE' clauses
func_clauses = []
where_clauses = []
parameters = {}
for element in clause:
if callable(element):
func_clauses.append(element)
else:
sql_class = globals().get(element.__class__.__name__)
if sql_class:
sql_element = sql_class(**element.__dict__)
where_clauses.append(sql_element.as_sql())
parameters.update(sql_element.as_sql_param())
else:
func_clauses.append(element.build_lambda())
if func_clauses:
return (where_clauses, parameters, parse_storage_clause(func_clauses))
else:
return (where_clauses, parameters, None)
def str_encode(value):
if isinstance(value, list):
return [str_encode(x) for x in value]
return value
def site_unicode(value):
return force_text(value, get_publisher().site_charset)
def get_connection(new=False):
if new:
cleanup_connection()
if not getattr(get_publisher(), 'pgconn', None):
postgresql_cfg = {}
for param in ('database', 'user', 'password', 'host', 'port'):
value = get_cfg('postgresql', {}).get(param)
if value:
postgresql_cfg[param] = value
if 'database' in postgresql_cfg:
postgresql_cfg['dbname'] = postgresql_cfg.pop('database')
try:
pgconn = psycopg2.connect(**postgresql_cfg)
except psycopg2.Error:
if new:
raise
pgconn = None
get_publisher().pgconn = pgconn
return get_publisher().pgconn
def cleanup_connection():
if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
get_publisher().pgconn.close()
get_publisher().pgconn = None
def get_connection_and_cursor(new=False):
conn = get_connection(new=new)
try:
cur = conn.cursor()
except psycopg2.InterfaceError:
# may be postgresql was restarted in between
conn = get_connection(new=True)
cur = conn.cursor()
return (conn, cur)
def get_formdef_table_name(formdef):
# PostgreSQL limits identifier length to 63 bytes
#
# The system uses no more than NAMEDATALEN-1 bytes of an identifier;
# longer names can be written in commands, but they will be truncated.
# By default, NAMEDATALEN is 64 so the maximum identifier length is
# 63 bytes. If this limit is problematic, it can be raised by changing
# the NAMEDATALEN constant in src/include/pg_config_manual.h.
#
# as we have to know our table names, we crop the names here, and to an
# extent that allows suffixes (like _evolution) to be added.
assert formdef.id is not None
if hasattr(formdef, 'table_name') and formdef.table_name:
return formdef.table_name
formdef.table_name = '%s_%s_%s' % (
formdef.data_sql_prefix,
formdef.id,
get_name_as_sql_identifier(formdef.url_name)[:30],
)
formdef.store(object_only=True)
return formdef.table_name
def get_formdef_trigger_function_name(formdef):
assert formdef.id is not None
return '%s_%s_trigger_fn' % (formdef.data_sql_prefix, formdef.id)
def get_formdef_trigger_name(formdef):
assert formdef.id is not None
return '%s_%s_trigger' % (formdef.data_sql_prefix, formdef.id)
def get_formdef_new_id(id_start):
new_id = id_start
conn, cur = get_connection_and_cursor()
while True:
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('formdata\\_%s\\_%%' % new_id,),
)
if cur.fetchone()[0] == 0:
break
new_id += 1
conn.commit()
cur.close()
return new_id
def get_carddef_new_id(id_start):
new_id = id_start
conn, cur = get_connection_and_cursor()
while True:
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('carddata\\_%s\\_%%' % new_id,),
)
if cur.fetchone()[0] == 0:
break
new_id += 1
conn.commit()
cur.close()
return new_id
def formdef_wipe():
conn, cur = get_connection_and_cursor()
cur.execute(
'''SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('formdata\\_%%\\_%%',),
)
for table_name in [x[0] for x in cur.fetchall()]:
cur.execute('DELETE FROM %s' % table_name) # Force trigger execution
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
cur.execute("SELECT relkind FROM pg_class WHERE relname = 'wcs_all_forms'")
row = cur.fetchone()
# only do the delete if wcs_all_forms is a table and not still a view
if row is not None and row[0] == 'r':
cur.execute('TRUNCATE wcs_all_forms')
conn.commit()
cur.close()
def carddef_wipe():
conn, cur = get_connection_and_cursor()
cur.execute(
'''SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('carddata\\_%%\\_%%',),
)
for table_name in [x[0] for x in cur.fetchall()]:
cur.execute('DELETE FROM %s' % table_name) # Force trigger execution
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
conn.commit()
cur.close()
def get_formdef_view_name(formdef):
prefix = 'wcs_view'
if formdef.data_sql_prefix != 'formdata':
prefix = 'wcs_%s_view' % formdef.data_sql_prefix
return '%s_%s_%s' % (prefix, formdef.id, get_name_as_sql_identifier(formdef.url_name)[:40])
def guard_postgres(func):
def f(*args, **kwargs):
try:
return func(*args, **kwargs)
except psycopg2.Error:
get_connection().rollback()
raise
return f
@guard_postgres
def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=True):
if formdef.id is None:
return []
if getattr(formdef, 'fields', None) is Ellipsis:
# don't touch tables for lightweight objects
return []
if getattr(formdef, 'snapshot_object', None):
# don't touch tables for snapshot objects
return []
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
table_name = get_formdef_table_name(formdef)
new_table = False
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id serial PRIMARY KEY,
user_id varchar,
receipt_time timestamp,
anonymised timestamptz,
status varchar,
page_no varchar,
workflow_data bytea,
id_display varchar
)'''
% table_name
)
cur.execute(
'''CREATE TABLE %s_evolutions (id serial PRIMARY KEY,
who varchar,
status varchar,
time timestamp,
last_jump_datetime timestamp,
comment text,
parts bytea,
formdata_id integer REFERENCES %s (id) ON DELETE CASCADE)'''
% (table_name, table_name)
)
new_table = True
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {
'id',
'user_id',
'receipt_time',
'status',
'workflow_data',
'id_display',
'fts',
'page_no',
'anonymised',
'workflow_roles',
# workflow_merged_roles_dict combines workflow_roles from formdef and
# formdata and is used to filter on function assignment.
'workflow_merged_roles_dict',
# workflow_roles_array is created from workflow_roles to be used in
# get_ids_with_indexed_value
'workflow_roles_array',
'concerned_roles_array',
'tracking_code',
'actions_roles_array',
'backoffice_submission',
'submission_context',
'submission_agent_id',
'submission_channel',
'criticality_level',
'last_update_time',
'digests',
'user_label',
'prefilling_data',
}
# migrations
if 'fts' not in existing_fields:
# full text search
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (table_name, table_name))
if 'workflow_roles' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles bytea''' % table_name)
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles_array text[]''' % table_name)
if 'workflow_merged_roles_dict' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_merged_roles_dict jsonb''' % table_name)
if 'concerned_roles_array' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN concerned_roles_array text[]''' % table_name)
if 'actions_roles_array' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN actions_roles_array text[]''' % table_name)
if 'page_no' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN page_no varchar''' % table_name)
if 'anonymised' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN anonymised timestamptz''' % table_name)
if 'tracking_code' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN tracking_code varchar''' % table_name)
if 'backoffice_submission' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN backoffice_submission boolean''' % table_name)
if 'submission_context' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN submission_context bytea''' % table_name)
if 'submission_agent_id' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN submission_agent_id varchar''' % table_name)
if 'submission_channel' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN submission_channel varchar''' % table_name)
if 'criticality_level' not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN criticality_level integer NOT NULL DEFAULT(0)''' % table_name
)
if 'last_update_time' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp''' % table_name)
if 'digests' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN digests jsonb''' % table_name)
if 'user_label' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN user_label varchar''' % table_name)
if 'prefilling_data' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN prefilling_data bytea''' % table_name)
# add new fields
for field in formdef.get_all_fields():
assert field.id is not None
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add(get_field_id(field))
if get_field_id(field) not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type))
if field.store_display_value:
needed_fields.add('%s_display' % get_field_id(field))
if '%s_display' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s varchar'''
% (table_name, '%s_display' % get_field_id(field))
)
if field.store_structured_value:
needed_fields.add('%s_structured' % get_field_id(field))
if '%s_structured' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s bytea'''
% (table_name, '%s_structured' % get_field_id(field))
)
for field in (formdef.geolocations or {}).keys():
column_name = 'geoloc_%s' % field
needed_fields.add(column_name)
if column_name not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN %s %s' '' % (table_name, column_name, 'POINT'))
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field))
if formdef.data_sql_prefix == 'formdata':
recreate_trigger(formdef, cur, conn)
# migrations on _evolutions table
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = '%s_evolutions'
'''
% table_name
)
evo_existing_fields = {x[0] for x in cur.fetchall()}
if 'last_jump_datetime' not in evo_existing_fields:
cur.execute('''ALTER TABLE %s_evolutions ADD COLUMN last_jump_datetime timestamp''' % table_name)
if rebuild_views or len(existing_fields - needed_fields):
# views may have been dropped when dropping columns, so we recreate
# them even if not asked to.
redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views)
if new_table:
do_formdef_indexes(formdef, created=True, conn=conn, cur=cur)
if own_conn:
conn.commit()
cur.close()
actions = []
if 'concerned_roles_array' not in existing_fields:
actions.append('rebuild_security')
elif 'actions_roles_array' not in existing_fields:
actions.append('rebuild_security')
if 'tracking_code' not in existing_fields:
# if tracking code has just been added to the table we need to make
# sure the tracking code table does exist.
actions.append('do_tracking_code_table')
return actions
def recreate_trigger(formdef, cur, conn):
# recreate the trigger function, just so it's uptodate
table_name = get_formdef_table_name(formdef)
category_value = formdef.category_id
geoloc_base_x_query = "NULL"
geoloc_base_y_query = "NULL"
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
geoloc_base_x_query = "NEW.geoloc_base[0]"
geoloc_base_y_query = "NEW.geoloc_base[1]"
if formdef.category_id is None:
category_value = "NULL"
criticality_levels = len(formdef.workflow.criticality_levels or [0])
endpoint_status = formdef.workflow.get_endpoint_status()
endpoint_status_filter = ", ".join(["'wf-%s'" % x.id for x in endpoint_status])
if endpoint_status_filter == "":
# not the prettiest in town, but will do fine for now.
endpoint_status_filter = "'xxxx'"
formed_name_quotedstring = psycopg2.extensions.QuotedString(formdef.name)
formed_name_quotedstring.encoding = 'utf8'
formdef_name = formed_name_quotedstring.getquoted().decode()
trigger_code = '''
BEGIN
IF TG_OP = 'DELETE' THEN
DELETE FROM wcs_all_forms WHERE formdef_id = {formdef_id} AND id = OLD.id;
RETURN OLD;
ELSEIF TG_OP = 'INSERT' THEN
INSERT INTO wcs_all_forms VALUES (
{category_id},
{formdef_id},
NEW.id,
NEW.user_id,
NEW.receipt_time,
NEW.status,
NEW.id_display,
NEW.submission_agent_id,
NEW.submission_channel,
NEW.backoffice_submission,
NEW.last_update_time,
NEW.digests,
NEW.user_label,
NEW.concerned_roles_array,
NEW.actions_roles_array,
NEW.fts,
NEW.status IN ({endpoint_status}),
{formdef_name},
(SELECT name FROM users WHERE users.id = CAST(NEW.user_id AS INTEGER)),
NEW.criticality_level - {criticality_levels},
{geoloc_base_x},
{geoloc_base_y},
NEW.anonymised);
RETURN NEW;
ELSE
UPDATE wcs_all_forms SET
user_id = NEW.user_id,
receipt_time = NEW.receipt_time,
status = NEW.status,
id_display = NEW.id_display,
submission_agent_id = NEW.submission_agent_id,
submission_channel = NEW.submission_channel,
backoffice_submission = NEW.backoffice_submission,
last_update_time = NEW.last_update_time,
digests = NEW.digests,
user_label = NEW.user_label,
concerned_roles_array = NEW.concerned_roles_array,
actions_roles_array = NEW.actions_roles_array,
fts = NEW.fts,
is_at_endpoint = NEW.status IN ({endpoint_status}),
formdef_name = {formdef_name},
user_name = (SELECT name FROM users WHERE users.id = CAST(NEW.user_id AS INTEGER)),
criticality_level = NEW.criticality_level - {criticality_levels},
geoloc_base_x = {geoloc_base_x},
geoloc_base_y = {geoloc_base_y},
anonymised = NEW.anonymised
WHERE formdef_id = {formdef_id} AND id = OLD.id;
RETURN NEW;
END IF;
END;
'''.format(
category_id=category_value, # always valued ? need to handle null otherwise.
formdef_id=formdef.id,
geoloc_base_x=geoloc_base_x_query,
geoloc_base_y=geoloc_base_y_query,
formdef_name=formdef_name,
criticality_levels=criticality_levels,
endpoint_status=endpoint_status_filter,
)
cur.execute(
'''SELECT prosrc FROM pg_proc
WHERE proname = '%s'
'''
% get_formdef_trigger_function_name(formdef)
)
function_row = cur.fetchone()
if function_row is None or function_row[0] != trigger_code:
cur.execute(
'''
CREATE OR REPLACE FUNCTION {trg_fn_name}()
RETURNS trigger
LANGUAGE plpgsql
AS $${code}$$;
'''.format(
trg_fn_name=get_formdef_trigger_function_name(formdef),
code=trigger_code,
)
)
trg_name = get_formdef_trigger_name(formdef)
cur.execute(
'''SELECT 1 FROM pg_trigger
WHERE tgrelid = '%s'::regclass
AND tgname = '%s'
'''
% (table_name, trg_name)
)
if len(cur.fetchall()) == 0:
# compatibility note: to support postgresql<11 we use PROCEDURE and not FUNCTION
cur.execute(
'''CREATE TRIGGER {trg_name} AFTER INSERT OR UPDATE OR DELETE
ON {table_name}
FOR EACH ROW EXECUTE PROCEDURE {trg_fn_name}();
'''.format(
trg_fn_name=get_formdef_trigger_function_name(formdef),
table_name=table_name,
trg_name=trg_name,
)
)
def do_formdef_indexes(formdef, created, conn, cur, concurrently=False):
table_name = get_formdef_table_name(formdef)
evolutions_table_name = table_name + '_evolutions'
existing_indexes = set()
if not created:
cur.execute(
'''SELECT indexname
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename IN (%s, %s)''',
(table_name, evolutions_table_name),
)
existing_indexes = {x[0] for x in cur.fetchall()}
create_index = 'CREATE INDEX'
if concurrently:
create_index = 'CREATE INDEX CONCURRENTLY'
if evolutions_table_name + '_fid' not in existing_indexes:
cur.execute(
'''%s %s_fid ON %s (formdata_id)''' % (create_index, evolutions_table_name, evolutions_table_name)
)
for attr in ('receipt_time', 'anonymised', 'user_id', 'status'):
if table_name + '_' + attr + '_idx' not in existing_indexes:
cur.execute(
'%(create_index)s %(table_name)s_%(attr)s_idx ON %(table_name)s (%(attr)s)'
% {'create_index': create_index, 'table_name': table_name, 'attr': attr}
)
for attr in ('concerned_roles_array', 'actions_roles_array'):
idx_name = 'idx_' + attr + '_' + table_name
cur.execute(
'%(create_index)s IF NOT EXISTS %(idx_name)s ON %(table_name)s USING gin (%(attr)s)'
% {'create_index': create_index, 'table_name': table_name, 'idx_name': idx_name, 'attr': attr}
)
@guard_postgres
def do_user_table():
conn, cur = get_connection_and_cursor()
table_name = 'users'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id serial PRIMARY KEY,
name varchar,
ascii_name varchar,
email varchar,
roles text[],
is_active bool,
is_admin bool,
anonymous bool,
verified_fields text[],
name_identifiers text[],
lasso_dump text,
last_seen timestamp,
deleted_timestamp timestamp
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {
'id',
'name',
'email',
'roles',
'is_admin',
'anonymous',
'name_identifiers',
'verified_fields',
'lasso_dump',
'last_seen',
'fts',
'ascii_name',
'deleted_timestamp',
'is_active',
}
from wcs.admin.settings import UserFieldsFormDef
formdef = UserFieldsFormDef()
for field in formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add(get_field_id(field))
if get_field_id(field) not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type))
if field.store_display_value:
needed_fields.add('%s_display' % get_field_id(field))
if '%s_display' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s varchar'''
% (table_name, '%s_display' % get_field_id(field))
)
if field.store_structured_value:
needed_fields.add('%s_structured' % get_field_id(field))
if '%s_structured' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s bytea'''
% (table_name, '%s_structured' % get_field_id(field))
)
# migrations
if 'fts' not in existing_fields:
# full text search
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (table_name, table_name))
if 'verified_fields' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN verified_fields text[]' % table_name)
if 'ascii_name' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN ascii_name varchar' % table_name)
if 'deleted_timestamp' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN deleted_timestamp timestamp' % table_name)
if 'is_active' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN is_active bool DEFAULT TRUE' % table_name)
cur.execute('UPDATE %s SET is_active = FALSE WHERE deleted_timestamp IS NOT NULL' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
try:
cur.execute('CREATE INDEX users_name_idx ON users (name)')
conn.commit()
except psycopg2.ProgrammingError:
conn.rollback()
cur.close()
def do_role_table(concurrently=False):
conn, cur = get_connection_and_cursor()
table_name = 'roles'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id VARCHAR PRIMARY KEY,
name VARCHAR,
uuid UUID,
slug VARCHAR UNIQUE,
internal BOOLEAN,
details VARCHAR,
emails VARCHAR[],
emails_to_members BOOLEAN,
allows_backoffice_access BOOLEAN)'''
% table_name
)
cur.execute('ALTER TABLE roles ALTER COLUMN uuid TYPE VARCHAR')
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in Role._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def migrate_legacy_roles():
# store old pickle roles in SQL
for role_id in wcs.roles.Role.keys():
role = wcs.roles.Role.get(role_id)
role.__class__ = Role
role.store()
def do_tracking_code_table():
conn, cur = get_connection_and_cursor()
table_name = 'tracking_codes'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
formdef_id varchar,
formdata_id varchar)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {'id', 'formdef_id', 'formdata_id'}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def do_session_table():
conn, cur = get_connection_and_cursor()
table_name = 'sessions'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
session_data bytea,
name_identifier varchar,
visiting_objects_keys varchar[]
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {'id', 'session_data', 'name_identifier', 'visiting_objects_keys', 'last_update_time'}
# migrations
if 'last_update_time' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp DEFAULT NOW()''' % table_name)
cur.execute('''CREATE INDEX %s_ts ON %s (last_update_time)''' % (table_name, table_name))
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def do_transient_data_table():
conn, cur = get_connection_and_cursor()
table_name = TransientData._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
session_id VARCHAR REFERENCES sessions(id) ON DELETE CASCADE,
data bytea,
last_update_time timestamptz
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in TransientData._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def do_custom_views_table():
conn, cur = get_connection_and_cursor()
table_name = 'custom_views'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
title varchar,
slug varchar,
user_id varchar,
visibility varchar,
formdef_type varchar,
formdef_id varchar,
is_default boolean,
order_by varchar,
columns jsonb,
filters jsonb
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in CustomView._table_static_fields}
# migrations
if 'is_default' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN is_default boolean DEFAULT FALSE''' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def do_snapshots_table():
conn, cur = get_connection_and_cursor()
table_name = 'snapshots'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL,
object_type VARCHAR,
object_id VARCHAR,
timestamp TIMESTAMP WITH TIME ZONE,
user_id VARCHAR,
comment TEXT,
serialization TEXT,
patch TEXT,
label VARCHAR
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
# migrations
if 'patch' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN patch TEXT''' % table_name)
needed_fields = {x[0] for x in Snapshot._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
# add indexes
cur.execute(
'''SELECT indexname
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = %s''',
(table_name,),
)
existing_indexes = {x[0] for x in cur.fetchall()}
if ('%s_pkey' % table_name) not in existing_indexes:
cur.execute('''ALTER TABLE %s ADD PRIMARY KEY (id)''' % table_name)
if ('%s_object_by_date' % table_name) not in existing_indexes:
cur.execute(
'''CREATE INDEX %s_object_by_date ON %s(object_type, object_id, timestamp DESC)'''
% (table_name, table_name)
)
conn.commit()
cur.close()
def do_loggederrors_table(concurrently=False):
conn, cur = get_connection_and_cursor()
table_name = 'loggederrors'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
kind VARCHAR,
tech_id VARCHAR UNIQUE,
summary VARCHAR,
formdef_class VARCHAR,
formdata_id VARCHAR,
formdef_id VARCHAR,
workflow_id VARCHAR,
status_id VARCHAR,
status_item_id VARCHAR,
expression VARCHAR,
expression_type VARCHAR,
traceback TEXT,
exception_class VARCHAR,
exception_message VARCHAR,
occurences_count INTEGER,
first_occurence_timestamp TIMESTAMP WITH TIME ZONE,
latest_occurence_timestamp TIMESTAMP WITH TIME ZONE
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in LoggedError._table_static_fields}
# migrations
if 'kind' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN kind VARCHAR''' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
create_index = 'CREATE INDEX'
if concurrently:
create_index = 'CREATE INDEX CONCURRENTLY'
# build indexes
existing_indexes = set()
cur.execute(
'''SELECT indexname
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = %s''',
(table_name,),
)
existing_indexes = {x[0] for x in cur.fetchall()}
for attr in ('formdef_id', 'workflow_id'):
if table_name + '_' + attr + '_idx' not in existing_indexes:
cur.execute(
'%(create_index)s %(table_name)s_%(attr)s_idx ON %(table_name)s (%(attr)s)'
% {'create_index': create_index, 'table_name': table_name, 'attr': attr}
)
conn.commit()
cur.close()
def do_tokens_table(concurrently=False):
conn, cur = get_connection_and_cursor()
table_name = Token._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id VARCHAR PRIMARY KEY,
type VARCHAR,
expiration TIMESTAMPTZ,
context JSONB
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in Token._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
def migrate_legacy_tokens():
# store old pickle tokens in SQL
for token_id in wcs.qommon.tokens.Token.keys():
try:
token = wcs.qommon.tokens.Token.get(token_id)
except KeyError:
continue
except AttributeError:
# old python2 tokens:
# AttributeError: module 'builtins' has no attribute 'unicode'
wcs.qommon.tokens.Token.remove_object(token_id)
continue
token.__class__ = Token
token.store()
@guard_postgres
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
('wcs_meta',),
)
table_exists = cur.fetchone()[0] > 0
if not table_exists:
cur.execute(
'''CREATE TABLE wcs_meta (id serial PRIMARY KEY,
key varchar,
value varchar,
created_at timestamptz DEFAULT NOW(),
updated_at timestamptz DEFAULT NOW())'''
)
if insert_current_sql_level:
sql_level = SQL_LEVEL[0]
else:
sql_level = 0
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
('sql_level', str(sql_level)),
)
else:
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
('wcs_meta',),
)
existing_fields = {x[0] for x in cur.fetchall()}
if 'created_at' not in existing_fields:
cur.execute('''ALTER TABLE wcs_meta ADD COLUMN created_at timestamptz DEFAULT NOW()''')
if 'updated_at' not in existing_fields:
cur.execute('''ALTER TABLE wcs_meta ADD COLUMN updated_at timestamptz DEFAULT NOW()''')
if own_conn:
conn.commit()
cur.close()
@guard_postgres
def redo_views(conn, cur, formdef, rebuild_global_views=False):
if formdef.id is None:
return
drop_views(formdef, conn, cur)
do_views(formdef, conn, cur, rebuild_global_views=rebuild_global_views)
@guard_postgres
def drop_views(formdef, conn, cur):
# remove the global views
drop_global_views(conn, cur)
if formdef:
# remove the form view itself
view_prefix = 'wcs\\_view\\_%s\\_%%' % formdef.id
if formdef.data_sql_prefix != 'formdata':
view_prefix = 'wcs\\_%s\\_view\\_%s\\_%%' % (formdef.data_sql_prefix, formdef.id)
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
(view_prefix,),
)
else:
# if there's no formdef specified, remove all form views
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('wcs\\_view\\_%',),
)
view_names = []
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
def get_view_fields(formdef):
view_fields = []
view_fields.append(("int '%s'" % (formdef.category_id or 0), 'category_id'))
view_fields.append(("int '%s'" % (formdef.id or 0), 'formdef_id'))
for field in (
'id',
'user_id',
'receipt_time',
'status',
'id_display',
'submission_agent_id',
'submission_channel',
'backoffice_submission',
'last_update_time',
'digests',
'user_label',
):
view_fields.append((field, field))
return view_fields
@guard_postgres
def do_views(formdef, conn, cur, rebuild_global_views=True):
# create new view
table_name = get_formdef_table_name(formdef)
view_name = get_formdef_view_name(formdef)
view_fields = get_view_fields(formdef)
column_names = {}
for field in formdef.get_all_fields():
field_key = get_field_id(field)
if field.type in ('page', 'title', 'subtitle', 'comment'):
continue
if field.varname:
# the variable should be fine as is but we pass it through
# get_name_as_sql_identifier nevertheless, to be extra sure it
# doesn't contain invalid characters.
field_name = 'f_%s' % get_name_as_sql_identifier(field.varname)[:50]
else:
field_name = '%s_%s' % (get_field_id(field), get_name_as_sql_identifier(field.label))
field_name = field_name[:50]
if field_name in column_names:
# it may happen that the same varname is used on multiple fields
# (for example in the case of conditional pages), in that situation
# we suffix the field name with an index count
while field_name in column_names:
column_names[field_name] += 1
field_name = '%s_%s' % (field_name, column_names[field_name])
column_names[field_name] = 1
view_fields.append((field_key, field_name))
if field.store_display_value:
field_key = '%s_display' % get_field_id(field)
view_fields.append((field_key, field_name + '_display'))
view_fields.append(
(
'''ARRAY(SELECT status FROM %s_evolutions '''
''' WHERE %s.id = %s_evolutions.formdata_id'''
''' ORDER BY %s_evolutions.time)''' % ((table_name,) * 4),
'status_history',
)
)
# add a is_at_endpoint column, dynamically created againt the endpoint status.
endpoint_status = formdef.workflow.get_endpoint_status()
view_fields.append(
(
'''(SELECT status = ANY(ARRAY[[%s]]::text[]))'''
% ', '.join(["'wf-%s'" % x.id for x in endpoint_status]),
'''is_at_endpoint''',
)
)
# [CRITICALITY_1] Add criticality_level, computed relative to levels in
# the given workflow, so all higher criticalites are sorted first. This is
# reverted when loading the formdata back, in [CRITICALITY_2]
levels = len(formdef.workflow.criticality_levels or [0])
view_fields.append(('''(criticality_level - %d)''' % levels, '''criticality_level'''))
view_fields.append((cur.mogrify('(SELECT text %s)', (formdef.name,)), 'formdef_name'))
view_fields.append(
(
'''(SELECT name FROM users
WHERE users.id = CAST(user_id AS INTEGER))''',
'user_name',
)
)
view_fields.append(('concerned_roles_array', 'concerned_roles_array'))
view_fields.append(('actions_roles_array', 'actions_roles_array'))
view_fields.append(('fts', 'fts'))
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
view_fields.append(('geoloc_base[0]', 'geoloc_base_x'))
view_fields.append(('geoloc_base[1]', 'geoloc_base_y'))
else:
view_fields.append(('NULL::real', 'geoloc_base_x'))
view_fields.append(('NULL::real', 'geoloc_base_y'))
view_fields.append(('anonymised', 'anonymised'))
fields_list = ', '.join(['%s AS %s' % (force_text(x), force_text(y)) for (x, y) in view_fields])
cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % (view_name, fields_list, table_name))
if rebuild_global_views:
do_global_views(conn, cur) # recreate global views
def drop_global_views(conn, cur):
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('wcs\\_category\\_%',),
)
view_names = []
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
def do_global_views(conn, cur):
# recreate global views
# XXX TODO: make me dynamic, please ?
cur.execute(
"""CREATE TABLE IF NOT EXISTS wcs_all_forms (
category_id integer,
formdef_id integer NOT NULL,
id integer NOT NULL,
user_id character varying,
receipt_time timestamp without time zone,
status character varying,
id_display character varying,
submission_agent_id character varying,
submission_channel character varying,
backoffice_submission boolean,
last_update_time timestamp without time zone,
digests jsonb,
user_label character varying,
concerned_roles_array text[],
actions_roles_array text[],
fts tsvector,
is_at_endpoint boolean,
formdef_name text,
user_name character varying,
criticality_level integer,
geoloc_base_x double precision,
geoloc_base_y double precision,
anonymised timestamp with time zone
, PRIMARY KEY(formdef_id, id)
)"""
)
cur.execute(
'''CREATE INDEX IF NOT EXISTS %s_fts ON %s USING gin(fts)''' % ("wcs_all_forms", "wcs_all_forms")
)
for attr in ('receipt_time', 'anonymised', 'user_id', 'status'):
cur.execute(
'''CREATE INDEX IF NOT EXISTS %s_%s ON %s (%s)''' % ("wcs_all_forms", attr, "wcs_all_forms", attr)
)
for attr in ('concerned_roles_array', 'actions_roles_array'):
cur.execute(
'''CREATE INDEX IF NOT EXISTS %s_%s ON %s USING gin (%s)'''
% ("wcs_all_forms", attr, "wcs_all_forms", attr)
)
clean_global_views(conn, cur)
for category in wcs.categories.Category.select():
name = get_name_as_sql_identifier(category.url_name)[:40]
cur.execute(
'''CREATE OR REPLACE VIEW wcs_category_%s AS SELECT * from wcs_all_forms
WHERE category_id = %s'''
% (name, category.id)
)
def clean_global_views(conn, cur):
from .formdef import FormDef
# Purge of any dead data
valid_ids = [int(i) for i in FormDef.keys()]
if valid_ids:
cur.execute('DELETE FROM wcs_all_forms WHERE NOT formdef_id = ANY(%s)', (valid_ids,))
else:
cur.execute('TRUNCATE wcs_all_forms')
def init_global_table(conn=None, cur=None):
from .formdef import FormDef
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
cur.execute("SELECT relkind FROM pg_class WHERE relname = 'wcs_all_forms';")
rows = cur.fetchall()
if len(rows) != 0:
if rows[0][0] == 'v':
# force wcs_all_forms table creation
cur.execute('DROP VIEW IF EXISTS wcs_all_forms CASCADE;')
else:
assert rows[0][0] == 'r'
cur.execute('DROP TABLE wcs_all_forms CASCADE;')
do_global_views(conn, cur)
# now copy all data into the table
for formdef in FormDef.select():
category_value = formdef.category_id
if formdef.category_id is None:
category_value = "NULL"
geoloc_base_x_query = "NULL"
geoloc_base_y_query = "NULL"
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
geoloc_base_x_query = "geoloc_base[0]"
geoloc_base_y_query = "geoloc_base[1]"
criticality_levels = len(formdef.workflow.criticality_levels or [0])
endpoint_status = formdef.workflow.get_endpoint_status()
endpoint_status_filter = ", ".join(["'wf-%s'" % x.id for x in endpoint_status])
if endpoint_status_filter == "":
# not the prettiest in town, but will do fine for now.
endpoint_status_filter = "'xxxx'"
formed_name_quotedstring = psycopg2.extensions.QuotedString(formdef.name)
formed_name_quotedstring.encoding = 'utf8'
formdef_name = formed_name_quotedstring.getquoted().decode()
cur.execute(
"""
INSERT INTO wcs_all_forms
SELECT
{category_id},
{formdef_id},
id,
user_id,
receipt_time,
status,
id_display,
submission_agent_id,
submission_channel,
backoffice_submission,
last_update_time,
digests,
user_label,
concerned_roles_array,
actions_roles_array,
fts,
status IN ({endpoint_status}),
{formdef_name},
(SELECT name FROM users WHERE users.id = CAST(user_id AS INTEGER)),
criticality_level - {criticality_levels},
{geoloc_base_x},
{geoloc_base_y},
anonymised
FROM {table_name}
ON CONFLICT DO NOTHING;
""".format(
table_name=get_formdef_table_name(formdef),
category_id=category_value, # always valued ? need to handle null otherwise.
formdef_id=formdef.id,
geoloc_base_x=geoloc_base_x_query,
geoloc_base_y=geoloc_base_y_query,
formdef_name=formdef_name,
criticality_levels=criticality_levels,
endpoint_status=endpoint_status_filter,
)
)
if own_conn:
conn.commit()
cur.close()
class SqlMixin:
_table_name = None
_numerical_id = True
_table_select_skipped_fields = []
_has_id = True
@classmethod
@guard_postgres
def keys(cls, clause=None):
conn, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
sql_statement = 'SELECT id FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
ids = [x[0] for x in cur.fetchall()]
conn.commit()
cur.close()
return ids
@classmethod
@guard_postgres
def count(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
if func_clause:
# fallback to counting the result of a select()
return len(cls.select(clause))
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
conn, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
conn.commit()
cur.close()
return count
@classmethod
@guard_postgres
def exists(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
if func_clause:
# fallback to counting the result of a select()
return len(cls.select(clause))
sql_statement = 'SELECT 1 FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' LIMIT 1'
conn, cur = get_connection_and_cursor()
try:
cur.execute(sql_statement, parameters)
except psycopg2.errors.UndefinedTable:
result = False
else:
check = cur.fetchone()
result = check is not None
conn.commit()
cur.close()
return result
@classmethod
@guard_postgres
def get_with_indexed_value(
cls, index, value, ignore_errors=False, order_by=None, limit=None, offset=None
):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %s = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
index,
)
sql_statement += cls.get_order_by_clause(order_by)
parameters = {'value': str(value)}
if limit:
sql_statement += ' LIMIT %(limit)s'
parameters['limit'] = limit
if offset:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
cur.execute(sql_statement, parameters)
try:
while True:
row = cur.fetchone()
if row is None:
break
ob = cls._row2ob(row)
if ignore_errors and ob is None:
continue
yield ob
finally:
conn.commit()
cur.close()
@classmethod
@guard_postgres
def get_ids_from_query(cls, query):
cur = get_connection_and_cursor()[1]
sql_statement = (
'''SELECT id FROM %s
WHERE fts @@ plainto_tsquery(%%(value)s)'''
% cls._table_name
)
cur.execute(sql_statement, {'value': FtsMatch.get_fts_value(query)})
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
@classmethod
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False, column=None):
if column is None and (cls._numerical_id or id is None):
try:
int(id)
except (TypeError, ValueError):
if ignore_errors and id is None:
return None
else:
raise KeyError()
cur = get_connection_and_cursor()[1]
sql_statement = '''SELECT %s
FROM %s
WHERE %s = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
column or 'id',
)
cur.execute(sql_statement, {'value': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
@classmethod
@guard_postgres
def get_on_index(cls, value, index, ignore_errors=False, **kwargs):
return cls.get(value, ignore_errors=ignore_errors, column=index)
@classmethod
@guard_postgres
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None, order_by=None):
if not ids:
return []
tables = [cls._table_name]
columns = [
'%s.%s' % (cls._table_name, column_name)
for column_name in [x[0] for x in cls._table_static_fields] + cls.get_data_fields()
]
extra_fields = []
if fields:
# look for relations
for field in fields:
if not getattr(field, 'is_related_field', False):
continue
if field.parent_field_id == 'user-label':
# relation to user table
carddef_table_alias = 'users'
carddef_table_decl = (
'LEFT JOIN users ON (CAST(%s.user_id AS INTEGER) = users.id)' % cls._table_name
)
else:
carddef_data_table_name = get_formdef_table_name(field.carddef)
carddef_table_alias = 't%s' % id(field.carddef)
carddef_table_decl = 'LEFT JOIN %s AS %s ON (CAST(%s.%s AS INTEGER) = %s.id)' % (
carddef_data_table_name,
carddef_table_alias,
cls._table_name,
get_field_id(field.parent_field),
carddef_table_alias,
)
if carddef_table_decl not in tables:
tables.append(carddef_table_decl)
column_field_id = field.get_column_field_id()
columns.append('%s.%s' % (carddef_table_alias, column_field_id))
extra_fields.append(field.id)
conn, cur = get_connection_and_cursor()
if cls._numerical_id:
ids_str = ', '.join([str(x) for x in ids])
else:
ids_str = ', '.join(["'%s'" % x for x in ids])
sql_statement = '''SELECT %s
FROM %s
WHERE %s.id IN (%s)''' % (
', '.join(columns),
' '.join(tables),
cls._table_name,
ids_str,
)
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement)
objects = cls.get_objects(cur, extra_fields=extra_fields)
conn.commit()
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
if keep_order:
objects_dict = {}
for object in objects:
objects_dict[object.id] = object
objects = [objects_dict[x] for x in ids if objects_dict.get(x)]
return list(objects)
@classmethod
def get_objects_iterator(cls, cur, ignore_errors=False, extra_fields=None):
while True:
row = cur.fetchone()
if row is None:
break
yield cls._row2ob(row, extra_fields=extra_fields)
@classmethod
def get_objects(cls, cur, ignore_errors=False, iterator=False, extra_fields=None):
generator = cls.get_objects_iterator(cur=cur, ignore_errors=ignore_errors, extra_fields=extra_fields)
if iterator:
return generator
return list(generator)
@classmethod
def get_order_by_clause(cls, order_by):
if not order_by:
return ''
# [SEC_ORDER] security note: it is not possible to use
# prepared statements for ORDER BY clauses, therefore input
# is controlled beforehand (see misc.get_order_by_or_400).
direction = 'ASC'
if order_by.startswith('-'):
order_by = order_by[1:]
direction = 'DESC'
if '->' in order_by:
# sort on field of block field: f42->'data'->0->>'bf13e4d8a8-fb08-4808-b5ae-02d6247949b9'
parts = order_by.split('->')
order_by = '%s->%s' % (parts[0].replace('-', '_'), '->'.join(parts[1:]))
else:
order_by = order_by.replace('-', '_')
fields = ['formdef_name', 'user_name'] # global view fields
fields.extend([x[0] for x in cls._table_static_fields])
fields.extend(cls.get_data_fields())
if order_by.split('->')[0] not in fields:
# for a sort on field of block field, just check the existence of the block field
return ''
return ' ORDER BY %s %s' % (order_by, direction)
@classmethod
@guard_postgres
def has_key(cls, id):
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT EXISTS(SELECT 1 FROM %s WHERE id = %%s)' % cls._table_name
with cur:
cur.execute(sql_statement, (id,))
result = cur.fetchall()[0][0]
conn.commit()
return result
@classmethod
@guard_postgres
def select_iterator(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
itersize=None,
):
table_static_fields = [
x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0]
for x in cls._table_static_fields
]
def retrieve():
for object in cls.get_objects(cur, iterator=True):
if object is None:
continue
if func_clause and not func_clause(object):
continue
yield object
if itersize and cls._has_id:
# this case concerns almost all data tables: formdata, card, users, roles
sql_statement = '''SELECT id FROM %s''' % cls._table_name
else:
# this case concerns aggregated views like wcs_all_forms (class
# AnyFormData) which does not have a surrogate key id column
sql_statement = '''SELECT %s FROM %s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += cls.get_order_by_clause(order_by)
if not func_clause:
if limit:
sql_statement += ' LIMIT %(limit)s'
parameters['limit'] = limit
if offset:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
conn, cur = get_connection_and_cursor()
with cur:
cur.execute(sql_statement, parameters)
conn.commit()
if itersize and cls._has_id:
sql_id_statement = '''SELECT %s FROM %s WHERE id IN %%s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
sql_id_statement += cls.get_order_by_clause(order_by)
ids = [row[0] for row in cur]
while ids:
cur.execute(sql_id_statement, [tuple(ids[:itersize])])
conn.commit()
yield from retrieve()
ids = ids[itersize:]
else:
yield from retrieve()
@classmethod
@guard_postgres
def select(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
iterator=False,
itersize=None,
):
if iterator and not itersize:
itersize = 200
objects = cls.select_iterator(
clause=clause,
order_by=order_by,
ignore_errors=ignore_errors,
limit=limit,
offset=offset,
)
func_clause = parse_clause(clause)[2]
if func_clause and (limit or offset):
objects = _take(objects, limit, offset)
if iterator:
return objects
return list(objects)
@classmethod
@guard_postgres
def select_distinct(cls, columns, clause=None, first_field_alias=None):
# do note this method returns unicode strings.
column0 = columns[0]
if first_field_alias:
column0 = '%s as %s' % (column0, first_field_alias)
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT DISTINCT ON (%s) %s FROM %s' % (
columns[0],
', '.join([column0] + columns[1:]),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' ORDER BY %s' % (first_field_alias or columns[0])
cur.execute(sql_statement, parameters)
values = [x for x in cur.fetchall()]
conn.commit()
cur.close()
return values
def get_sql_dict_from_data(self, data, formdef):
sql_dict = {}
for field in formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
value = data.get(field.id)
if value is not None:
if field.key in ('ranked-items', 'password'):
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
value = [[force_str(x), force_str(y)] for x, y in value.items()]
elif field.key == 'computed':
if value is not None:
# embed value in a dict, so it's never necessary to cast the
# value for postgresql
value = {'data': json.loads(JSONEncoder().encode(value)), '@type': 'computed-data'}
elif sql_type == 'varchar':
assert isinstance(value, str)
elif sql_type == 'date':
assert isinstance(value, time.struct_time)
value = datetime.datetime(value.tm_year, value.tm_mon, value.tm_mday)
elif sql_type == 'bytea':
value = bytearray(pickle.dumps(value, protocol=2))
elif sql_type == 'jsonb' and isinstance(value, dict) and value.get('schema'):
# block field, adapt date/field values
value = copy.deepcopy(value)
for field_id, field_type in value.get('schema').items():
if field_type not in ('date', 'file'):
continue
for entry in value.get('data') or []:
subvalue = entry.get(field_id)
if subvalue and field_type == 'date':
entry[field_id] = strftime('%Y-%m-%d', subvalue)
elif subvalue and field_type == 'file':
entry[field_id] = subvalue.__getstate__()
elif sql_type == 'boolean':
pass
sql_dict[get_field_id(field)] = value
if field.store_display_value:
sql_dict['%s_display' % get_field_id(field)] = data.get('%s_display' % field.id)
if field.store_structured_value:
sql_dict['%s_structured' % get_field_id(field)] = bytearray(
pickle.dumps(data.get('%s_structured' % field.id), protocol=2)
)
return sql_dict
@classmethod
def _row2obdata(cls, row, formdef):
obdata = {}
i = len(cls._table_static_fields)
if formdef.geolocations:
i += len(formdef.geolocations.keys())
for field in formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
value = row[i]
if value is not None:
value = str_encode(value)
if field.key == 'ranked-items':
d = {}
for data, rank in value:
try:
d[data] = int(rank)
except ValueError:
d[data] = rank
value = d
elif field.key == 'password':
d = {}
for fmt, val in value:
d[fmt] = force_str(val)
value = d
elif field.key == 'computed':
if isinstance(value, dict) and value.get('@type') == 'computed-data':
value = value.get('data')
if sql_type == 'date':
value = value.timetuple()
elif sql_type == 'bytea':
value = pickle_loads(value)
elif sql_type == 'jsonb' and isinstance(value, dict) and value.get('schema'):
# block field, adapt date/field values
for field_id, field_type in value.get('schema').items():
if field_type not in ('date', 'file'):
continue
for entry in value.get('data') or []:
subvalue = entry.get(field_id)
if subvalue and field_type == 'date':
entry[field_id] = time.strptime(subvalue, '%Y-%m-%d')
elif subvalue and field_type == 'file':
entry[field_id] = PicklableUpload.__new__(PicklableUpload)
entry[field_id].__setstate__(subvalue)
obdata[field.id] = value
i += 1
if field.store_display_value:
value = str_encode(row[i])
obdata['%s_display' % field.id] = value
i += 1
if field.store_structured_value:
value = row[i]
if value is not None:
obdata['%s_structured' % field.id] = pickle_loads(value)
if obdata['%s_structured' % field.id] is None:
del obdata['%s_structured' % field.id]
i += 1
return obdata
@classmethod
@guard_postgres
def remove_object(cls, id):
conn, cur = get_connection_and_cursor()
sql_statement = (
'''DELETE FROM %s
WHERE id = %%(id)s'''
% cls._table_name
)
cur.execute(sql_statement, {'id': str(id)})
conn.commit()
cur.close()
@classmethod
@guard_postgres
def wipe(cls, drop=False, clause=None):
conn, cur = get_connection_and_cursor()
sql_statement = '''DELETE FROM %s''' % cls._table_name
parameters = {}
if clause:
where_clauses, parameters, dummy = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
conn.commit()
cur.close()
@classmethod
@guard_postgres
def get_sorted_ids(cls, order_by, clause=None):
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM %s' % cls._table_name
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
if order_by == 'rank':
try:
fts = [x for x in clause if not callable(x) and x.__class__.__name__ == 'FtsMatch'][0]
except IndexError:
pass
else:
sql_statement += ' ORDER BY ts_rank(fts, plainto_tsquery(%%(c%s)s)) DESC' % id(fts.value)
else:
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement, parameters)
ids = [x[0] for x in cur.fetchall()]
conn.commit()
cur.close()
return ids
class SqlDataMixin(SqlMixin):
_names = None # make sure StorableObject methods fail
_formdef = None
_table_static_fields = [
('id', 'serial'),
('user_id', 'varchar'),
('receipt_time', 'timestamp'),
('status', 'varchar'),
('page_no', 'varchar'),
('anonymised', 'timestamptz'),
('workflow_data', 'bytea'),
('prefilling_data', 'bytea'),
('id_display', 'varchar'),
('workflow_roles', 'bytea'),
('workflow_merged_roles_dict', 'jsonb'),
('workflow_roles_array', 'text[]'),
('concerned_roles_array', 'text[]'),
('actions_roles_array', 'text[]'),
('tracking_code', 'varchar'),
('backoffice_submission', 'boolean'),
('submission_context', 'bytea'),
('submission_agent_id', 'varchar'),
('submission_channel', 'varchar'),
('criticality_level', 'int'),
('last_update_time', 'timestamp'),
('digests', 'jsonb'),
('user_label', 'varchar'),
]
def __init__(self, id=None):
self.id = id
self.data = {}
_evolution = None
@guard_postgres
def get_evolution(self):
if self._evolution is not None:
return self._evolution
if not self.id:
self._evolution = []
return self._evolution
conn, cur = get_connection_and_cursor()
sql_statement = (
'''SELECT id, who, status, time, last_jump_datetime,
comment, parts FROM %s_evolutions
WHERE formdata_id = %%(id)s
ORDER BY id'''
% self._table_name
)
cur.execute(sql_statement, {'id': self.id})
self._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
self._evolution.append(self._row2evo(row, formdata=self))
conn.commit()
cur.close()
return self._evolution
@classmethod
def _row2evo(cls, row, formdata):
o = wcs.formdata.Evolution(formdata)
o._sql_id, o.who, o.status, o.time, o.last_jump_datetime, o.comment = (
str_encode(x) for x in tuple(row[:6])
)
if o.time:
o.time = o.time.timetuple()
if row[6]:
o.parts = pickle_loads(row[6])
return o
def set_evolution(self, value):
self._evolution = value
evolution = property(get_evolution, set_evolution)
@classmethod
@guard_postgres
def load_all_evolutions(cls, values):
# Typically formdata.evolution is loaded on-demand (see above
# property()) and this is fine to minimize queries, especially when
# dealing with a single formdata. However in some places (to compute
# statistics for example) it is sometimes useful to access .evolution
# on a serie of formdata and in that case, it's more efficient to
# optimize the process loading all evolutions in a single batch query.
object_dict = {x.id: x for x in values if x.id and x._evolution is None}
if not object_dict:
return
conn, cur = get_connection_and_cursor()
sql_statement = (
'''SELECT id, who, status, time, last_jump_datetime,
comment, parts, formdata_id
FROM %s_evolutions'''
% cls._table_name
)
sql_statement += ''' WHERE formdata_id IN %(object_ids)s ORDER BY id'''
cur.execute(sql_statement, {'object_ids': tuple(object_dict.keys())})
for value in values:
value._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
formdata_id = tuple(row[:8])[7]
formdata = object_dict.get(formdata_id)
if not formdata:
continue
formdata._evolution.append(formdata._row2evo(row, formdata))
conn.commit()
cur.close()
@guard_postgres
def _set_auto_fields(self, cur):
if self.set_auto_fields():
sql_statement = (
'''UPDATE %s
SET id_display = %%(id_display)s,
digests = %%(digests)s,
user_label = %%(user_label)s
WHERE id = %%(id)s'''
% self._table_name
)
cur.execute(
sql_statement,
{
'id': self.id,
'id_display': self.id_display,
'digests': self.digests,
'user_label': self.user_label,
},
)
@guard_postgres
@invalidate_substitution_cache
def store(self, where=None):
sql_dict = {
'user_id': self.user_id,
'status': self.status,
'page_no': self.page_no,
'workflow_data': self.workflow_data,
'id_display': self.id_display,
'anonymised': self.anonymised,
'tracking_code': self.tracking_code,
'backoffice_submission': self.backoffice_submission,
'submission_context': self.submission_context,
'prefilling_data': self.prefilling_data,
'submission_agent_id': self.submission_agent_id,
'submission_channel': self.submission_channel,
'criticality_level': self.criticality_level,
'workflow_merged_roles_dict': self.workflow_merged_roles_dict,
}
if self.last_update_time:
sql_dict['last_update_time'] = datetime.datetime.fromtimestamp(time.mktime(self.last_update_time))
else:
sql_dict['last_update_time'] = None
if self.receipt_time:
sql_dict['receipt_time'] = datetime.datetime.fromtimestamp(time.mktime(self.receipt_time))
else:
sql_dict['receipt_time'] = None
if self.workflow_roles:
sql_dict['workflow_roles_array'] = []
for x in self.workflow_roles.values():
if isinstance(x, list):
sql_dict['workflow_roles_array'].extend(x)
elif x:
sql_dict['workflow_roles_array'].append(str(x))
else:
sql_dict['workflow_roles_array'] = None
for attr in ('workflow_data', 'workflow_roles', 'submission_context', 'prefilling_data'):
if getattr(self, attr):
sql_dict[attr] = bytearray(pickle.dumps(getattr(self, attr), protocol=2))
else:
sql_dict[attr] = None
for field in (self._formdef.geolocations or {}).keys():
value = (self.geolocations or {}).get(field)
if value:
value = '(%.6f, %.6f)' % (value.get('lon'), value.get('lat'))
sql_dict['geoloc_%s' % field] = value
sql_dict['concerned_roles_array'] = [str(x) for x in self.concerned_roles if x]
sql_dict['actions_roles_array'] = [str(x) for x in self.actions_roles if x]
sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef))
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
if not where:
where = []
where.append(Equal('id', self.id))
where_clauses, parameters, dummy = parse_clause(where)
column_names = list(sql_dict.keys())
sql_dict.update(parameters)
sql_statement = '''UPDATE %s SET %s WHERE %s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
' AND '.join(where_clauses),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
if len(where) > 1:
# abort if nothing was modified and there were extra where clauses
raise NothingToUpdate()
# this has been a request to save a new line with a preset id (for example
# for data migration)
sql_dict['id'] = self.id
column_names.append('id')
sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
self._set_auto_fields(cur)
self.clean_live_evolution_items()
if self._evolution:
# skip all the evolution that already have an _sql_id
for idx, evo in enumerate(self._evolution):
if not hasattr(evo, '_sql_id'):
break
# now we can save all after this idx
for evo in self._evolution[idx:]:
sql_dict = {}
if hasattr(evo, '_sql_id'):
sql_dict.update({'id': evo._sql_id})
sql_statement = (
'''UPDATE %s_evolutions SET
who = %%(who)s,
time = %%(time)s,
last_jump_datetime = %%(last_jump_datetime)s,
status = %%(status)s,
comment = %%(comment)s,
parts = %%(parts)s
WHERE id = %%(id)s
RETURNING id'''
% self._table_name
)
else:
sql_statement = (
'''INSERT INTO %s_evolutions (
id, who, status,
time, last_jump_datetime,
comment, parts,
formdata_id)
VALUES (DEFAULT, %%(who)s, %%(status)s,
%%(time)s, %%(last_jump_datetime)s,
%%(comment)s,
%%(parts)s, %%(formdata_id)s)
RETURNING id'''
% self._table_name
)
sql_dict.update(
{
'who': evo.who,
'status': evo.status,
'time': datetime.datetime.fromtimestamp(time.mktime(evo.time)),
'last_jump_datetime': evo.last_jump_datetime,
'comment': evo.comment,
'formdata_id': self.id,
}
)
if evo.parts:
sql_dict['parts'] = bytearray(pickle.dumps(evo.parts, protocol=2))
else:
sql_dict['parts'] = None
cur.execute(sql_statement, sql_dict)
evo._sql_id = cur.fetchone()[0]
fts_strings = {'A': [], 'B': [], 'C': [], 'D': []}
fts_strings['A'].append(str(self.id))
fts_strings['A'].append(self.get_display_id())
fts_strings['C'].append(self._formdef.name)
if self.tracking_code:
fts_strings['A'].append(self.tracking_code)
def get_all_fields():
for field in self._formdef.get_all_fields():
if field.key == 'block' and self.data.get(field.id):
for data in self.data[field.id].get('data'):
try:
for subfield in field.block.fields:
yield subfield, data
except KeyError:
# block doesn't exist anymore
break
else:
data = self.data
yield field, self.data
for field, data in get_all_fields():
if not data.get(field.id):
continue
value = None
if field.key in ('string', 'text', 'email'):
value = data.get(field.id)
elif field.key in ('item', 'items'):
value = data.get('%s_display' % field.id)
if value:
weight = 'C'
if field.include_in_listing:
weight = 'B'
if isinstance(value, str) and len(value) < 10000:
# avoid overlong strings, typically base64-encoded values
fts_strings[weight].append(value)
elif type(value) in (tuple, list):
for val in value:
fts_strings[weight].append(val)
if self._evolution:
for evo in self._evolution:
if evo.comment:
fts_strings['D'].append(evo.comment)
for part in evo.parts or []:
fts_strings['D'].append(part.render_for_fts() if part.render_for_fts else '')
user = self.get_user()
if user:
fts_strings['A'].append(user.get_display_name())
fts_parts = []
parameters = {'id': self.id}
for weight, strings in fts_strings.items():
# assemble strings
value = ' '.join([force_text(x) for x in strings if x])
fts_parts.append("setweight(to_tsvector(%%(fts%s)s), '%s')" % (weight, weight))
parameters['fts%s' % weight] = FtsMatch.get_fts_value(str(value))
sql_statement = '''UPDATE %s SET fts = %s
WHERE id = %%(id)s''' % (
self._table_name,
' || '.join(fts_parts) or "''",
)
cur.execute(sql_statement, parameters)
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, extra_fields=None):
o = cls()
for static_field, value in zip(cls._table_static_fields, tuple(row[: len(cls._table_static_fields)])):
setattr(o, static_field[0], str_encode(value))
if o.receipt_time:
o.receipt_time = o.receipt_time.timetuple()
for attr in ('workflow_data', 'workflow_roles', 'submission_context', 'prefilling_data'):
if getattr(o, attr):
setattr(o, attr, pickle_loads(getattr(o, attr)))
o.geolocations = {}
for i, field in enumerate((cls._formdef.geolocations or {}).keys()):
value = row[len(cls._table_static_fields) + i]
if not value:
continue
m = re.match(r"\(([^)]+),([^)]+)\)", value)
o.geolocations[field] = {'lon': float(m.group(1)), 'lat': float(m.group(2))}
o.data = cls._row2obdata(row, cls._formdef)
if extra_fields:
# extra fields are tuck at the end
for i, field_id in enumerate(reversed(extra_fields)):
o.data[field_id] = row[-(i + 1)]
del o._last_update_time
return o
@classmethod
def get_data_fields(cls):
data_fields = ['geoloc_%s' % x for x in (cls._formdef.geolocations or {}).keys()]
for field in cls._formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append(get_field_id(field))
if field.store_display_value:
data_fields.append('%s_display' % get_field_id(field))
if field.store_structured_value:
data_fields.append('%s_structured' % get_field_id(field))
return data_fields
@classmethod
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False):
try:
int(str(id))
except (TypeError, ValueError):
if ignore_errors:
return None
else:
raise KeyError()
cur = get_connection_and_cursor()[1]
fields = cls.get_data_fields()
potential_comma = ', '
if not fields:
potential_comma = ''
sql_statement = '''SELECT %s
%s
%s
FROM %s
WHERE id = %%(id)s''' % (
', '.join([x[0] for x in cls._table_static_fields]),
potential_comma,
', '.join(fields),
cls._table_name,
)
cur.execute(sql_statement, {'id': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
@classmethod
@guard_postgres
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True, clause=None):
cur = get_connection_and_cursor()[1]
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if isinstance(value, int):
value = str(value)
if '%s_array' % index in [x[0] for x in cls._table_static_fields]:
sql_statement = '''SELECT id FROM %s WHERE %s_array @> ARRAY[%%(value)s]''' % (
cls._table_name,
index,
)
else:
sql_statement = '''SELECT id FROM %s WHERE %s = %%(value)s''' % (cls._table_name, index)
if where_clauses:
sql_statement += ' AND ' + ' AND '.join(where_clauses)
else:
parameters = {}
parameters.update({'value': value})
cur.execute(sql_statement, parameters)
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
@classmethod
def get_order_by_clause(cls, order_by):
if hasattr(order_by, 'id'):
# form field, convert to its column name
attribute = order_by
order_by = get_field_id(attribute)
if attribute.store_display_value:
order_by = order_by + '_display'
return super().get_order_by_clause(order_by)
@classmethod
@guard_postgres
def fix_sequences(cls):
conn, cur = get_connection_and_cursor()
for table_name in (cls._table_name, '%s_evolutions' % cls._table_name):
sql_statement = '''select max(id) from %s''' % table_name
cur.execute(sql_statement)
max_id = cur.fetchone()[0]
if max_id:
sql_statement = '''ALTER SEQUENCE %s RESTART %s''' % ('%s_id_seq' % table_name, max_id + 1)
cur.execute(sql_statement)
conn.commit()
cur.close()
@classmethod
def rebuild_security(cls, update_all=False):
formdatas = cls.select(order_by='id', iterator=True)
conn, cur = get_connection_and_cursor()
for i, formdata in enumerate(formdatas):
# don't update all formdata before commiting
# this will make us hold locks for much longer than required
if i % 100 == 0:
conn.commit()
if not update_all:
sql_statement = (
'''UPDATE %s
SET concerned_roles_array = %%(roles)s,
actions_roles_array = %%(actions_roles)s,
workflow_merged_roles_dict = %%(workflow_merged_roles_dict)s
WHERE id = %%(id)s
AND (concerned_roles_array <> %%(roles)s OR
actions_roles_array <> %%(actions_roles)s OR
workflow_merged_roles_dict <> %%(workflow_merged_roles_dict)s)'''
% cls._table_name
)
else:
sql_statement = (
'''UPDATE %s
SET concerned_roles_array = %%(roles)s,
actions_roles_array = %%(actions_roles)s,
workflow_merged_roles_dict = %%(workflow_merged_roles_dict)s
WHERE id = %%(id)s'''
% cls._table_name
)
with get_publisher().substitutions.temporary_feed(formdata):
# formdata is already added to sources list in individual
# {concerned,actions}_roles but adding it first here will
# allow cached values to be reused between the properties.
cur.execute(
sql_statement,
{
'id': formdata.id,
'roles': [str(x) for x in formdata.concerned_roles if x],
'actions_roles': [str(x) for x in formdata.actions_roles if x],
'workflow_merged_roles_dict': formdata.workflow_merged_roles_dict,
},
)
conn.commit()
cur.close()
@classmethod
@guard_postgres
def wipe(cls, drop=False):
conn, cur = get_connection_and_cursor()
if drop:
cur.execute('''DROP TABLE %s_evolutions CASCADE''' % cls._table_name)
cur.execute('''DELETE FROM %s''' % cls._table_name) # force trigger execution first.
cur.execute('''DROP TABLE %s CASCADE''' % cls._table_name)
else:
cur.execute('''DELETE FROM %s_evolutions''' % cls._table_name)
cur.execute('''DELETE FROM %s''' % cls._table_name)
conn.commit()
cur.close()
@classmethod
def do_tracking_code_table(cls):
do_tracking_code_table()
class SqlFormData(SqlDataMixin, wcs.formdata.FormData):
pass
class SqlCardData(SqlDataMixin, wcs.carddata.CardData):
pass
class SqlUser(SqlMixin, wcs.users.User):
_table_name = 'users'
_table_static_fields = [
('id', 'serial'),
('name', 'varchar'),
('email', 'varchar'),
('roles', 'varchar[]'),
('is_admin', 'bool'),
('anonymous', 'bool'),
('name_identifiers', 'varchar[]'),
('verified_fields', 'varchar[]'),
('lasso_dump', 'text'),
('last_seen', 'timestamp'),
('ascii_name', 'varchar'),
('deleted_timestamp', 'timestamp'),
('is_active', 'bool'),
]
id = None
def __init__(self, name=None):
self.name = name
self.name_identifiers = []
self.verified_fields = []
self.roles = []
@guard_postgres
@invalidate_substitution_cache
def store(self, skip_global_forms_table_update=False):
sql_dict = {
'name': self.name,
'ascii_name': self.ascii_name,
'email': self.email,
'roles': self.roles,
'is_admin': self.is_admin,
'anonymous': self.anonymous,
'name_identifiers': self.name_identifiers,
'verified_fields': self.verified_fields,
'lasso_dump': self.lasso_dump,
'last_seen': None,
'deleted_timestamp': self.deleted_timestamp,
'is_active': self.is_active,
}
if self.last_seen:
sql_dict['last_seen'] = (datetime.datetime.fromtimestamp(self.last_seen),)
user_formdef = self.get_formdef()
if not self.form_data:
self.form_data = {}
sql_dict.update(self.get_sql_dict_from_data(self.form_data, user_formdef))
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
fts_strings = []
if self.name:
fts_strings.append(('A', self.name))
fts_strings.append(('A', self.ascii_name))
if self.email:
fts_strings.append(('B', self.email))
if user_formdef and user_formdef.fields:
for field in user_formdef.fields:
if not self.form_data.get(field.id):
continue
value = None
if field.key in ('string', 'text', 'email'):
value = self.form_data.get(field.id)
elif field.key in ('item', 'items'):
value = self.form_data.get('%s_display' % field.id)
if value:
if isinstance(value, str):
fts_strings.append(('B', value))
elif type(value) in (tuple, list):
for val in value:
fts_strings.append(('B', val))
fts_parts = []
parameters = {'id': self.id}
for i, (weight, value) in enumerate(fts_strings):
fts_parts.append("setweight(to_tsvector(%%(fts%s)s), '%s')" % (i, weight))
parameters['fts%s' % i] = FtsMatch.get_fts_value(value)
sql_statement = '''UPDATE %s SET fts = %s
WHERE id = %%(id)s''' % (
self._table_name,
' || '.join(fts_parts) or "''",
)
cur.execute(sql_statement, parameters)
if not skip_global_forms_table_update:
# update wcs_all_forms rows with potential name change
sql_statement = 'UPDATE wcs_all_forms SET user_name = %(user_name)s WHERE user_id = %(user_id)s'
cur.execute(sql_statement, {'user_id': str(self.id), 'user_name': self.name})
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
(
o.id,
o.name,
o.email,
o.roles,
o.is_admin,
o.anonymous,
o.name_identifiers,
o.verified_fields,
o.lasso_dump,
o.last_seen,
ascii_name, # XXX what's this ? pylint: disable=unused-variable
o.deleted_timestamp,
o.is_active,
) = (str_encode(x) for x in tuple(row[:13]))
if o.last_seen:
o.last_seen = time.mktime(o.last_seen.timetuple())
if o.roles:
o.roles = [str(x) for x in o.roles]
o.form_data = cls._row2obdata(row, cls.get_formdef())
return o
@classmethod
def get_data_fields(cls):
data_fields = []
for field in cls.get_formdef().get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append(get_field_id(field))
if field.store_display_value:
data_fields.append('%s_display' % get_field_id(field))
if field.store_structured_value:
data_fields.append('%s_structured' % get_field_id(field))
return data_fields
@classmethod
@guard_postgres
def fix_sequences(cls):
conn, cur = get_connection_and_cursor()
sql_statement = '''select max(id) from %s''' % cls._table_name
cur.execute(sql_statement)
max_id = cur.fetchone()[0]
if max_id is not None:
sql_statement = '''ALTER SEQUENCE %s_id_seq RESTART %s''' % (cls._table_name, max_id + 1)
cur.execute(sql_statement)
conn.commit()
cur.close()
@classmethod
@guard_postgres
def get_users_with_name_identifier(cls, name_identifier):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE deleted_timestamp IS NULL
AND %%(value)s = ANY(name_identifiers)''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
)
cur.execute(sql_statement, {'value': name_identifier})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
@classmethod
@guard_postgres
def get_users_with_email(cls, email):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE deleted_timestamp IS NULL
AND email = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
)
cur.execute(sql_statement, {'value': email})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
@classmethod
@guard_postgres
def get_users_with_role(cls, role_id):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE deleted_timestamp IS NULL
AND %%(value)s = ANY(roles)''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
)
cur.execute(sql_statement, {'value': str(role_id)})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
class Role(SqlMixin, wcs.roles.Role):
_table_name = 'roles'
_table_static_fields = [
('id', 'varchar'),
('name', 'varchar'),
('uuid', 'varchar'),
('slug', 'varchar'),
('internal', 'boolean'),
('details', 'varchar'),
('emails', 'varchar[]'),
('emails_to_members', 'boolean'),
('allows_backoffice_access', 'boolean'),
]
_numerical_id = False
@classmethod
def get(cls, id, ignore_errors=False, ignore_migration=False, column=None):
o = super().get(id, ignore_errors=ignore_errors, ignore_migration=ignore_migration, column=column)
if o and not ignore_migration:
if o.migrate():
o.store()
return o
@guard_postgres
def store(self):
if self.slug is None:
# set slug if it's not yet there
self.slug = self.get_new_slug()
sql_dict = {
'id': self.id,
'name': self.name,
'uuid': self.uuid,
'slug': self.slug,
'internal': self.internal,
'details': self.details,
'emails': self.emails,
'emails_to_members': self.emails_to_members,
'allows_backoffice_access': self.allows_backoffice_access,
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
if not self.id:
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
conn.commit()
cur.close()
@classmethod
def get_data_fields(cls):
return []
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('varchar', 'varchar[]'):
setattr(o, field[0], str_encode(value))
else:
setattr(o, field[0], value)
return o
class TransientData(SqlMixin):
# table to keep some transient submission data out of global session dictionary
_table_name = 'transient_data'
_table_static_fields = [
('id', 'varchar'),
('session_id', 'varchar'),
('data', 'bytea'),
('last_update_time', 'timestamptz'),
]
_numerical_id = False
def __init__(self, id, session_id, data):
self.id = id
self.session_id = session_id
self.data = data
@guard_postgres
def store(self):
sql_dict = {
'id': self.id,
'session_id': self.session_id,
'data': bytearray(pickle.dumps(self.data, protocol=2)),
'last_update_time': now(),
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)
ON CONFLICT(id) DO UPDATE SET %s''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
o.id = str_encode(row[0])
o.session_id = row[1]
o.data = pickle_loads(row[2])
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
@guard_postgres
def remove_for_session(cls, session_id):
conn, cur = get_connection_and_cursor()
sql_statement = 'DELETE FROM %s WHERE ' % cls._table_name
sql_statement += 'session_id = %s'
cur.execute(sql_statement, (session_id,))
conn.commit()
cur.close()
class Session(SqlMixin, wcs.sessions.BasicSession):
_table_name = 'sessions'
_table_static_fields = [
('id', 'varchar'),
('session_data', 'bytea'),
]
_numerical_id = False
@classmethod
@guard_postgres
def select_recent(cls, seconds=30 * 60, **kwargs):
clause = [
GreaterOrEqual('last_update_time', datetime.datetime.now() - datetime.timedelta(seconds=seconds))
]
return cls.select(clause=clause, **kwargs)
@guard_postgres
def store(self):
if self.message:
# escape lazy gettext
self.message = (self.message[0], str(self.message[1]))
# store transient data
for v in (self.magictokens or {}).values():
v.store()
# force to be empty, to make sure there's no leftover direct usage
session_data = copy.copy(self.__dict__)
session_data['magictokens'] = None
sql_dict = {
'id': self.id,
'session_data': bytearray(pickle.dumps(session_data, protocol=2)),
# the other fields are stored to run optimized SELECT() against the
# table, they are ignored when loading the data.
'name_identifier': self.name_identifier,
'visiting_objects_keys': list(self.visiting_objects.keys())
if getattr(self, 'visiting_objects')
else None,
'last_update_time': datetime.datetime.now(),
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
o.id = str_encode(row[0])
session_data = pickle_loads(row[1])
for k, v in session_data.items():
setattr(o, k, v)
if o.magictokens:
# migration, obsolete storage of magictokens in session
for k, v in o.magictokens.items():
o.add_magictoken(k, v)
o.magictokens = None
o.store()
return o
@classmethod
def get_sessions_for_saml(cls, name_identifier=Ellipsis, *args, **kwargs):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE name_identifier = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields]),
cls._table_name,
)
cur.execute(sql_statement, {'value': name_identifier})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
@classmethod
def get_sessions_with_visited_object(cls, object_key):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %%(value)s = ANY(visiting_objects_keys)
AND last_update_time > (now() - interval '30 minutes')
''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
)
cur.execute(sql_statement, {'value': object_key})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
@classmethod
def get_data_fields(cls):
return []
def add_magictoken(self, token, data):
assert self.id
super().add_magictoken(token, data)
self.magictokens[token] = TransientData(id=token, session_id=self.id, data=data)
self.magictokens[token].store()
def get_by_magictoken(self, token, default=None):
if not self.magictokens:
self.magictokens = {}
try:
if token not in self.magictokens:
self.magictokens[token] = TransientData.select(
[Equal('session_id', self.id), Equal('id', token)]
)[0]
return self.magictokens[token].data
except IndexError:
return default
def remove_magictoken(self, token):
super().remove_magictoken(token)
TransientData.remove_object(token)
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
_table_name = 'tracking_codes'
_table_static_fields = [
('id', 'varchar'),
('formdef_id', 'varchar'),
('formdata_id', 'varchar'),
]
_numerical_id = False
id = None
@classmethod
def get(cls, id, **kwargs):
return super().get(id.upper(), **kwargs)
@guard_postgres
@invalidate_substitution_cache
def store(self):
sql_dict = {'id': self.id, 'formdef_id': self.formdef_id, 'formdata_id': self.formdata_id}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
ON CONFLICT ON CONSTRAINT tracking_codes_pkey
DO UPDATE
SET %s
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
(o.id, o.formdef_id, o.formdata_id) = (str_encode(x) for x in tuple(row[:3]))
return o
@classmethod
def get_data_fields(cls):
return []
class CustomView(SqlMixin, wcs.custom_views.CustomView):
_table_name = 'custom_views'
_table_static_fields = [
('id', 'varchar'),
('title', 'varchar'),
('slug', 'varchar'),
('user_id', 'varchar'),
('visibility', 'varchar'),
('formdef_type', 'varchar'),
('formdef_id', 'varchar'),
('is_default', 'boolean'),
('order_by', 'varchar'),
('columns', 'jsonb'),
('filters', 'jsonb'),
]
@guard_postgres
@invalidate_substitution_cache
def store(self):
self.ensure_slug()
sql_dict = {
'id': self.id,
'title': self.title,
'slug': self.slug,
'user_id': self.user_id,
'visibility': self.visibility,
'formdef_type': self.formdef_type,
'formdef_id': self.formdef_id,
'is_default': self.is_default,
'order_by': self.order_by,
'columns': self.columns,
'filters': self.filters,
}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
conn.commit()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] == 'varchar':
setattr(o, field[0], str_encode(value))
elif field[1] in ('jsonb', 'boolean'):
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
_table_name = 'snapshots'
_table_static_fields = [
('id', 'serial'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('timestamp', 'timestamptz'),
('user_id', 'varchar'),
('comment', 'text'),
('serialization', 'text'),
('patch', 'text'),
('label', 'varchar'),
]
_table_select_skipped_fields = ['serialization', 'patch']
@guard_postgres
@invalidate_substitution_cache
def store(self):
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = [x for x in sql_dict.keys() if x != 'id']
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
conn.commit()
cur.close()
@classmethod
def select_object_history(cls, obj, clause=None):
return cls.select(
[Equal('object_type', obj.xml_root_node), Equal('object_id', obj.id)] + (clause or []),
order_by='-timestamp',
)
def is_from_object(self, obj):
return self.object_type == obj.xml_root_node and self.object_id == obj.id
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('serial', 'timestamptz'):
setattr(o, field[0], value)
elif field[1] in ('varchar', 'text'):
setattr(o, field[0], str_encode(value))
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_latest(cls, object_type, object_id, complete=False, max_timestamp=None):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT id FROM snapshots
WHERE object_type = %%(object_type)s
AND object_id = %%(object_id)s
%s
%s
ORDER BY timestamp DESC
LIMIT 1''' % (
'AND serialization IS NOT NULL' if complete else '',
'AND timestamp <= %(max_timestamp)s' if max_timestamp else '',
)
cur.execute(
sql_statement,
{'object_type': object_type, 'object_id': object_id, 'max_timestamp': max_timestamp},
)
row = cur.fetchone()
conn.commit()
cur.close()
if row is None:
return None
return cls.get(row[0])
@classmethod
def _get_recent_changes(cls, object_types, user=None, limit=5, offset=0):
conn, cur = get_connection_and_cursor()
clause = [Contains('object_type', object_types)]
if user is not None:
clause.append(Equal('user_id', str(user.id)))
where_clauses, parameters, dummy = parse_clause(clause)
sql_statement = 'SELECT object_type, object_id, MAX(timestamp) AS m FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' GROUP BY object_type, object_id ORDER BY m DESC'
if limit:
sql_statement += ' LIMIT %(limit)s'
parameters['limit'] = limit
if offset:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
cur.execute(sql_statement, parameters)
result = cur.fetchall()
conn.commit()
cur.close()
return result
@classmethod
def count_recent_changes(cls, object_types):
conn, cur = get_connection_and_cursor()
clause = [Contains('object_type', object_types)]
where_clauses, parameters, dummy = parse_clause(clause)
sql_statement = 'SELECT COUNT(*) FROM (SELECT object_type, object_id FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' GROUP BY object_type, object_id) AS s'
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
conn.commit()
cur.close()
return count
class LoggedError(SqlMixin, wcs.logged_errors.LoggedError):
_table_name = 'loggederrors'
_table_static_fields = [
('id', 'serial'),
('kind', 'varchar'),
('tech_id', 'varchar'),
('summary', 'varchar'),
('formdef_class', 'varchar'),
('formdata_id', 'varchar'),
('formdef_id', 'varchar'),
('workflow_id', 'varchar'),
('status_id', 'varchar'),
('status_item_id', 'varchar'),
('expression', 'varchar'),
('expression_type', 'varchar'),
('traceback', 'text'),
('exception_class', 'varchar'),
('exception_message', 'varchar'),
('occurences_count', 'integer'),
('first_occurence_timestamp', 'timestamptz'),
('latest_occurence_timestamp', 'timestamptz'),
]
@guard_postgres
@invalidate_substitution_cache
def store(self):
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
conn, cur = get_connection_and_cursor()
error = self
if not self.id:
existing_errors = list(self.get_with_indexed_value('tech_id', self.tech_id))
if not existing_errors:
column_names = [x for x in sql_dict.keys() if x != 'id']
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
try:
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
except psycopg2.IntegrityError:
# tech_id already used ?
conn.rollback()
existing_errors = list(self.get_with_indexed_value('tech_id', self.tech_id))
if existing_errors:
error = existing_errors[0]
error.record_new_occurence(self)
else:
column_names = sql_dict.keys()
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
assert cur.fetchone() is not None, 'LoggedError id not found'
conn.commit()
cur.close()
return error
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('varchar', 'text'):
setattr(o, field[0], str_encode(value))
else:
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
@guard_postgres
def fix_sequences(cls):
conn, cur = get_connection_and_cursor()
sql_statement = '''select max(id) from %s''' % cls._table_name
cur.execute(sql_statement)
max_id = cur.fetchone()[0]
if max_id is not None:
sql_statement = '''ALTER SEQUENCE %s_id_seq RESTART %s''' % (cls._table_name, max_id + 1)
cur.execute(sql_statement)
conn.commit()
cur.close()
class Token(SqlMixin, wcs.qommon.tokens.Token):
_table_name = 'tokens'
_table_static_fields = [
('id', 'varchar'),
('type', 'varchar'),
('expiration', 'timestamptz'),
('context', 'jsonb'),
]
_numerical_id = False
@guard_postgres
def store(self):
sql_dict = {
'id': self.id,
'type': self.type,
'expiration': self.expiration,
'context': self.context,
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
if not self.id:
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = cur.fetchone()[0]
else:
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
conn.commit()
cur.close()
@classmethod
def get_data_fields(cls):
return []
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
setattr(o, field[0], value)
o.expiration_check()
return o
class classproperty:
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
_formdef_cache = {}
_has_id = False
@classproperty
def _table_static_fields(self):
if not hasattr(self, '__table_static_fields'):
from wcs.formdef import FormDef
fake_formdef = FormDef()
common_fields = get_view_fields(fake_formdef)
self.__table_static_fields = [(x[1], x[0]) for x in common_fields]
self.__table_static_fields.append(('criticality_level', 'criticality_level'))
self.__table_static_fields.append(('geoloc_base_x', 'geoloc_base_x'))
self.__table_static_fields.append(('geoloc_base_y', 'geoloc_base_y'))
self.__table_static_fields.append(('concerned_roles_array', 'concerned_roles_array'))
self.__table_static_fields.append(('anonymised', 'anonymised'))
return self.__table_static_fields
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_objects(cls, *args, **kwargs):
cls._formdef_cache = {}
return super().get_objects(*args, **kwargs)
@classmethod
def _row2ob(cls, row, **kwargs):
formdef_id = row[1]
from wcs.formdef import FormDef
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
o = formdef.data_class()()
for static_field, value in zip(cls._table_static_fields, tuple(row[: len(cls._table_static_fields)])):
setattr(o, static_field[0], str_encode(value))
# [CRITICALITY_2] transform criticality_level back to the expected
# range (see [CRITICALITY_1])
levels = len(formdef.workflow.criticality_levels or [0])
o.criticality_level = levels + o.criticality_level
# convert back unstructured geolocation to the 'native' formdata format.
if o.geoloc_base_x is not None:
o.geolocations = {'base': {'lon': o.geoloc_base_x, 'lat': o.geoloc_base_y}}
return o
@classmethod
@guard_postgres
def load_all_evolutions(cls, formdatas):
classes = {}
for formdata in formdatas:
if formdata._table_name not in classes:
classes[formdata._table_name] = []
classes[formdata._table_name].append(formdata)
for formdatas in classes.values():
formdatas[0].load_all_evolutions(formdatas)
def get_period_query(
period_start=None, include_start=True, period_end=None, include_end=True, criterias=None, parameters=None
):
clause = [NotNull('receipt_time')]
table_name = 'wcs_all_forms'
if criterias:
from wcs.formdef import FormDef
formdef_class = FormDef
for criteria in criterias:
if criteria.__class__.__name__ == 'Equal' and criteria.attribute == 'formdef_klass':
formdef_class = criteria.value
continue
if criteria.__class__.__name__ == 'Equal' and criteria.attribute == 'formdef_id':
# if there's a formdef_id specified, switch to using the
# specific table so we have access to all fields
table_name = get_formdef_table_name(formdef_class.get(criteria.value))
continue
clause.append(criteria)
if period_start:
if include_start:
clause.append(GreaterOrEqual('receipt_time', period_start))
else:
clause.append(Greater('receipt_time', period_start))
if period_end:
if include_end:
clause.append(LessOrEqual('receipt_time', period_end))
else:
clause.append(Less('receipt_time', period_end))
where_clauses, params, dummy = parse_clause(clause)
parameters.update(params)
statement = ' FROM %s ' % table_name
statement += ' WHERE ' + ' AND '.join(where_clauses)
return statement
def get_time_aggregate_query(time_interval, query, group_by, function='DATE_TRUNC'):
statement = f"SELECT {function}('{time_interval}', receipt_time) AS {time_interval}, "
if group_by:
statement += '%s, ' % group_by
statement += 'COUNT(*) '
statement += query
aggregate_fields = time_interval
if group_by:
aggregate_fields += ', %s' % group_by
statement += f' GROUP BY {aggregate_fields} ORDER BY {aggregate_fields}'
return statement
@guard_postgres
def get_actionable_counts(user_roles):
conn, cur = get_connection_and_cursor()
criterias = [
Equal('is_at_endpoint', False),
Intersects('actions_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s
GROUP BY formdef_id''' % ' AND '.join(
where_clauses
)
cur.execute(statement, parameters)
counts = {str(x): y for x, y in cur.fetchall()}
conn.commit()
cur.close()
return counts
@guard_postgres
def get_total_counts(user_roles):
conn, cur = get_connection_and_cursor()
criterias = [
Intersects('concerned_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s
GROUP BY formdef_id''' % ' AND '.join(
where_clauses
)
cur.execute(statement, parameters)
counts = {str(x): y for x, y in cur.fetchall()}
conn.commit()
cur.close()
return counts
@guard_postgres
def get_weekday_totals(period_start=None, period_end=None, criterias=None, group_by=None):
conn, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('dow', statement, group_by, function='DATE_PART')
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x[0]), *x[1:]) for x in result]
coverage = [x[0] for x in result]
for weekday in range(7):
if weekday not in coverage:
result.append((weekday, 0))
result.sort(key=lambda x: x[0])
# add labels,
weekday_names = [
_('Sunday'),
_('Monday'),
_('Tuesday'),
_('Wednesday'),
_('Thursday'),
_('Friday'),
_('Saturday'),
]
result = [(weekday_names[x[0]], *x[1:]) for x in result]
# and move Sunday last
result = result[1:] + [result[0]]
conn.commit()
cur.close()
return result
@guard_postgres
def get_formdef_totals(period_start=None, period_end=None, criterias=None):
conn, cur = get_connection_and_cursor()
statement = '''SELECT formdef_id, COUNT(*)'''
parameters = {}
statement += get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement += ' GROUP BY formdef_id'
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x), y) for x, y in result]
conn.commit()
cur.close()
return result
@guard_postgres
def get_hour_totals(period_start=None, period_end=None, criterias=None, group_by=None):
conn, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('hour', statement, group_by, function='DATE_PART')
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x[0]), *x[1:]) for x in result]
coverage = [x[0] for x in result]
for hour in range(24):
if hour not in coverage:
result.append((hour, 0))
result.sort(key=lambda x: x[0])
conn.commit()
cur.close()
return result
@guard_postgres
def get_monthly_totals(
period_start=None,
period_end=None,
criterias=None,
group_by=None,
):
conn, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('month', statement, group_by)
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [('%d-%02d' % x[0].timetuple()[:2], *x[1:]) for x in raw_result]
if result:
coverage = [x[0] for x in result]
current_month = raw_result[0][0]
last_month = raw_result[-1][0]
while current_month < last_month:
label = '%d-%02d' % current_month.timetuple()[:2]
if label not in coverage:
result.append((label, 0))
current_month = current_month + datetime.timedelta(days=31)
current_month = current_month - datetime.timedelta(days=current_month.day - 1)
result.sort(key=lambda x: x[0])
conn.commit()
cur.close()
return result
@guard_postgres
def get_yearly_totals(period_start=None, period_end=None, criterias=None, group_by=None):
conn, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('year', statement, group_by)
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [(str(x[0].year), *x[1:]) for x in raw_result]
if result:
coverage = [x[0] for x in result]
current_year = raw_result[0][0]
last_year = raw_result[-1][0]
while current_year < last_year:
label = str(current_year.year)
if label not in coverage:
result.append((label, 0))
current_year = current_year + datetime.timedelta(days=366)
result.sort(key=lambda x: x[0])
conn.commit()
cur.close()
return result
@guard_postgres
def get_period_total(
period_start=None, include_start=True, period_end=None, include_end=True, criterias=None
):
conn, cur = get_connection_and_cursor()
statement = '''SELECT COUNT(*)'''
parameters = {}
statement += get_period_query(
period_start=period_start,
include_start=include_start,
period_end=period_end,
include_end=include_end,
criterias=criterias,
parameters=parameters,
)
cur.execute(statement, parameters)
result = int(cur.fetchone()[0])
conn.commit()
cur.close()
return result
# latest migration, number + description (description is not used
# programmaticaly but will make sure git conflicts if two migrations are
# separately added with the same number)
SQL_LEVEL = (64, 'add transient data table')
def migrate_global_views(conn, cur):
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
('wcs_all_forms',),
)
existing_fields = {x[0] for x in cur.fetchall()}
if 'formdef_id' not in existing_fields:
drop_global_views(conn, cur)
do_global_views(conn, cur)
@guard_postgres
def get_sql_level(conn, cur):
do_meta_table(conn, cur, insert_current_sql_level=False)
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', ('sql_level',))
sql_level = int(cur.fetchone()[0])
return sql_level
@guard_postgres
def is_reindex_needed(index, conn, cur):
do_meta_table(conn, cur, insert_current_sql_level=False)
key_name = 'reindex_%s' % index
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name,))
row = cur.fetchone()
if row is None:
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
(key_name, 'no'),
)
return False
return row[0] == 'needed'
@guard_postgres
def set_reindex(index, value, conn=None, cur=None):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
do_meta_table(conn, cur, insert_current_sql_level=False)
key_name = 'reindex_%s' % index
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name,))
row = cur.fetchone()
if row is None:
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
(key_name, value),
)
else:
if row[0] != value:
cur.execute(
'''UPDATE wcs_meta SET value = %s, updated_at = NOW() WHERE key = %s''', (value, key_name)
)
if own_conn:
conn.commit()
cur.close()
def migrate_views(conn, cur):
drop_views(None, conn, cur)
from wcs.carddef import CardDef
from wcs.formdef import FormDef
for formdef in FormDef.select() + CardDef.select():
# make sure all formdefs have up-to-date views
do_formdef_tables(formdef, conn=conn, cur=cur, rebuild_views=True, rebuild_global_views=False)
migrate_global_views(conn, cur)
@guard_postgres
def migrate():
conn, cur = get_connection_and_cursor()
sql_level = get_sql_level(conn, cur)
if sql_level < 0:
# fake code to help in testing the error code path.
raise RuntimeError()
if sql_level < 1: # 1: introduction of tracking_code table
do_tracking_code_table()
if sql_level < 63:
# 42: create snapshots table
# 54: add patch column
# 63: add index
do_snapshots_table()
if sql_level < 50:
# 49: store Role in SQL
# 50: switch role uuid column to varchar
do_role_table()
migrate_legacy_roles()
if sql_level < 53:
# 47: store LoggedErrors in SQL
# 48: remove acked attribute from LoggedError
# 53: add kind column to logged_errors table
do_loggederrors_table()
if sql_level < 40:
# 3: introduction of _structured for user fields
# 4: removal of identification_token
# 12: (first part) add fts to users
# 16: add verified_fields to users
# 21: (first part) add ascii_name to users
# 39: add deleted_timestamp
# 40: add is_active to users
do_user_table()
if sql_level < 64:
# 64: add transient data table
do_transient_data_table()
if sql_level < 32:
# 25: create session_table
# 32: add last_update_time column to session table
do_session_table()
if sql_level < 44:
# 37: create custom_views table
# 44: add is_default column to custom_views table
do_custom_views_table()
if sql_level < 57:
# 57: store tokens in SQL
do_tokens_table()
migrate_legacy_tokens()
if sql_level < 52:
# 2: introduction of formdef_id in views
# 5: add concerned_roles_array, is_at_endpoint and fts to views
# 7: add backoffice_submission to tables
# 8: add submission_context to tables
# 9: add last_update_time to views
# 10: add submission_channel to tables
# 11: add formdef_name and user_name to views
# 13: add backoffice_submission to views
# 14: add criticality_level to tables & views
# 15: add geolocation to formdata
# 19: add geolocation to views
# 20: remove user hash stuff
# 22: rebuild views
# 26: add digest to formdata
# 27: add last_jump_datetime in evolutions tables
# 31: add user_label to formdata
# 33: add anonymised field to global view
# 38: extract submission_agent_id to its own column
# 43: add prefilling_data to formdata
# 52: store digests on formdata and carddata
migrate_views(conn, cur)
if sql_level < 6:
# 6: add actions_roles_array to tables and views
from wcs.formdef import FormDef
migrate_views(conn, cur)
for formdef in FormDef.select():
formdef.data_class().rebuild_security()
if sql_level < 62:
# 12: (second part), store fts in existing rows
# 21: (second part), store ascii_name of users
# 23: (first part), use misc.simplify() over full text queries
# 61: use setweight on formdata & user indexation
# 62: use setweight on formdata & user indexation (reapply)
set_reindex('user', 'needed', conn=conn, cur=cur)
if sql_level < 62:
# 17: store last_update_time in tables
# 18: add user name to full-text search index
# 21: (third part), add user ascii_names to full-text index
# 23: (second part) use misc.simplify() over full text queries
# 28: add display id and formdef name to full-text index
# 29: add evolution parts to full-text index
# 31: add user_label to formdata
# 38: extract submission_agent_id to its own column
# 41: update full text normalization
# 51: add index on formdata blockdef fields
# 55: update full text normalisation (switch to unidecode)
# 61: use setweight on formdata & user indexation
# 62: use setweight on formdata & user indexation (reapply)
set_reindex('formdata', 'needed', conn=conn, cur=cur)
if sql_level < 56:
from wcs.carddef import CardDef
from wcs.formdef import FormDef
# 24: add index on evolution(formdata_id)
# 35: add indexes on formdata(receipt_time) and formdata(anonymised)
# 36: add index on formdata(user_id)
# 45 & 46: add index on formdata(status)
# 56: add GIN indexes to concerned_roles_array and actions_roles_array
for formdef in FormDef.select() + CardDef.select():
do_formdef_indexes(formdef, created=False, conn=conn, cur=cur)
if sql_level < 30:
# 30: actually remove evo.who on anonymised formdatas
from wcs.formdef import FormDef
for formdef in FormDef.select():
for formdata in formdef.data_class().select_iterator(clause=[NotNull('anonymised')]):
if formdata.evolution:
for evo in formdata.evolution:
evo.who = None
formdata.store()
if sql_level < 52:
# 52: store digests on formdata and carddata
from wcs.carddef import CardDef
from wcs.formdef import FormDef
for formdef in FormDef.select() + CardDef.select():
if not formdef.digest_templates:
continue
for formdata in formdef.data_class().select_iterator():
formdata._set_auto_fields(cur) # build digests
if sql_level < 58:
# 58: add workflow_merged_roles_dict as a jsonb column with
# combined formdef and formdata value.
from wcs.carddef import CardDef
from wcs.formdef import FormDef
for formdef in FormDef.select() + CardDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
migrate_views(conn, cur)
set_reindex('formdata', 'needed', conn=conn, cur=cur)
if sql_level < 60:
# 59: switch wcs_all_forms to a trigger-maintained table
# 60: rebuild triggers
from wcs.formdef import FormDef
init_global_table(conn, cur)
for formdef in FormDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
if sql_level != SQL_LEVEL[0]:
cur.execute(
'''UPDATE wcs_meta SET value = %s, updated_at=NOW() WHERE key = %s''',
(str(SQL_LEVEL[0]), 'sql_level'),
)
conn.commit()
cur.close()
@guard_postgres
def reindex():
conn, cur = get_connection_and_cursor()
if is_reindex_needed('user', conn=conn, cur=cur):
for user in SqlUser.select(iterator=True):
user.store()
set_reindex('user', 'done', conn=conn, cur=cur)
from wcs.carddef import CardDef
from wcs.formdef import FormDef
if is_reindex_needed('formdata', conn=conn, cur=cur):
# load and store all formdatas
for formdef in FormDef.select() + CardDef.select():
for formdata in formdef.data_class().select(iterator=True):
try:
formdata.migrate()
formdata.store()
except Exception as e:
print('error reindexing %s (%r)' % (formdata, e))
set_reindex('formdata', 'done', conn=conn, cur=cur)
conn.commit()
cur.close()
@guard_postgres
def formdef_remap_statuses(formdef, mapping):
table_name = get_formdef_table_name(formdef)
evolutions_table_name = table_name + '_evolutions'
unmapped_status_suffix = str(formdef.workflow_id or 'default')
# build the case expression
status_cases = []
for old_id, new_id in mapping.items():
status_cases.append(
SQL('WHEN status = {old_status} THEN {new_status}').format(
old_status=Literal(old_id), new_status=Literal(new_id)
)
)
case_expression = SQL(
'(CASE WHEN status IS NULL THEN NULL '
'{status_cases} '
# keep status alread marked as invalid
'WHEN status LIKE {pattern} THEN status '
# mark unknown statuses as invalid
'ELSE (status || {suffix}) END)'
).format(
status_cases=SQL('').join(status_cases),
pattern=Literal('%-invalid-%'),
suffix=Literal('-invalid-' + unmapped_status_suffix),
)
conn, cur = get_connection_and_cursor()
# update formdatas statuses
cur.execute(
SQL('UPDATE {table_name} SET status = {case_expression} WHERE status <> {draft_status}').format(
table_name=Identifier(table_name), case_expression=case_expression, draft_status=Literal('draft')
)
)
# update evolutions statuses
cur.execute(
SQL('UPDATE {table_name} SET status = {case_expression}').format(
table_name=Identifier(evolutions_table_name), case_expression=case_expression
)
)
conn.commit()
cur.close()