wcs/wcs/sql.py

5723 lines
209 KiB
Python

# w.c.s. - web application for online forms
# Copyright (C) 2005-2012 Entr'ouvert
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import copy
import datetime
import decimal
import hashlib
import io
import itertools
import json
import os
import pickle
import re
import secrets
import shutil
import time
import uuid
from contextlib import ContextDecorator
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from django.utils.encoding import force_bytes, force_str
from django.utils.timezone import localtime, make_aware, now
from psycopg2.errors import UndefinedTable # noqa pylint: disable=no-name-in-module
from psycopg2.sql import SQL, Identifier, Literal
from quixote import get_publisher
import wcs.carddata
import wcs.categories
import wcs.custom_views
import wcs.formdata
import wcs.logged_errors
import wcs.qommon.tokens
import wcs.roles
import wcs.snapshots
import wcs.sql_criterias
import wcs.tracking_code
import wcs.users
from . import qommon
from .carddef import CardDef
from .formdef import FormDef
from .publisher import UnpicklerClass
from .qommon import _, get_cfg
from .qommon.misc import JSONEncoder, is_ascii_digit, strftime
from .qommon.storage import NothingToUpdate, _take, classonlymethod
from .qommon.storage import parse_clause as parse_storage_clause
from .qommon.substitution import invalidate_substitution_cache
from .qommon.upload_storage import PicklableUpload
from .sql_criterias import * # noqa pylint: disable=wildcard-import,unused-wildcard-import
# enable psycogp2 unicode mode, this will fetch postgresql varchar/text columns
# as unicode objects
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
# automatically adapt dictionaries into json fields
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)
SQL_TYPE_MAPPING = {
'title': None,
'subtitle': None,
'comment': None,
'page': None,
'text': 'text',
'bool': 'boolean',
'numeric': 'numeric',
'file': 'bytea',
'date': 'date',
'items': 'text[]',
'table': 'text[][]',
'table-select': 'text[][]',
'tablerows': 'text[][]',
# mapping of dicts
'ranked-items': 'text[][]',
'password': 'text[][]',
# field block
'block': 'jsonb',
# computed data field
'computed': 'jsonb',
}
class LoggingCursor(psycopg2.extensions.cursor):
# keep track of (number of) queries, for tests and cron logging and usage summary.
queries = None
queries_count = 0
queries_log_function = None
def execute(self, query, vars=None):
LoggingCursor.queries_count += 1
if self.queries_log_function:
self.queries_log_function(query)
if self.queries is not None:
self.queries.append(query)
return super().execute(query, vars)
class WcsPgConnection(psycopg2.extensions.connection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cursor_factory = LoggingCursor
self._wcs_in_transaction = False
self._wcs_savepoints = []
class Atomic(ContextDecorator):
"""
Inspired by django Atomic
"""
def __init__(self):
pass
def start_transaction(self):
# get the conn
conn = get_connection()
cursor = conn.cursor()
# if already in txn, start a savepoint
if conn._wcs_in_transaction:
import _thread
savepoint_name = '%s_%s' % (_thread.get_ident(), len(conn._wcs_savepoints))
cursor.execute("SAVEPOINT \"%s\";" % savepoint_name)
conn._wcs_savepoints.append(savepoint_name)
else:
conn._wcs_in_transaction = True
conn.autocommit = False
def rollback(self):
conn = get_connection()
cursor = conn.cursor()
# rollback transaction, or rollback savepoint (and release the savepoint, it won't be used anymore)
if len(conn._wcs_savepoints) == 0:
conn.rollback()
conn._wcs_in_transaction = False
conn.autocommit = True
else:
last_savepoint = conn._wcs_savepoints.pop()
cursor.execute("ROLLBACK TO SAVEPOINT \"%s\";" % last_savepoint)
cursor.execute("RELEASE SAVEPOINT \"%s\";" % last_savepoint)
def commit(self):
conn = get_connection()
cursor = conn.cursor()
# commit transaction, or release savepoint
if len(conn._wcs_savepoints) == 0:
conn.commit()
conn._wcs_in_transaction = False
conn.autocommit = True
else:
last_savepoint = conn._wcs_savepoints.pop()
cursor.execute("RELEASE SAVEPOINT \"%s\";" % last_savepoint)
def __enter__(self):
self.start_transaction()
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
self.commit()
else:
self.rollback()
def partial_commit(self):
self.commit()
self.start_transaction()
def atomic(f=None):
if f is None:
return Atomic()
else:
return Atomic()(f)
class LazyEvolutionList(list):
def __init__(self, dump):
self.dump = dump
def _load(self):
try:
dump = super().__getattribute__('dump')
except AttributeError:
pass
else:
super().__setitem__(slice(0), get_publisher().unpickler_class(io.BytesIO(dump)).load())
del self.dump
def __getattribute__(self, name):
super().__getattribute__('_load')()
return super().__getattribute__(name)
def __bool__(self):
self._load()
return bool(len(self))
def __iter__(self):
self._load()
return super().__iter__()
def __len__(self):
self._load()
return super().__len__()
def __reversed__(self):
self._load()
return super().__reversed__()
def __str__(self):
self._load()
return super().__str__()
def __getitem__(self, index):
self._load()
return super().__getitem__(index)
def __setitem__(self, index, value):
self._load()
return super().__setitem__(index, value)
def __delitem__(self, index):
self._load()
return super().__delitem__(index)
def __iadd__(self, values):
self._load()
return super().__add__(values)
def __repr__(self):
self._load()
return super().__repr__()
def __contains__(self, value):
self._load()
return super().__contains__(value)
def __reduce__(self):
return (list, (), None, iter(self))
def pickle_loads(value):
if hasattr(value, 'tobytes'):
value = value.tobytes()
return UnpicklerClass(io.BytesIO(force_bytes(value))).load()
def get_name_as_sql_identifier(name):
name = qommon.misc.simplify(name)
for char in '<>|{}!?^*+/=\'': # forbidden chars
name = name.replace(char, '')
name = name.replace('-', '_')
return name
def parse_clause(clause):
# returns a three-elements tuple with:
# - a list of SQL 'WHERE' clauses
# - a dict for query parameters
# - a callable, or None if all clauses have been successfully translated
if clause is None:
return ([], {}, None)
if callable(clause): # already a callable
return ([], {}, clause)
# create 'WHERE' clauses
func_clauses = []
where_clauses = []
parameters = {}
for i, element in enumerate(clause):
if callable(element):
func_clauses.append(element)
else:
sql_class = getattr(wcs.sql_criterias, element.__class__.__name__)
if sql_class:
if isinstance(element, wcs.sql_criterias.Criteria):
# already SQL
sql_element = element
else:
# criteria from wcs.qommon.storage, replace it with its SQL variant
# (keep extra _label for display in backoffice)
sql_element = sql_class(**element.__dict__)
sql_element._label = getattr(element, '_label', None)
clause[i] = sql_element
where_clauses.append(sql_element.as_sql())
parameters.update(sql_element.as_sql_param())
else:
func_clauses.append(element.build_lambda())
if func_clauses:
return (where_clauses, parameters, parse_storage_clause(func_clauses))
else:
return (where_clauses, parameters, None)
def str_encode(value):
if isinstance(value, list):
return [str_encode(x) for x in value]
return value
def get_connection(new=False):
if new:
cleanup_connection()
publisher = get_publisher()
if not getattr(publisher, 'pgconn', None):
postgresql_cfg = {}
for param in ('database', 'user', 'password', 'host', 'port'):
value = get_cfg('postgresql', {}).get(param)
if value:
postgresql_cfg[param] = value
if 'database' in postgresql_cfg:
postgresql_cfg['dbname'] = postgresql_cfg.pop('database')
postgresql_cfg['application_name'] = getattr(publisher, 'sql_application_name', None)
try:
pgconn = psycopg2.connect(connection_factory=WcsPgConnection, **postgresql_cfg)
pgconn.autocommit = True
except psycopg2.Error:
if new:
raise
pgconn = None
publisher.pgconn = pgconn
return publisher.pgconn
def cleanup_connection():
if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
get_publisher().pgconn.close()
get_publisher().pgconn = None
def get_connection_and_cursor(new=False):
conn = get_connection(new=new)
try:
cur = conn.cursor()
except psycopg2.InterfaceError:
# may be postgresql was restarted in between
conn = get_connection(new=True)
cur = conn.cursor()
return (conn, cur)
def get_formdef_table_name(formdef):
# PostgreSQL limits identifier length to 63 bytes
#
# The system uses no more than NAMEDATALEN-1 bytes of an identifier;
# longer names can be written in commands, but they will be truncated.
# By default, NAMEDATALEN is 64 so the maximum identifier length is
# 63 bytes. If this limit is problematic, it can be raised by changing
# the NAMEDATALEN constant in src/include/pg_config_manual.h.
#
# as we have to know our table names, we crop the names here, and to an
# extent that allows suffixes (like _evolution) to be added.
assert formdef.id is not None
if hasattr(formdef, 'table_name') and formdef.table_name:
return formdef.table_name
formdef.table_name = '%s_%s_%s' % (
formdef.data_sql_prefix,
formdef.id,
get_name_as_sql_identifier(formdef.url_name)[:30],
)
if not formdef.is_readonly():
formdef.store(object_only=True)
return formdef.table_name
def get_formdef_trigger_function_name(formdef):
assert formdef.id is not None
return '%s_%s_trigger_fn' % (formdef.data_sql_prefix, formdef.id)
def get_formdef_trigger_name(formdef):
assert formdef.id is not None
return '%s_%s_trigger' % (formdef.data_sql_prefix, formdef.id)
def get_formdef_new_id(id_start):
new_id = id_start
_, cur = get_connection_and_cursor()
while True:
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('formdata\\_%s\\_%%' % new_id,),
)
if cur.fetchone()[0] == 0:
break
new_id += 1
cur.close()
return new_id
def get_carddef_new_id(id_start):
new_id = id_start
_, cur = get_connection_and_cursor()
while True:
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('carddata\\_%s\\_%%' % new_id,),
)
if cur.fetchone()[0] == 0:
break
new_id += 1
cur.close()
return new_id
def formdef_wipe():
_, cur = get_connection_and_cursor()
cur.execute(
'''SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('formdata\\_%%\\_%%',),
)
for table_name in [x[0] for x in cur.fetchall()]:
cur.execute('DELETE FROM %s' % table_name) # Force trigger execution
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
cur.execute("SELECT relkind FROM pg_class WHERE relname = 'wcs_all_forms'")
row = cur.fetchone()
# only do the delete if wcs_all_forms is a table and not still a view
if row is not None and row[0] == 'r':
cur.execute('TRUNCATE wcs_all_forms')
cur.close()
def carddef_wipe():
_, cur = get_connection_and_cursor()
cur.execute(
'''SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('carddata\\_%%\\_%%',),
)
for table_name in [x[0] for x in cur.fetchall()]:
cur.execute('DELETE FROM %s' % table_name) # Force trigger execution
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
cur.close()
def get_formdef_view_name(formdef):
prefix = 'wcs_view'
if formdef.data_sql_prefix != 'formdata':
prefix = 'wcs_%s_view' % formdef.data_sql_prefix
return '%s_%s_%s' % (prefix, formdef.id, get_name_as_sql_identifier(formdef.url_name)[:40])
def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=True):
if formdef.id is None:
return []
if getattr(formdef, 'fields', None) is Ellipsis:
# don't touch tables for lightweight objects
return []
if getattr(formdef, 'snapshot_object', None):
# don't touch tables for snapshot objects
return []
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
with atomic():
table_name = get_formdef_table_name(formdef)
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id serial PRIMARY KEY,
user_id varchar,
receipt_time timestamptz,
anonymised timestamptz,
status varchar,
page_no varchar,
workflow_data bytea,
id_display varchar
)'''
% table_name
)
cur.execute(
'''CREATE TABLE %s_evolutions (id serial PRIMARY KEY,
who varchar,
status varchar,
time timestamptz,
last_jump_datetime timestamptz,
comment text,
parts bytea,
formdata_id integer REFERENCES %s (id) ON DELETE CASCADE)'''
% (table_name, table_name)
)
# make sure the table will not be changed while we work on it
cur.execute('LOCK TABLE %s;' % table_name)
cur.execute(
'''SELECT column_name, data_type FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_field_types = {x[0]: x[1] for x in cur.fetchall()}
existing_fields = set(existing_field_types.keys())
needed_fields = {x[0] for x in formdef.data_class()._table_static_fields}
needed_fields.add('fts')
# migrations
if 'fts' not in existing_fields:
# full text search, column and index
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
if 'criticality_level' not in existing_fields:
# criticality leve, with default value
existing_fields.add('criticality_level')
cur.execute(
'''ALTER TABLE %s ADD COLUMN criticality_level integer NOT NULL DEFAULT(0)''' % table_name
)
# generic migration for new columns
for field_name, field_type in formdef.data_class()._table_static_fields:
if field_name not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, field_name, field_type))
# store datetimes with timezone
if existing_field_types.get('receipt_time') not in (None, 'timestamp with time zone'):
cur.execute(f'ALTER TABLE {table_name} ALTER COLUMN receipt_time SET DATA TYPE timestamptz')
if existing_field_types.get('last_update_time') not in (None, 'timestamp with time zone'):
cur.execute(f'ALTER TABLE {table_name} ALTER COLUMN last_update_time SET DATA TYPE timestamptz')
# add new fields
field_integrity_errors = {}
for field in formdef.get_all_fields():
assert field.id is not None
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add(get_field_id(field))
if get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type)
)
else:
existing_type = existing_field_types.get(get_field_id(field))
# map to names returned in data_type column
expected_type = {
'varchar': 'character varying',
'text[]': 'ARRAY',
'text[][]': 'ARRAY',
}.get(sql_type) or sql_type
if existing_type != expected_type:
field_integrity_errors[str(field.id)] = {'got': existing_type, 'expected': expected_type}
if field.store_display_value:
needed_fields.add('%s_display' % get_field_id(field))
if '%s_display' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s varchar'''
% (table_name, '%s_display' % get_field_id(field))
)
if field.store_structured_value:
needed_fields.add('%s_structured' % get_field_id(field))
if '%s_structured' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s bytea'''
% (table_name, '%s_structured' % get_field_id(field))
)
if (field_integrity_errors or None) != formdef.sql_integrity_errors:
formdef.sql_integrity_errors = field_integrity_errors
formdef.store(object_only=True)
for field in (formdef.geolocations or {}).keys():
column_name = 'geoloc_%s' % field
needed_fields.add(column_name)
if column_name not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN %s %s' '' % (table_name, column_name, 'POINT'))
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field))
if formdef.data_sql_prefix == 'formdata':
recreate_trigger(formdef, cur, conn)
# migrations on _evolutions table
cur.execute(
'''SELECT column_name, data_type FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = '%s_evolutions'
'''
% table_name
)
evo_existing_fields = {x[0]: x[1] for x in cur.fetchall()}
if 'last_jump_datetime' not in evo_existing_fields:
cur.execute(
'''ALTER TABLE %s_evolutions ADD COLUMN last_jump_datetime timestamptz''' % table_name
)
if evo_existing_fields.get('time') not in (None, 'timestamp with time zone'):
cur.execute(f'ALTER TABLE {table_name}_evolutions ALTER COLUMN time SET DATA TYPE timestamptz')
if evo_existing_fields.get('last_jump_datetime') not in (None, 'timestamp with time zone'):
cur.execute(
f'ALTER TABLE {table_name}_evolutions ALTER COLUMN last_jump_datetime SET DATA TYPE timestamptz'
)
if rebuild_views or len(existing_fields - needed_fields):
# views may have been dropped when dropping columns, so we recreate
# them even if not asked to.
redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views)
do_formdef_indexes(formdef, cur=cur)
if own_conn:
cur.close()
actions = []
if 'concerned_roles_array' not in existing_fields:
actions.append('rebuild_security')
elif 'actions_roles_array' not in existing_fields:
actions.append('rebuild_security')
if 'tracking_code' not in existing_fields:
# if tracking code has just been added to the table we need to make
# sure the tracking code table does exist.
actions.append('do_tracking_code_table')
return actions
def recreate_trigger(formdef, cur, conn):
# recreate the trigger function, just so it's uptodate
table_name = get_formdef_table_name(formdef)
category_value = formdef.category_id
geoloc_base_x_query = 'NULL'
geoloc_base_y_query = 'NULL'
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
geoloc_base_x_query = 'NEW.geoloc_base[0]'
geoloc_base_y_query = 'NEW.geoloc_base[1]'
if formdef.category_id is None:
category_value = 'NULL'
criticality_levels = len(formdef.workflow.criticality_levels or [0])
endpoint_status = formdef.workflow.get_endpoint_status()
endpoint_status_filter = ', '.join(["'wf-%s'" % x.id for x in endpoint_status])
if endpoint_status_filter == '':
# not the prettiest in town, but will do fine for now.
endpoint_status_filter = "'xxxx'"
formed_name_quotedstring = psycopg2.extensions.QuotedString(formdef.name)
formed_name_quotedstring.encoding = 'utf8'
formdef_name = formed_name_quotedstring.getquoted().decode()
trigger_code = '''
BEGIN
IF TG_OP = 'DELETE' THEN
DELETE FROM wcs_all_forms WHERE formdef_id = {formdef_id} AND id = OLD.id;
RETURN OLD;
ELSEIF TG_OP = 'INSERT' THEN
INSERT INTO wcs_all_forms VALUES (
{category_id},
{formdef_id},
NEW.id,
NEW.user_id,
NEW.receipt_time,
NEW.status,
NEW.id_display,
NEW.submission_agent_id,
NEW.submission_channel,
NEW.backoffice_submission,
NEW.last_update_time,
NEW.digests,
NEW.user_label,
NEW.concerned_roles_array,
NEW.actions_roles_array,
NEW.fts,
NEW.status IN ({endpoint_status}),
{formdef_name},
(SELECT name FROM users WHERE users.id = CAST(NEW.user_id AS INTEGER)),
NEW.criticality_level - {criticality_levels},
{geoloc_base_x},
{geoloc_base_y},
NEW.anonymised,
NEW.statistics_data,
NEW.relations_data);
RETURN NEW;
ELSE
UPDATE wcs_all_forms SET
user_id = NEW.user_id,
receipt_time = NEW.receipt_time,
status = NEW.status,
id_display = NEW.id_display,
submission_agent_id = NEW.submission_agent_id,
submission_channel = NEW.submission_channel,
backoffice_submission = NEW.backoffice_submission,
last_update_time = NEW.last_update_time,
digests = NEW.digests,
user_label = NEW.user_label,
concerned_roles_array = NEW.concerned_roles_array,
actions_roles_array = NEW.actions_roles_array,
fts = NEW.fts,
is_at_endpoint = NEW.status IN ({endpoint_status}),
formdef_name = {formdef_name},
user_name = (SELECT name FROM users WHERE users.id = CAST(NEW.user_id AS INTEGER)),
criticality_level = NEW.criticality_level - {criticality_levels},
geoloc_base_x = {geoloc_base_x},
geoloc_base_y = {geoloc_base_y},
anonymised = NEW.anonymised,
statistics_data = NEW.statistics_data,
relations_data = NEW.relations_data
WHERE formdef_id = {formdef_id} AND id = OLD.id;
RETURN NEW;
END IF;
END;
'''.format(
category_id=category_value, # always valued ? need to handle null otherwise.
formdef_id=formdef.id,
geoloc_base_x=geoloc_base_x_query,
geoloc_base_y=geoloc_base_y_query,
formdef_name=formdef_name,
criticality_levels=criticality_levels,
endpoint_status=endpoint_status_filter,
)
cur.execute(
'''SELECT prosrc FROM pg_proc
WHERE proname = '%s'
'''
% get_formdef_trigger_function_name(formdef)
)
function_row = cur.fetchone()
if function_row is None or function_row[0] != trigger_code:
cur.execute(
'''
CREATE OR REPLACE FUNCTION {trg_fn_name}()
RETURNS trigger
LANGUAGE plpgsql
AS $${code}$$;
'''.format(
trg_fn_name=get_formdef_trigger_function_name(formdef),
code=trigger_code,
)
)
trg_name = get_formdef_trigger_name(formdef)
cur.execute(
'''SELECT 1 FROM pg_trigger
WHERE tgrelid = '%s'::regclass
AND tgname = '%s'
'''
% (table_name, trg_name)
)
if len(cur.fetchall()) == 0:
# compatibility note: to support postgresql<11 we use PROCEDURE and not FUNCTION
cur.execute(
'''CREATE TRIGGER {trg_name} AFTER INSERT OR UPDATE OR DELETE
ON {table_name}
FOR EACH ROW EXECUTE PROCEDURE {trg_fn_name}();
'''.format(
trg_fn_name=get_formdef_trigger_function_name(formdef),
table_name=table_name,
trg_name=trg_name,
)
)
def do_formdef_indexes(formdef, cur, concurrently=False):
table_name = get_formdef_table_name(formdef)
evolutions_table_name = table_name + '_evolutions'
if concurrently:
create_index = 'CREATE INDEX CONCURRENTLY IF NOT EXISTS'
else:
create_index = 'CREATE INDEX IF NOT EXISTS'
cur.execute(f'{create_index} {evolutions_table_name}_fid ON {evolutions_table_name} (formdata_id, id)')
cur.execute(f'{create_index} {table_name}_fts ON {table_name} USING gin(fts)')
attrs = ['receipt_time', 'anonymised', 'user_id', 'status']
if isinstance(formdef, CardDef):
attrs.append('id_display')
for attr in attrs:
cur.execute(f'{create_index} {table_name}_{attr}_idx ON {table_name} ({attr})')
for attr in ('concerned_roles_array', 'actions_roles_array', 'workflow_roles_array'):
idx_name = 'idx_' + attr + '_' + table_name
cur.execute(f'{create_index} {idx_name} ON {table_name} USING gin ({attr})')
def do_user_table():
_, cur = get_connection_and_cursor()
table_name = 'users'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id serial PRIMARY KEY,
name varchar,
ascii_name varchar,
email varchar,
roles text[],
is_active bool,
is_admin bool,
verified_fields text[],
name_identifiers text[],
lasso_dump text,
last_seen timestamp,
deleted_timestamp timestamp,
preferences jsonb,
test_uuid varchar
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {
'id',
'name',
'email',
'roles',
'is_admin',
'name_identifiers',
'verified_fields',
'lasso_dump',
'last_seen',
'fts',
'ascii_name',
'deleted_timestamp',
'is_active',
'preferences',
'test_uuid',
}
from wcs.admin.settings import UserFieldsFormDef
formdef = UserFieldsFormDef()
for field in formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add(get_field_id(field))
if get_field_id(field) not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type))
if field.store_display_value:
needed_fields.add('%s_display' % get_field_id(field))
if '%s_display' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s varchar'''
% (table_name, '%s_display' % get_field_id(field))
)
if field.store_structured_value:
needed_fields.add('%s_structured' % get_field_id(field))
if '%s_structured' % get_field_id(field) not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN %s bytea'''
% (table_name, '%s_structured' % get_field_id(field))
)
# migrations
if 'fts' not in existing_fields:
# full text search
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
if 'verified_fields' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN verified_fields text[]' % table_name)
if 'ascii_name' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN ascii_name varchar' % table_name)
if 'deleted_timestamp' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN deleted_timestamp timestamp' % table_name)
if 'is_active' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN is_active bool DEFAULT TRUE' % table_name)
cur.execute('UPDATE %s SET is_active = FALSE WHERE deleted_timestamp IS NOT NULL' % table_name)
if 'preferences' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN preferences jsonb' % table_name)
if 'test_uuid' not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN test_uuid varchar' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
SqlUser.do_indexes(cur)
cur.close()
def do_role_table():
_, cur = get_connection_and_cursor()
table_name = 'roles'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id VARCHAR PRIMARY KEY,
name VARCHAR,
uuid UUID,
slug VARCHAR UNIQUE,
internal BOOLEAN,
details VARCHAR,
emails VARCHAR[],
emails_to_members BOOLEAN,
allows_backoffice_access BOOLEAN)'''
% table_name
)
cur.execute('ALTER TABLE roles ALTER COLUMN uuid TYPE VARCHAR')
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in Role._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cur.close()
def migrate_legacy_roles():
# store old pickle roles in SQL
for role_id in wcs.roles.Role.keys():
role = wcs.roles.Role.get(role_id)
role.__class__ = Role
role.store()
def do_tracking_code_table():
_, cur = get_connection_and_cursor()
table_name = 'tracking_codes'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
formdef_id varchar,
formdata_id varchar)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {'id', 'formdef_id', 'formdata_id'}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cur.close()
def do_session_table():
_, cur = get_connection_and_cursor()
table_name = 'sessions'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
session_data bytea,
name_identifier varchar,
visiting_objects_keys varchar[]
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {'id', 'session_data', 'name_identifier', 'visiting_objects_keys', 'last_update_time'}
# migrations
if 'last_update_time' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp DEFAULT NOW()''' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
Session.do_indexes(cur)
cur.close()
def do_transient_data_table():
_, cur = get_connection_and_cursor()
table_name = TransientData._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
session_id VARCHAR REFERENCES sessions(id) ON DELETE CASCADE,
data bytea,
last_update_time timestamptz
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in TransientData._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cur.close()
def do_custom_views_table():
_, cur = get_connection_and_cursor()
table_name = 'custom_views'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id varchar PRIMARY KEY,
title varchar,
slug varchar,
user_id varchar,
visibility varchar,
formdef_type varchar,
formdef_id varchar,
is_default boolean,
order_by varchar,
group_by varchar,
columns jsonb,
filters jsonb
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in CustomView._table_static_fields}
# migrations
if 'is_default' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN is_default boolean DEFAULT FALSE''' % table_name)
if 'role_id' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN role_id VARCHAR''' % table_name)
if 'group_by' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN group_by VARCHAR''' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
CustomView.do_indexes(cur)
cur.close()
def do_snapshots_table():
_, cur = get_connection_and_cursor()
table_name = 'snapshots'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
object_type VARCHAR,
object_id VARCHAR,
timestamp TIMESTAMP WITH TIME ZONE,
user_id VARCHAR,
comment TEXT,
serialization TEXT,
patch TEXT,
label VARCHAR,
test_result_id INTEGER,
application_slug VARCHAR,
application_version VARCHAR
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
# generic migration for new columns
for field_name, field_type in Snapshot._table_static_fields:
if field_name not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, field_name, field_type))
needed_fields = {x[0] for x in Snapshot._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
Snapshot.do_indexes(cur)
cur.close()
def do_loggederrors_table():
_, cur = get_connection_and_cursor()
table_name = 'loggederrors'
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
kind VARCHAR,
tech_id VARCHAR UNIQUE,
summary VARCHAR,
formdef_class VARCHAR,
formdata_id VARCHAR,
formdef_id VARCHAR,
workflow_id VARCHAR,
status_id VARCHAR,
status_item_id VARCHAR,
expression VARCHAR,
expression_type VARCHAR,
context JSONB,
traceback TEXT,
exception_class VARCHAR,
exception_message VARCHAR,
occurences_count INTEGER,
first_occurence_timestamp TIMESTAMP WITH TIME ZONE,
latest_occurence_timestamp TIMESTAMP WITH TIME ZONE
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in LoggedError._table_static_fields}
# migrations
if 'kind' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN kind VARCHAR''' % table_name)
if 'context' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN context JSONB''' % table_name)
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
LoggedError.do_indexes(cur)
cur.close()
def do_tokens_table():
_, cur = get_connection_and_cursor()
table_name = Token._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id VARCHAR PRIMARY KEY,
type VARCHAR,
expiration TIMESTAMPTZ,
context JSONB
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in Token._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cur.close()
def migrate_legacy_tokens():
# store old pickle tokens in SQL
for token_id in wcs.qommon.tokens.Token.keys():
try:
token = wcs.qommon.tokens.Token.get(token_id)
except KeyError:
continue
except AttributeError:
# old python2 tokens:
# AttributeError: module 'builtins' has no attribute 'unicode'
wcs.qommon.tokens.Token.remove_object(token_id)
continue
token.__class__ = Token
token.store()
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
cur.execute(
'''SELECT COUNT(t.*), COUNT(i.*) FROM information_schema.tables t
LEFT JOIN pg_indexes i
ON (i.schemaname, i.tablename, i.indexname) = (table_schema, table_name, %s)
WHERE table_schema = 'public'
AND table_name = %s''',
(
'wcs_meta_key',
'wcs_meta',
),
)
info_row = cur.fetchone()
table_exists = info_row[0] > 0
index_exists = info_row[1] > 0
if not table_exists:
cur.execute(
'''CREATE TABLE wcs_meta (id serial PRIMARY KEY,
key varchar,
value varchar,
created_at timestamptz DEFAULT NOW(),
updated_at timestamptz DEFAULT NOW())'''
)
cur.execute('CREATE UNIQUE INDEX IF NOT EXISTS wcs_meta_key ON wcs_meta (key)')
if insert_current_sql_level:
sql_level = SQL_LEVEL[0]
else:
sql_level = 0
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
('sql_level', str(sql_level)),
)
else:
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
('wcs_meta',),
)
existing_fields = {x[0] for x in cur.fetchall()}
if 'created_at' not in existing_fields:
cur.execute('''ALTER TABLE wcs_meta ADD COLUMN created_at timestamptz DEFAULT NOW()''')
if 'updated_at' not in existing_fields:
cur.execute('''ALTER TABLE wcs_meta ADD COLUMN updated_at timestamptz DEFAULT NOW()''')
if not index_exists:
cur.execute('CREATE UNIQUE INDEX IF NOT EXISTS wcs_meta_key ON wcs_meta (key)')
if own_conn:
cur.close()
def redo_views(conn, cur, formdef, rebuild_global_views=False):
if formdef.id is None:
return
if get_publisher().has_site_option('sql-create-formdef-views'):
drop_views(formdef, conn, cur)
do_views(formdef, conn, cur, rebuild_global_views=rebuild_global_views)
def drop_views(formdef, conn, cur):
# remove the global views
drop_global_views(conn, cur)
view_names = []
if formdef:
# remove the form view itself
view_prefix = 'wcs\\_view\\_%s\\_%%' % formdef.id
if formdef.data_sql_prefix != 'formdata':
view_prefix = 'wcs\\_%s\\_view\\_%s\\_%%' % (formdef.data_sql_prefix, formdef.id)
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
(view_prefix,),
)
else:
# if there's no formdef specified, remove all form & card views
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('wcs\\_view\\_%',),
)
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('wcs\\_carddata\\_view\\_%',),
)
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
def get_view_fields(formdef):
view_fields = []
view_fields.append(("int '%s'" % (formdef.category_id or 0), 'category_id'))
view_fields.append(("int '%s'" % (formdef.id or 0), 'formdef_id'))
for field in (
'id',
'user_id',
'receipt_time',
'status',
'id_display',
'submission_agent_id',
'submission_channel',
'backoffice_submission',
'last_update_time',
'digests',
'user_label',
):
view_fields.append((field, field))
return view_fields
def do_views(formdef, conn, cur, rebuild_global_views=True):
# create new view
table_name = get_formdef_table_name(formdef)
view_name = get_formdef_view_name(formdef)
view_fields = get_view_fields(formdef)
column_names = {}
for field in formdef.get_all_fields():
field_key = get_field_id(field)
if field.is_no_data_field:
continue
if field.varname:
# the variable should be fine as is but we pass it through
# get_name_as_sql_identifier nevertheless, to be extra sure it
# doesn't contain invalid characters.
field_name = 'f_%s' % get_name_as_sql_identifier(field.varname)[:50]
else:
field_name = '%s_%s' % (get_field_id(field), get_name_as_sql_identifier(field.label))
field_name = field_name[:50]
if field_name in column_names:
# it may happen that the same varname is used on multiple fields
# (for example in the case of conditional pages), in that situation
# we suffix the field name with an index count
while field_name in column_names:
column_names[field_name] += 1
field_name = '%s_%s' % (field_name, column_names[field_name])
column_names[field_name] = 1
view_fields.append((field_key, field_name))
if field.store_display_value:
field_key = '%s_display' % get_field_id(field)
view_fields.append((field_key, field_name + '_display'))
view_fields.append(
(
'''ARRAY(SELECT status FROM %s_evolutions '''
''' WHERE %s.id = %s_evolutions.formdata_id'''
''' ORDER BY %s_evolutions.time)''' % ((table_name,) * 4),
'status_history',
)
)
# add a is_at_endpoint column, dynamically created againt the endpoint status.
endpoint_status = formdef.workflow.get_endpoint_status()
view_fields.append(
(
'''(SELECT status = ANY(ARRAY[[%s]]::text[]))'''
% ', '.join(["'wf-%s'" % x.id for x in endpoint_status]),
'''is_at_endpoint''',
)
)
# [CRITICALITY_1] Add criticality_level, computed relative to levels in
# the given workflow, so all higher criticalites are sorted first. This is
# reverted when loading the formdata back, in [CRITICALITY_2]
levels = len(formdef.workflow.criticality_levels or [0])
view_fields.append(('''(criticality_level - %d)''' % levels, '''criticality_level'''))
view_fields.append((cur.mogrify('(SELECT text %s)', (formdef.name,)), 'formdef_name'))
view_fields.append(
(
'''(SELECT name FROM users
WHERE users.id = CAST(user_id AS INTEGER))''',
'user_name',
)
)
view_fields.append(('concerned_roles_array', 'concerned_roles_array'))
view_fields.append(('actions_roles_array', 'actions_roles_array'))
view_fields.append(('fts', 'fts'))
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
view_fields.append(('geoloc_base[0]', 'geoloc_base_x'))
view_fields.append(('geoloc_base[1]', 'geoloc_base_y'))
else:
view_fields.append(('NULL::real', 'geoloc_base_x'))
view_fields.append(('NULL::real', 'geoloc_base_y'))
view_fields.append(('anonymised', 'anonymised'))
fields_list = ', '.join(['%s AS %s' % (force_str(x), force_str(y)) for (x, y) in view_fields])
cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % (view_name, fields_list, table_name))
if rebuild_global_views:
do_global_views(conn, cur) # recreate global views
def drop_global_views(conn, cur):
cur.execute(
'''SELECT table_name FROM information_schema.views
WHERE table_schema = 'public'
AND table_name LIKE %s''',
('wcs\\_category\\_%',),
)
view_names = []
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
def update_global_view_formdef_category(formdef):
_, cur = get_connection_and_cursor()
with cur:
cur.execute(
'''UPDATE wcs_all_forms set category_id = %s WHERE formdef_id = %s''',
(formdef.category_id, formdef.id),
)
def do_global_views(conn, cur):
# recreate global views
# XXX TODO: make me dynamic, please ?
cur.execute(
"""CREATE TABLE IF NOT EXISTS wcs_all_forms (
category_id integer,
formdef_id integer NOT NULL,
id integer NOT NULL,
user_id character varying,
receipt_time timestamp with time zone,
status character varying,
id_display character varying,
submission_agent_id character varying,
submission_channel character varying,
backoffice_submission boolean,
last_update_time timestamp with time zone,
digests jsonb,
user_label character varying,
concerned_roles_array text[],
actions_roles_array text[],
fts tsvector,
is_at_endpoint boolean,
formdef_name text,
user_name character varying,
criticality_level integer,
geoloc_base_x double precision,
geoloc_base_y double precision,
anonymised timestamp with time zone,
statistics_data jsonb,
relations_data jsonb
, PRIMARY KEY(formdef_id, id)
)"""
)
create_index = 'CREATE INDEX IF NOT EXISTS'
for attr in ('receipt_time', 'anonymised', 'user_id', 'status'):
cur.execute(f'{create_index} wcs_all_forms_{attr} ON wcs_all_forms ({attr})')
for attr in ('fts', 'concerned_roles_array', 'actions_roles_array'):
cur.execute(f'{create_index} wcs_all_forms_{attr} ON wcs_all_forms USING gin({attr})')
cur.execute(
f'''{create_index} wcs_all_forms_actions_roles_live ON wcs_all_forms
USING gin(actions_roles_array) WHERE (anonymised IS NULL AND is_at_endpoint = false)'''
)
# make sure the table will not be changed while we work on it
with atomic():
cur.execute('LOCK TABLE wcs_all_forms;')
cur.execute(
'''SELECT column_name, data_type FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
('wcs_all_forms',),
)
existing_fields = {x[0]: x[1] for x in cur.fetchall()}
if 'statistics_data' not in existing_fields:
cur.execute('ALTER TABLE wcs_all_forms ADD COLUMN statistics_data jsonb')
if 'relations_data' not in existing_fields:
cur.execute('ALTER TABLE wcs_all_forms ADD COLUMN relations_data jsonb')
if existing_fields.get('receipt_time') not in (None, 'timestamp with time zone'):
cur.execute('ALTER TABLE wcs_all_forms ALTER COLUMN receipt_time SET DATA TYPE timestamptz')
if existing_fields.get('last_update_time') not in (None, 'timestamp with time zone'):
cur.execute('ALTER TABLE wcs_all_forms ALTER COLUMN last_update_time SET DATA TYPE timestamptz')
clean_global_views(conn, cur)
for category in wcs.categories.Category.select():
name = get_name_as_sql_identifier(category.url_name)[:40]
cur.execute(
'''CREATE OR REPLACE VIEW wcs_category_%s AS SELECT * from wcs_all_forms
WHERE category_id = %s'''
% (name, category.id)
)
def clean_global_views(conn, cur):
# Purge of any dead data
valid_ids = [int(i) for i in FormDef.keys()]
if valid_ids:
cur.execute('DELETE FROM wcs_all_forms WHERE NOT formdef_id = ANY(%s)', (valid_ids,))
else:
cur.execute('TRUNCATE wcs_all_forms')
def init_global_table(conn=None, cur=None):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
cur.execute("SELECT relkind FROM pg_class WHERE relname = 'wcs_all_forms';")
rows = cur.fetchall()
if len(rows) != 0:
if rows[0][0] == 'v':
# force wcs_all_forms table creation
cur.execute('DROP VIEW IF EXISTS wcs_all_forms CASCADE;')
else:
assert rows[0][0] == 'r'
cur.execute('DROP TABLE wcs_all_forms CASCADE;')
do_global_views(conn, cur)
# now copy all data into the table
for formdef in FormDef.select():
category_value = formdef.category_id
if formdef.category_id is None:
category_value = 'NULL'
geoloc_base_x_query = 'NULL'
geoloc_base_y_query = 'NULL'
if formdef.geolocations and 'base' in formdef.geolocations:
# default geolocation is in the 'base' key; we have to unstructure the
# field is the POINT type of postgresql cannot be used directly as it
# doesn't have an equality operator.
geoloc_base_x_query = 'geoloc_base[0]'
geoloc_base_y_query = 'geoloc_base[1]'
criticality_levels = len(formdef.workflow.criticality_levels or [0])
endpoint_status = formdef.workflow.get_endpoint_status()
endpoint_status_filter = ', '.join(["'wf-%s'" % x.id for x in endpoint_status])
if endpoint_status_filter == '':
# not the prettiest in town, but will do fine for now.
endpoint_status_filter = "'xxxx'"
formed_name_quotedstring = psycopg2.extensions.QuotedString(formdef.name)
formed_name_quotedstring.encoding = 'utf8'
formdef_name = formed_name_quotedstring.getquoted().decode()
cur.execute(
"""
INSERT INTO wcs_all_forms
SELECT
{category_id},
{formdef_id},
id,
user_id,
receipt_time,
status,
id_display,
submission_agent_id,
submission_channel,
backoffice_submission,
last_update_time,
digests,
user_label,
concerned_roles_array,
actions_roles_array,
fts,
status IN ({endpoint_status}),
{formdef_name},
(SELECT name FROM users WHERE users.id = CAST(user_id AS INTEGER)),
criticality_level - {criticality_levels},
{geoloc_base_x},
{geoloc_base_y},
anonymised,
statistics_data,
relations_data
FROM {table_name}
ON CONFLICT DO NOTHING;
""".format(
table_name=get_formdef_table_name(formdef),
category_id=category_value, # always valued ? need to handle null otherwise.
formdef_id=formdef.id,
geoloc_base_x=geoloc_base_x_query,
geoloc_base_y=geoloc_base_y_query,
formdef_name=formdef_name,
criticality_levels=criticality_levels,
endpoint_status=endpoint_status_filter,
)
)
if own_conn:
cur.close()
class SqlMixin:
_table_name = None
_numerical_id = True
_table_select_skipped_fields = []
_has_id = True
_sql_indexes = None
@classmethod
def do_indexes(cls, cur, concurrently=False):
if concurrently:
create_index = 'CREATE INDEX CONCURRENTLY IF NOT EXISTS'
else:
create_index = 'CREATE INDEX IF NOT EXISTS'
for index in cls.get_sql_indexes():
cur.execute(f'{create_index} {index}')
@classmethod
def get_sql_indexes(cls):
return cls._sql_indexes or []
@classmethod
def keys(cls, clause=None):
_, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
sql_statement = 'SELECT id FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
ids = [x[0] for x in cur.fetchall()]
cur.close()
return ids
@classmethod
def count(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
if func_clause:
# fallback to counting the result of a select()
return len(cls.select(clause))
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
_, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
cur.close()
return count
@classmethod
def exists(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
if func_clause:
# fallback to counting the result of a select()
return len(cls.select(clause))
sql_statement = 'SELECT 1 FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' LIMIT 1'
_, cur = get_connection_and_cursor()
try:
cur.execute(sql_statement, parameters)
except UndefinedTable:
result = False
else:
check = cur.fetchone()
result = check is not None
cur.close()
return result
@classmethod
def get_ids_from_query(cls, query):
_, cur = get_connection_and_cursor()
sql_statement = (
'''SELECT id FROM %s
WHERE fts @@ plainto_tsquery(%%(value)s)'''
% cls._table_name
)
cur.execute(sql_statement, {'value': FtsMatch.get_fts_value(query)})
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
@classmethod
def get(cls, id, ignore_errors=False, ignore_migration=False, column=None):
if column is None and (cls._numerical_id or id is None):
try:
if not (0 < int(str(id)) < 2**31) or not is_ascii_digit(str(id)):
# avoid NumericValueOutOfRange and _ in digits
raise TypeError()
except (TypeError, ValueError):
if ignore_errors and (cls._numerical_id or id is None):
return None
else:
raise KeyError()
_, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %s = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
column or 'id',
)
cur.execute(sql_statement, {'value': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
@classmethod
def get_on_index(cls, value, index, ignore_errors=False, **kwargs):
return cls.get(value, ignore_errors=ignore_errors, column=index)
@classmethod
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None, order_by=None):
if not ids:
return []
tables = [cls._table_name]
columns = [
'%s.%s' % (cls._table_name, column_name)
for column_name in [x[0] for x in cls._table_static_fields] + cls.get_data_fields()
]
extra_fields = []
if fields:
# look for relations
for field in fields:
if not getattr(field, 'is_related_field', False):
continue
if field.parent_field_id == 'user-label':
# relation to user table
carddef_table_alias = 'users'
carddef_table_decl = (
'LEFT JOIN users ON (CAST(%s.user_id AS INTEGER) = users.id)' % cls._table_name
)
else:
carddef_data_table_name = get_formdef_table_name(field.carddef)
carddef_table_alias = 't%s' % id(field.carddef)
if field.carddef.id_template:
carddef_table_decl = 'LEFT JOIN %s AS %s ON (%s.%s = %s.id_display)' % (
carddef_data_table_name,
carddef_table_alias,
cls._table_name,
get_field_id(field.parent_field),
carddef_table_alias,
)
else:
carddef_table_decl = 'LEFT JOIN %s AS %s ON (CAST(%s.%s AS INTEGER) = %s.id)' % (
carddef_data_table_name,
carddef_table_alias,
cls._table_name,
get_field_id(field.parent_field),
carddef_table_alias,
)
if carddef_table_decl not in tables:
tables.append(carddef_table_decl)
column_field_id = field.get_column_field_id()
columns.append('%s.%s' % (carddef_table_alias, column_field_id))
if field.store_display_value:
columns.append('%s.%s_display' % (carddef_table_alias, column_field_id))
if field.store_structured_value:
columns.append('%s.%s_structured' % (carddef_table_alias, column_field_id))
extra_fields.append(field)
_, cur = get_connection_and_cursor()
if cls._numerical_id:
ids_str = ', '.join([str(x) for x in ids])
else:
ids_str = ', '.join(["'%s'" % x for x in ids])
sql_statement = '''SELECT %s
FROM %s
WHERE %s.id IN (%s)''' % (
', '.join(columns),
' '.join(tables),
cls._table_name,
ids_str,
)
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement)
objects = cls.get_objects(cur, extra_fields=extra_fields)
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
if keep_order:
objects_dict = {}
for object in objects:
objects_dict[object.id] = object
objects = [objects_dict[x] for x in ids if objects_dict.get(x)]
return list(objects)
@classmethod
def get_objects_iterator(cls, cur, ignore_errors=False, extra_fields=None):
while True:
row = cur.fetchone()
if row is None:
break
yield cls._row2ob(row, extra_fields=extra_fields)
@classmethod
def get_objects(cls, cur, ignore_errors=False, iterator=False, extra_fields=None):
generator = cls.get_objects_iterator(cur=cur, ignore_errors=ignore_errors, extra_fields=extra_fields)
if iterator:
return generator
return list(generator)
@classmethod
def get_order_by_clause(cls, order_by):
if not order_by:
return ''
def _get_order_by_part(part):
# [SEC_ORDER] security note: it is not possible to use
# prepared statements for ORDER BY clauses, therefore input
# is controlled beforehand (see misc.get_order_by_or_400).
direction = 'ASC'
if part.startswith('-'):
part = part[1:]
direction = 'DESC'
if '->' in part:
# sort on field of block field: f42->'data'->0->>'bf13e4d8a8-fb08-4808-b5ae-02d6247949b9'
# or on digest (digests->>'default'); make sure all parts have their
# dashes changed to underscores
parts = part.split('->')
part = '%s->%s' % (parts[0].replace('-', '_'), '->'.join(parts[1:]))
else:
part = part.replace('-', '_')
fields = ['formdef_name', 'user_name'] # global view fields
fields.extend([x[0] for x in cls._table_static_fields])
fields.extend(cls.get_data_fields())
if part.split('->')[0] not in fields:
# for a sort on field of block field, just check the existence of the block field
return None, None
return part, direction
if not isinstance(order_by, list):
order_by = [order_by]
ordering = []
for part in order_by:
order, direction = _get_order_by_part(part)
if order is None:
continue
ordering.append(f'{order} {direction}')
if not ordering:
return ''
return ' ORDER BY %s' % ', '.join(ordering)
@classmethod
def has_key(cls, id):
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT EXISTS(SELECT 1 FROM %s WHERE id = %%s)' % cls._table_name
with cur:
cur.execute(sql_statement, (id,))
result = cur.fetchall()[0][0]
return result
@classmethod
def select_iterator(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
itersize=None,
):
table_static_fields = [
x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0]
for x in cls._table_static_fields
]
def retrieve():
for object in cls.get_objects(cur, iterator=True):
if object is None:
continue
if func_clause and not func_clause(object):
continue
yield object
if itersize and cls._has_id:
# this case concerns almost all data tables: formdata, card, users, roles
sql_statement = '''SELECT id FROM %s''' % cls._table_name
else:
# this case concerns aggregated views like wcs_all_forms (class
# AnyFormData) which does not have a surrogate key id column
sql_statement = '''SELECT %s FROM %s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += cls.get_order_by_clause(order_by)
if not func_clause:
if limit:
sql_statement += ' LIMIT %(limit)s'
parameters['limit'] = limit
if offset:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
_, cur = get_connection_and_cursor()
with cur:
cur.execute(sql_statement, parameters)
if itersize and cls._has_id:
sql_id_statement = '''SELECT %s FROM %s WHERE id IN %%s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
sql_id_statement += cls.get_order_by_clause(order_by)
ids = [row[0] for row in cur]
while ids:
cur.execute(sql_id_statement, [tuple(ids[:itersize])])
yield from retrieve()
ids = ids[itersize:]
else:
yield from retrieve()
@classmethod
def select(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
iterator=False,
itersize=None,
):
if iterator and not itersize:
itersize = 200
objects = cls.select_iterator(
clause=clause,
order_by=order_by,
ignore_errors=ignore_errors,
limit=limit,
offset=offset,
itersize=itersize,
)
func_clause = parse_clause(clause)[2]
if func_clause and (limit or offset):
objects = _take(objects, limit, offset)
if iterator:
return objects
return list(objects)
@classmethod
def select_distinct(cls, columns, clause=None, first_field_alias=None):
# do note this method returns unicode strings.
column0 = columns[0]
if first_field_alias:
column0 = '%s as %s' % (column0, first_field_alias)
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT DISTINCT ON (%s) %s FROM %s' % (
columns[0],
', '.join([column0] + columns[1:]),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' ORDER BY %s' % (first_field_alias or columns[0])
cur.execute(sql_statement, parameters)
values = [x for x in cur.fetchall()]
cur.close()
return values
def get_sql_dict_from_data(self, data, formdef):
sql_dict = {}
for field in formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
value = data.get(field.id)
if value is not None:
if field.key in ('ranked-items', 'password'):
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
value = [[force_str(x), force_str(y)] for x, y in value.items()]
elif field.key == 'computed':
if value is not None:
# embed value in a dict, so it's never necessary to cast the
# value for postgresql
value = {'data': json.loads(JSONEncoder().encode(value)), '@type': 'computed-data'}
elif sql_type == 'varchar':
assert isinstance(value, str)
elif sql_type == 'date':
assert isinstance(value, time.struct_time)
value = datetime.datetime(value.tm_year, value.tm_mon, value.tm_mday)
elif sql_type == 'bytea':
value = bytearray(pickle.dumps(value, protocol=2))
elif sql_type == 'jsonb' and isinstance(value, dict) and value.get('schema'):
# block field, adapt date/field values
value = copy.deepcopy(value)
for field_id, field_type in value.get('schema').items():
if field_type not in ('date', 'file', 'numeric'):
continue
for entry in value.get('data') or []:
subvalue = entry.get(field_id)
if subvalue and field_type == 'date':
entry[field_id] = strftime('%Y-%m-%d', subvalue)
elif subvalue and field_type == 'file':
entry[field_id] = subvalue.__getstate__()
elif subvalue is not None and field_type == 'numeric':
entry[field_id] = str(subvalue)
elif sql_type == 'boolean':
pass
sql_dict[get_field_id(field)] = value
if field.store_display_value:
sql_dict['%s_display' % get_field_id(field)] = data.get('%s_display' % field.id)
if field.store_structured_value:
sql_dict['%s_structured' % get_field_id(field)] = bytearray(
pickle.dumps(data.get('%s_structured' % field.id), protocol=2)
)
return sql_dict
@classmethod
def _col2obdata(cls, row, i, field):
obdata = {}
field_key = field.key
if field_key == 'related-field':
field_key = field.related_field.key
sql_type = SQL_TYPE_MAPPING.get(field_key, 'varchar')
if sql_type is None:
return ({}, i)
value = row[i]
if value is not None:
value = str_encode(value)
if field.key == 'ranked-items':
d = {}
for data, rank in value:
try:
d[data] = int(rank)
except ValueError:
d[data] = rank
value = d
elif field.key == 'password':
d = {}
for fmt, val in value:
d[fmt] = force_str(val)
value = d
elif field.key == 'computed':
if not isinstance(value, dict):
raise ValueError(
'bad data %s (type %s) in computed field %s' % (value, type(value), field.id)
)
if value.get('@type') == 'computed-data':
value = value.get('data')
if sql_type == 'date':
value = value.timetuple()
elif sql_type == 'bytea':
value = pickle_loads(value)
elif sql_type == 'jsonb' and isinstance(value, dict) and value.get('schema'):
# block field, adapt date/field values
for field_id, field_type in value.get('schema').items():
if field_type not in ('date', 'file', 'numeric'):
continue
for entry in value.get('data') or []:
subvalue = entry.get(field_id)
if subvalue and field_type == 'date':
entry[field_id] = time.strptime(subvalue, '%Y-%m-%d')
elif subvalue and field_type == 'file':
entry[field_id] = PicklableUpload.__new__(PicklableUpload)
entry[field_id].__setstate__(subvalue)
elif subvalue and field_type == 'numeric':
entry[field_id] = decimal.Decimal(subvalue)
obdata[field.id] = value
i += 1
if field.store_display_value:
value = str_encode(row[i])
obdata['%s_display' % field.id] = value
i += 1
if field.store_structured_value:
value = row[i]
if value is not None:
obdata['%s_structured' % field.id] = pickle_loads(value)
if obdata['%s_structured' % field.id] is None:
del obdata['%s_structured' % field.id]
i += 1
return (obdata, i)
@classmethod
def _row2obdata(cls, row, formdef):
obdata = {}
i = len(cls._table_static_fields)
if formdef.geolocations:
i += len(formdef.geolocations.keys())
for field in formdef.get_all_fields():
coldata, i = cls._col2obdata(row, i, field)
obdata.update(coldata)
return obdata
@classmethod
def remove_object(cls, id):
_, cur = get_connection_and_cursor()
sql_statement = (
'''DELETE FROM %s
WHERE id = %%(id)s'''
% cls._table_name
)
cur.execute(sql_statement, {'id': str(id)})
cur.close()
@classonlymethod
def wipe(cls, drop=False, clause=None):
_, cur = get_connection_and_cursor()
sql_statement = '''DELETE FROM %s''' % cls._table_name
parameters = {}
if clause:
where_clauses, parameters, dummy = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
cur.close()
@classmethod
def get_sorted_ids(cls, order_by, clause=None):
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM %s' % cls._table_name
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
if order_by == 'rank':
try:
fts = [x for x in clause if not callable(x) and x.__class__.__name__ == 'FtsMatch'][0]
except IndexError:
pass
else:
sql_statement += ' ORDER BY ts_rank(fts, plainto_tsquery(%%(c%s)s)) DESC' % id(fts.value)
else:
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement, parameters)
ids = [x[0] for x in cur.fetchall()]
cur.close()
return ids
class SqlDataMixin(SqlMixin):
_names = None # make sure StorableObject methods fail
_formdef = None
_table_static_fields = [
('id', 'serial'),
('user_id', 'varchar'),
('receipt_time', 'timestamptz'),
('status', 'varchar'),
('page_no', 'varchar'),
('page_id', 'varchar'),
('anonymised', 'timestamptz'),
('workflow_data', 'bytea'),
('prefilling_data', 'bytea'),
('id_display', 'varchar'),
('workflow_roles', 'bytea'),
# workflow_merged_roles_dict combines workflow_roles from formdef and
# formdata and is used to filter on function assignment.
('workflow_merged_roles_dict', 'jsonb'),
# workflow_roles_array is created from workflow_roles to be used in
# get_ids_with_indexed_value
('workflow_roles_array', 'text[]'),
('concerned_roles_array', 'text[]'),
('actions_roles_array', 'text[]'),
('tracking_code', 'varchar'),
('backoffice_submission', 'boolean'),
('submission_context', 'bytea'),
('submission_agent_id', 'varchar'),
('submission_channel', 'varchar'),
('criticality_level', 'int'),
('last_update_time', 'timestamptz'),
('digests', 'jsonb'),
('user_label', 'varchar'),
('auto_geoloc', 'point'),
('statistics_data', 'jsonb'),
('relations_data', 'jsonb'),
]
def __init__(self, id=None):
self.id = id
self.data = {}
self._has_changed_digest = False
_evolution = None
def get_evolution(self):
if self._evolution is not None:
return self._evolution
if not self.id:
self._evolution = []
return self._evolution
_, cur = get_connection_and_cursor()
sql_statement = (
'''SELECT id, who, status, time, last_jump_datetime,
comment, parts FROM %s_evolutions
WHERE formdata_id = %%(id)s
ORDER BY id'''
% self._table_name
)
cur.execute(sql_statement, {'id': self.id})
self._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
self._evolution.append(self._row2evo(row, formdata=self))
cur.close()
return self._evolution
@classmethod
def _row2evo(cls, row, formdata):
o = wcs.formdata.Evolution(formdata)
o._sql_id, o.who, o.status, o.time, o.last_jump_datetime, o.comment = (
str_encode(x) for x in tuple(row[:6])
)
if row[6]:
o.parts = LazyEvolutionList(row[6])
return o
def set_evolution(self, value):
self._evolution = value
evolution = property(get_evolution, set_evolution)
@classmethod
def load_all_evolutions(cls, values):
# Typically formdata.evolution is loaded on-demand (see above
# property()) and this is fine to minimize queries, especially when
# dealing with a single formdata. However in some places (to compute
# statistics for example) it is sometimes useful to access .evolution
# on a serie of formdata and in that case, it's more efficient to
# optimize the process loading all evolutions in a single batch query.
object_dict = {x.id: x for x in values if x.id and x._evolution is None}
if not object_dict:
return
_, cur = get_connection_and_cursor()
sql_statement = (
'''SELECT id, who, status, time, last_jump_datetime,
comment, parts, formdata_id
FROM %s_evolutions'''
% cls._table_name
)
sql_statement += ''' WHERE formdata_id IN %(object_ids)s ORDER BY id'''
cur.execute(sql_statement, {'object_ids': tuple(object_dict.keys())})
for value in values:
value._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
formdata_id = tuple(row[:8])[7]
formdata = object_dict.get(formdata_id)
if not formdata:
continue
formdata._evolution.append(formdata._row2evo(row, formdata))
cur.close()
@classmethod
def get_resolution_times(
cls, start_status, end_statuses, period_start=None, period_end=None, extra_criterias=None
):
criterias = [StrictNotEqual('f.status', 'draft')]
if period_start:
criterias.append(GreaterOrEqual('f.receipt_time', period_start))
if period_end:
criterias.append(Less('f.receipt_time', period_end))
if extra_criterias:
def alter_criteria(criteria):
# change attributes to point to the formdata table (f)
if hasattr(criteria, 'attribute'):
criteria.attribute = f'f.{criteria.attribute}'
elif hasattr(criteria, 'criteria'): # Not()
alter_criteria(criteria.criteria)
elif hasattr(criteria, 'criterias'): # Or()
for c in criteria.criterias:
alter_criteria(c)
for criteria in extra_criterias:
altered_criteria = copy.deepcopy(criteria)
alter_criteria(altered_criteria)
criterias.append(altered_criteria)
where_clauses, params, dummy = parse_clause(criterias)
params.update(
{
'start_status': start_status,
'end_statuses': tuple(end_statuses),
}
)
table_name = cls._table_name
sql_statement = f'''
SELECT
f.id,
MIN(end_evo.time) - MIN(start_evo.time) as res_time
FROM {table_name} f
JOIN {table_name}_evolutions start_evo ON start_evo.formdata_id = f.id AND start_evo.status = %(start_status)s
JOIN {table_name}_evolutions end_evo ON end_evo.formdata_id = f.id AND end_evo.status IN %(end_statuses)s
WHERE {' AND '.join(where_clauses)}
GROUP BY f.id
ORDER BY res_time
'''
_, cur = get_connection_and_cursor()
with cur:
cur.execute(sql_statement, params)
results = cur.fetchall()
# row[1] will have the resolution time as computed by postgresql
return [row[1].total_seconds() for row in results if row[1].total_seconds() >= 0]
def _set_auto_fields(self, cur):
changed_auto_fields = self.set_auto_fields()
if changed_auto_fields:
self._has_changed_digest = bool('digests' in changed_auto_fields)
sql_statement = (
'''UPDATE %s
SET id_display = %%(id_display)s,
digests = %%(digests)s,
user_label = %%(user_label)s,
statistics_data = %%(statistics_data)s,
relations_data = %%(relations_data)s
WHERE id = %%(id)s'''
% self._table_name
)
cur.execute(
sql_statement,
{
'id': self.id,
'id_display': self.id_display,
'digests': self.digests,
'user_label': self.user_label,
'statistics_data': self.statistics_data,
'relations_data': self.relations_data,
},
)
@invalidate_substitution_cache
def store(self, where=None):
sql_dict = {
'user_id': self.user_id,
'status': self.status,
'page_no': self.page_no,
'workflow_data': self.workflow_data,
'id_display': self.id_display,
'anonymised': self.anonymised,
'tracking_code': self.tracking_code,
'backoffice_submission': self.backoffice_submission,
'submission_context': self.submission_context,
'prefilling_data': self.prefilling_data,
'submission_agent_id': self.submission_agent_id,
'submission_channel': self.submission_channel,
'criticality_level': self.criticality_level,
'workflow_merged_roles_dict': self.workflow_merged_roles_dict,
'statistics_data': self.statistics_data or {},
'relations_data': self.relations_data or {},
}
if self._evolution is not None and hasattr(self, '_last_update_time'):
# if evolution was loaded it may have been been modified, and last update time
# should then be refreshed.
delattr(self, '_last_update_time')
sql_dict['last_update_time'] = self.last_update_time
sql_dict['receipt_time'] = self.receipt_time
if self.workflow_roles:
sql_dict['workflow_roles_array'] = []
for x in self.workflow_roles.values():
if isinstance(x, list):
sql_dict['workflow_roles_array'].extend(x)
elif x:
sql_dict['workflow_roles_array'].append(str(x))
else:
sql_dict['workflow_roles_array'] = None
if hasattr(self, 'uuid'):
sql_dict['uuid'] = self.uuid
if hasattr(self, 'page_id'):
sql_dict['page_id'] = self.page_id
for attr in ('workflow_data', 'workflow_roles', 'submission_context', 'prefilling_data'):
if getattr(self, attr):
sql_dict[attr] = bytearray(pickle.dumps(getattr(self, attr), protocol=2))
else:
sql_dict[attr] = None
for field in (self._formdef.geolocations or {}).keys():
value = (self.geolocations or {}).get(field)
if value:
value = '(%.6f, %.6f)' % (value.get('lon'), value.get('lat'))
sql_dict['geoloc_%s' % field] = value
sql_dict['concerned_roles_array'] = [str(x) for x in self.concerned_roles if x]
sql_dict['actions_roles_array'] = [str(x) for x in self.actions_roles if x]
auto_geoloc_value = self.get_auto_geoloc()
if auto_geoloc_value:
auto_geoloc_value = '(%.6f, %.6f)' % (auto_geoloc_value.get('lon'), auto_geoloc_value.get('lat'))
sql_dict['auto_geoloc'] = auto_geoloc_value
sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef))
_, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
if not where:
where = []
where.append(Equal('id', self.id))
where_clauses, parameters, dummy = parse_clause(where)
column_names = list(sql_dict.keys())
sql_dict.update(parameters)
sql_statement = '''UPDATE %s SET %s WHERE %s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
' AND '.join(where_clauses),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
if len(where) > 1:
# abort if nothing was modified and there were extra where clauses
raise NothingToUpdate()
# this has been a request to save a new line with a preset id (for example
# for data migration)
sql_dict['id'] = self.id
column_names.append('id')
sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
self._set_auto_fields(cur)
self.clean_live_evolution_items()
if self._evolution:
idx = 0
if not getattr(self, '_store_all_evolution', False):
# skip all the evolution that already have an _sql_id
# it's still possible for debugging purpose and special needs
# to store them all using formdata._store_all_evolution = True
for idx, evo in enumerate(self._evolution):
if not hasattr(evo, '_sql_id'):
break
# now we can save all after this idx
for evo in self._evolution[idx:]:
sql_dict = {}
if hasattr(evo, '_sql_id'):
sql_dict.update({'id': evo._sql_id})
sql_statement = (
'''UPDATE %s_evolutions SET
who = %%(who)s,
time = %%(time)s,
last_jump_datetime = %%(last_jump_datetime)s,
status = %%(status)s,
comment = %%(comment)s,
parts = %%(parts)s
WHERE id = %%(id)s
RETURNING id'''
% self._table_name
)
else:
sql_statement = (
'''INSERT INTO %s_evolutions (
id, who, status,
time, last_jump_datetime,
comment, parts,
formdata_id)
VALUES (DEFAULT, %%(who)s, %%(status)s,
%%(time)s, %%(last_jump_datetime)s,
%%(comment)s,
%%(parts)s, %%(formdata_id)s)
RETURNING id'''
% self._table_name
)
sql_dict.update(
{
'who': evo.who,
'status': evo.status,
'time': evo.time,
'last_jump_datetime': evo.last_jump_datetime,
'comment': evo.comment,
'formdata_id': self.id,
}
)
if evo.parts:
sql_dict['parts'] = bytearray(pickle.dumps(evo.parts, protocol=2))
else:
sql_dict['parts'] = None
cur.execute(sql_statement, sql_dict)
evo._sql_id = cur.fetchone()[0]
fts_strings = {'A': set(), 'B': set(), 'C': set(), 'D': set()}
fts_strings['A'].add(str(self.id))
fts_strings['A'].add(self.get_display_id())
fts_strings['C'].add(self._formdef.name)
if self.tracking_code:
fts_strings['A'].add(self.tracking_code)
def get_all_fields():
for field in self._formdef.get_all_fields():
if field.key == 'block' and self.data.get(field.id):
for data in self.data[field.id].get('data'):
try:
for subfield in field.block.fields:
yield subfield, data
except KeyError:
# block doesn't exist anymore
break
else:
data = self.data
yield field, self.data
for field, data in get_all_fields():
if not data.get(field.id):
continue
value = None
if field.key in ('string', 'text', 'email', 'item', 'items'):
value = field.get_fts_value(data)
if value:
weight = 'C'
if field.include_in_listing:
weight = 'B'
if isinstance(value, str) and len(value) < 10000:
# avoid overlong strings, typically base64-encoded values
fts_strings[weight].add(value)
# normalize values looking like phonenumbers, because
# phonenumbers are normalized by the FTS criteria
if len(value) < 30 and value != normalize_phone_number_for_fts_if_needed(value):
# use weight 'D' to give preference to fields with the phonenumber validation
fts_strings['D'].add(normalize_phone_number_for_fts_if_needed(value))
elif type(value) in (tuple, list):
for val in value:
fts_strings[weight].add(val)
if self._evolution:
for evo in self._evolution:
if evo.comment:
fts_strings['D'].add(evo.comment)
for part in evo.parts or []:
fts_strings['D'].add(part.render_for_fts() if part.render_for_fts else '')
user = self.get_user()
if user:
fts_strings['A'].add(user.get_display_name())
fts_parts = []
parameters = {'id': self.id}
for weight, strings in fts_strings.items():
# assemble strings
value = ' '.join([force_str(x) for x in strings if x])
fts_parts.append("setweight(to_tsvector(%%(fts%s)s), '%s')" % (weight, weight))
parameters['fts%s' % weight] = FtsMatch.get_fts_value(str(value))
sql_statement = '''UPDATE %s SET fts = %s
WHERE id = %%(id)s''' % (
self._table_name,
' || '.join(fts_parts) or "''",
)
cur.execute(sql_statement, parameters)
cur.close()
@classmethod
def _row2ob(cls, row, extra_fields=None):
o = cls()
for static_field, value in zip(cls._table_static_fields, tuple(row[: len(cls._table_static_fields)])):
setattr(o, static_field[0], str_encode(value))
for attr in ('workflow_data', 'workflow_roles', 'submission_context', 'prefilling_data'):
if getattr(o, attr):
setattr(o, attr, pickle_loads(getattr(o, attr)))
o.geolocations = {}
for i, field in enumerate((cls._formdef.geolocations or {}).keys()):
value = row[len(cls._table_static_fields) + i]
if not value:
continue
m = re.match(r'\(([^)]+),([^)]+)\)', value)
o.geolocations[field] = {'lon': float(m.group(1)), 'lat': float(m.group(2))}
o.data = cls._row2obdata(row, cls._formdef)
if extra_fields:
# extra fields are tuck at the end
# count number of columns
count = (
len(extra_fields)
+ len([x for x in extra_fields if x.store_display_value])
+ len([x for x in extra_fields if x.store_structured_value])
)
i = len(row) - count
for field in extra_fields:
coldata, i = cls._col2obdata(row, i, field)
o.data.update(coldata)
return o
@classmethod
def get_data_fields(cls):
data_fields = ['geoloc_%s' % x for x in (cls._formdef.geolocations or {}).keys()]
for field in cls._formdef.get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append(get_field_id(field))
if field.store_display_value:
data_fields.append('%s_display' % get_field_id(field))
if field.store_structured_value:
data_fields.append('%s_structured' % get_field_id(field))
return data_fields
@classmethod
def get(cls, id, ignore_errors=False, ignore_migration=False):
try:
if not (0 < int(str(id)) < 2**31) or not is_ascii_digit(str(id)):
# avoid NumericValueOutOfRange and _ in digits
raise TypeError()
except (TypeError, ValueError):
if ignore_errors:
return None
else:
raise KeyError()
_, cur = get_connection_and_cursor()
fields = cls.get_data_fields()
potential_comma = ', '
if not fields:
potential_comma = ''
sql_statement = '''SELECT %s
%s
%s
FROM %s
WHERE id = %%(id)s''' % (
', '.join([x[0] for x in cls._table_static_fields]),
potential_comma,
', '.join(fields),
cls._table_name,
)
cur.execute(sql_statement, {'id': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
@classmethod
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True, clause=None):
_, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if isinstance(value, int):
value = str(value)
if '%s_array' % index in [x[0] for x in cls._table_static_fields]:
sql_statement = '''SELECT id FROM %s WHERE %s_array @> ARRAY[%%(value)s]''' % (
cls._table_name,
index,
)
else:
sql_statement = '''SELECT id FROM %s WHERE %s = %%(value)s''' % (cls._table_name, index)
if where_clauses:
sql_statement += ' AND ' + ' AND '.join(where_clauses)
else:
parameters = {}
parameters.update({'value': value})
cur.execute(sql_statement, parameters)
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
@classmethod
def get_order_by_clause(cls, order_by):
if hasattr(order_by, 'id'):
# form field, convert to its column name
attribute = order_by
order_by = get_field_id(attribute)
if attribute.store_display_value:
order_by = order_by + '_display'
return super().get_order_by_clause(order_by)
@classmethod
def rebuild_security(cls, update_all=False):
formdatas = cls.select(order_by='id', iterator=True)
_, cur = get_connection_and_cursor()
with atomic() as atomic_context:
for i, formdata in enumerate(formdatas):
# don't update all formdata before commiting
# this will make us hold locks for much longer than required
if i % 100 == 0:
atomic_context.partial_commit()
if not update_all:
sql_statement = (
'''UPDATE %s
SET concerned_roles_array = %%(roles)s,
actions_roles_array = %%(actions_roles)s,
workflow_merged_roles_dict = %%(workflow_merged_roles_dict)s
WHERE id = %%(id)s
AND (concerned_roles_array <> %%(roles)s OR
actions_roles_array <> %%(actions_roles)s OR
workflow_merged_roles_dict <> %%(workflow_merged_roles_dict)s)'''
% cls._table_name
)
else:
sql_statement = (
'''UPDATE %s
SET concerned_roles_array = %%(roles)s,
actions_roles_array = %%(actions_roles)s,
workflow_merged_roles_dict = %%(workflow_merged_roles_dict)s
WHERE id = %%(id)s'''
% cls._table_name
)
with get_publisher().substitutions.temporary_feed(formdata):
# formdata is already added to sources list in individual
# {concerned,actions}_roles but adding it first here will
# allow cached values to be reused between the properties.
cur.execute(
sql_statement,
{
'id': formdata.id,
'roles': [str(x) for x in formdata.concerned_roles if x],
'actions_roles': [str(x) for x in formdata.actions_roles if x],
'workflow_merged_roles_dict': formdata.workflow_merged_roles_dict,
},
)
cur.close()
@classonlymethod
def wipe(cls, drop=False):
_, cur = get_connection_and_cursor()
if drop:
cur.execute('''DROP TABLE %s_evolutions CASCADE''' % cls._table_name)
cur.execute('''DELETE FROM %s''' % cls._table_name) # force trigger execution first.
cur.execute('''DROP TABLE %s CASCADE''' % cls._table_name)
else:
cur.execute('''DELETE FROM %s_evolutions''' % cls._table_name)
cur.execute('''DELETE FROM %s''' % cls._table_name)
cur.close()
@classmethod
def do_tracking_code_table(cls):
do_tracking_code_table()
class SqlFormData(SqlDataMixin, wcs.formdata.FormData):
pass
class SqlCardData(SqlDataMixin, wcs.carddata.CardData):
_table_static_fields = SqlDataMixin._table_static_fields + [
('uuid', 'uuid UNIQUE NOT NULL DEFAULT gen_random_uuid()')
]
def store(self, *args, **kwargs):
if self.uuid is None:
self.uuid = str(uuid.uuid4())
is_new_card = bool(not self.id)
super().store(*args, **kwargs)
if self._has_changed_digest and not is_new_card:
self.update_related()
class SqlUser(SqlMixin, wcs.users.User):
_table_name = 'users'
_table_static_fields = [
('id', 'serial'),
('name', 'varchar'),
('email', 'varchar'),
('roles', 'varchar[]'),
('is_admin', 'bool'),
('name_identifiers', 'varchar[]'),
('verified_fields', 'varchar[]'),
('lasso_dump', 'text'),
('last_seen', 'timestamp'),
('ascii_name', 'varchar'),
('deleted_timestamp', 'timestamp'),
('is_active', 'bool'),
('preferences', 'jsonb'),
('test_uuid', 'varchar'),
]
_sql_indexes = [
'users_name_idx ON users (name)',
'users_name_identifiers_idx ON users USING gin(name_identifiers)',
'users_fts ON users USING gin(fts)',
'users_roles_idx ON users USING gin(roles)',
]
id = None
def __init__(self, name=None):
self.name = name
self.name_identifiers = []
self.verified_fields = []
self.roles = []
@classmethod
def select(cls, clause=None, **kwargs):
has_explicit_test_user_filter = bool(
isinstance(clause, list)
and any(x.attribute == 'test_uuid' for x in clause if hasattr(x, 'attribute'))
)
if not has_explicit_test_user_filter:
clause = clause or []
if callable(clause):
clause = [clause]
clause.append(Null('test_uuid'))
return super().select(clause=clause, **kwargs)
@invalidate_substitution_cache
def store(self):
sql_dict = {
'name': self.name,
'ascii_name': self.ascii_name,
'email': self.email,
'roles': self.roles,
'is_admin': self.is_admin,
'name_identifiers': self.name_identifiers,
'verified_fields': self.verified_fields,
'lasso_dump': self.lasso_dump,
'last_seen': None,
'deleted_timestamp': self.deleted_timestamp,
'is_active': self.is_active,
'preferences': self.preferences,
'test_uuid': self.test_uuid,
}
if self.last_seen:
sql_dict['last_seen'] = (datetime.datetime.fromtimestamp(self.last_seen),)
user_formdef = self.get_formdef()
if not self.form_data:
self.form_data = {}
sql_dict.update(self.get_sql_dict_from_data(self.form_data, user_formdef))
_, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
fts_strings = []
if self.name:
fts_strings.append(('A', self.name))
fts_strings.append(('A', self.ascii_name))
if self.email:
fts_strings.append(('B', self.email))
if user_formdef and user_formdef.fields:
for field in user_formdef.fields:
if not self.form_data.get(field.id):
continue
value = None
if field.key in ('string', 'text', 'email'):
value = self.form_data.get(field.id)
elif field.key in ('item', 'items'):
value = self.form_data.get('%s_display' % field.id)
if value:
if isinstance(value, str):
fts_strings.append(('B', value))
elif type(value) in (tuple, list):
for val in value:
fts_strings.append(('B', val))
fts_parts = []
parameters = {'id': self.id}
for i, (weight, value) in enumerate(fts_strings):
fts_parts.append("setweight(to_tsvector(%%(fts%s)s), '%s')" % (i, weight))
parameters['fts%s' % i] = FtsMatch.get_fts_value(value)
sql_statement = '''UPDATE %s SET fts = %s
WHERE id = %%(id)s''' % (
self._table_name,
' || '.join(fts_parts) or "''",
)
cur.execute(sql_statement, parameters)
if hasattr(self, '_name_in_db') and self._name_in_db != self.name:
# update wcs_all_forms rows with name change
sql_statement = 'UPDATE wcs_all_forms SET user_name = %(user_name)s WHERE user_id = %(user_id)s'
cur.execute(sql_statement, {'user_id': str(self.id), 'user_name': self.name})
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
(
o.id,
o.name,
o.email,
o.roles,
o.is_admin,
o.name_identifiers,
o.verified_fields,
o.lasso_dump,
o.last_seen,
ascii_name, # XXX what's this ? pylint: disable=unused-variable
o.deleted_timestamp,
o.is_active,
o.preferences,
o.test_uuid,
) = row[: len(cls._table_static_fields)]
if o.last_seen:
o.last_seen = time.mktime(o.last_seen.timetuple())
if o.roles:
o.roles = [str(x) for x in o.roles]
o.form_data = cls._row2obdata(row, cls.get_formdef())
o._name_in_db = o.name # keep track of stored name
return o
@classmethod
def get_data_fields(cls):
data_fields = []
for field in cls.get_formdef().get_all_fields():
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append(get_field_id(field))
if field.store_display_value:
data_fields.append('%s_display' % get_field_id(field))
if field.store_structured_value:
data_fields.append('%s_structured' % get_field_id(field))
return data_fields
@classmethod
def get_formdef_keepalive_user_uuids(cls):
_, cur = get_connection_and_cursor()
sql_statement = '''SELECT name_identifiers
FROM users
WHERE deleted_timestamp IS NULL
AND name_identifiers IS NOT NULL
AND CAST(users.id AS VARCHAR) IN (
SELECT user_id
FROM wcs_all_forms
WHERE is_at_endpoint = false)
'''
cur.execute(sql_statement)
uuids = []
for row in cur.fetchall():
uuids.extend(row[0])
cur.close()
return uuids
@classmethod
def get_reference_ids(cls):
'''Retrieve ids of users reference in some carddata or formdata.'''
from wcs.carddef import CardDef
from wcs.formdef import FormDef
referenced_ids = set()
_, cur = get_connection_and_cursor()
for objectdef in CardDef.select() + FormDef.select():
data_class = objectdef.data_class()
# referenced in form/card data.user_id
sql_statement = (
'SELECT CAST(data.user_id AS INTEGER) FROM %(table)s AS data WHERE data.user_id IS NOT NULL'
% {
'table': data_class._table_name,
}
)
cur.execute(sql_statement)
referenced_ids.update(user_id for user_id, in cur.fetchall())
# referenced in form/card data_evolution.who
sql_statement = '''SELECT CAST(evolution.who AS INTEGER)
FROM %(table)s AS evolution
WHERE evolution.who != '_submitter'
''' % {
'table': '%s_evolutions' % data_class._table_name,
}
cur.execute(sql_statement)
referenced_ids.update(user_id for user_id, in cur.fetchall())
# referenced in form/card data.workflow_roles_array
sql_statement = '''SELECT CAST(SUBSTRING(workflow_role.workflow_role FROM 7) AS INTEGER)
FROM %(table)s AS data, UNNEST(data.workflow_roles_array) AS workflow_role
WHERE SUBSTRING(workflow_role.workflow_role FROM 1 FOR 6) = '_user:' ''' % {
# users will be referenced as "_user:<user id>" entries in
# workflow_roles_array, filter on values starting with "_user:"
# (FROM 1 FOR 6) and extract the id part (FROM 7).
'table': data_class._table_name,
}
cur.execute(sql_statement)
referenced_ids.update(user_id for user_id, in cur.fetchall())
cur.close()
return referenced_ids
@classmethod
def get_to_delete_ids(cls):
'''Retrieve ids of users which are deleted on the IdP and are no more referenced by any form or card.'''
# fetch marked as deleted users
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT users.id FROM users WHERE users.deleted_timestamp IS NOT NULL'
cur.execute(sql_statement)
deleted_ids = {user_id for user_id, in cur.fetchall()}
cur.close()
to_delete_ids = deleted_ids.difference(cls.get_reference_ids())
return to_delete_ids
class Role(SqlMixin, wcs.roles.Role):
_table_name = 'roles'
_table_static_fields = [
('id', 'varchar'),
('name', 'varchar'),
('uuid', 'varchar'),
('slug', 'varchar'),
('internal', 'boolean'),
('details', 'varchar'),
('emails', 'varchar[]'),
('emails_to_members', 'boolean'),
('allows_backoffice_access', 'boolean'),
]
_numerical_id = False
@classmethod
def get(cls, id, ignore_errors=False, ignore_migration=False, column=None):
o = super().get(id, ignore_errors=ignore_errors, ignore_migration=ignore_migration, column=column)
if o and not ignore_migration:
if o.migrate():
o.store()
return o
def store(self):
if self.slug is None:
# set slug if it's not yet there
self.slug = self.get_new_slug()
sql_dict = {
'id': self.id,
'name': self.name,
'uuid': self.uuid,
'slug': self.slug,
'internal': self.internal,
'details': self.details,
'emails': self.emails,
'emails_to_members': self.emails_to_members,
'allows_backoffice_access': self.allows_backoffice_access,
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
if not self.id:
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
sql_statement = '''INSERT INTO %s (%s) VALUES (%s) ON CONFLICT(id) DO UPDATE SET %s''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
self.adjust_permissions()
@classmethod
def get_data_fields(cls):
return []
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('varchar', 'varchar[]'):
setattr(o, field[0], str_encode(value))
else:
setattr(o, field[0], value)
return o
class TransientData(SqlMixin):
# table to keep some transient submission data and form tokens out of session object
_table_name = 'transient_data'
_table_static_fields = [
('id', 'varchar'),
('session_id', 'varchar'),
('data', 'bytea'),
('last_update_time', 'timestamptz'),
]
_numerical_id = False
def __init__(self, id, session_id, data):
self.id = id
self.session_id = session_id
self.data = data
def store(self):
sql_dict = {
'id': self.id,
'session_id': self.session_id,
'data': bytearray(pickle.dumps(self.data, protocol=2)) if self.data is not None else None,
'last_update_time': now(),
}
_, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)
ON CONFLICT(id) DO UPDATE SET %s''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError as e:
if 'transient_data_session_id_fkey' not in str(e):
raise
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
o.id = str_encode(row[0])
o.session_id = row[1]
o.data = pickle_loads(row[2]) if row[2] else None
return o
@classmethod
def get_data_fields(cls):
return []
class Session(SqlMixin, wcs.sessions.BasicSession):
_table_name = 'sessions'
_table_static_fields = [
('id', 'varchar'),
('session_data', 'bytea'),
]
_numerical_id = False
_sql_indexes = [
'sessions_ts ON sessions (last_update_time)',
]
@classmethod
def select_recent_with_visits(cls, seconds=30 * 60, **kwargs):
clause = [
GreaterOrEqual('last_update_time', datetime.datetime.now() - datetime.timedelta(seconds=seconds)),
NotNull('visiting_objects_keys'),
]
return cls.select(clause=clause, **kwargs)
@classmethod
def clean(cls):
now = time.time()
last_usage_limit = now - 3 * 86400
creation_limit = now - 30 * 86400
last_update_limit = datetime.datetime.now() - datetime.timedelta(days=3)
session_ids = []
for session in cls.select_iterator(clause=[Less('last_update_time', last_update_limit)]):
if session._access_time < last_usage_limit or session._creation_time < creation_limit:
session_ids.append(session.id)
cls.wipe(clause=[Contains('id', session_ids)])
def store(self):
if self.message:
# escape lazy gettext
self.message = (self.message[0], str(self.message[1]))
# store transient data
for v in (self.magictokens or {}).values():
v.store()
# force to be empty, to make sure there's no leftover direct usage
session_data = copy.copy(self.__dict__)
session_data['magictokens'] = None
sql_dict = {
'id': self.id,
'session_data': bytearray(pickle.dumps(session_data, protocol=2)),
# the other fields are stored to run optimized SELECT() against the
# table, they are ignored when loading the data.
'name_identifier': self.name_identifier,
'visiting_objects_keys': list(self.visiting_objects.keys())
if getattr(self, 'visiting_objects')
else None,
'last_update_time': datetime.datetime.now(),
}
_, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)
ON CONFLICT (id) DO UPDATE SET %s''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
o.id = str_encode(row[0])
session_data = pickle_loads(row[1])
for k, v in session_data.items():
setattr(o, k, v)
if o.magictokens:
# migration, obsolete storage of magictokens in session
for k, v in o.magictokens.items():
o.add_magictoken(k, v)
o.magictokens = None
o.store()
return o
@classmethod
def get_sessions_for_saml(cls, name_identifier=Ellipsis, *args, **kwargs):
_, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE name_identifier = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields]),
cls._table_name,
)
cur.execute(sql_statement, {'value': name_identifier})
objects = cls.get_objects(cur)
cur.close()
return objects
@classmethod
def get_sessions_with_visited_object(cls, object_key):
_, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %%(value)s = ANY(visiting_objects_keys)
AND last_update_time > (now() - interval '30 minutes')
''' % (
', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()),
cls._table_name,
)
cur.execute(sql_statement, {'value': object_key})
objects = cls.get_objects(cur)
cur.close()
return objects
@classmethod
def get_data_fields(cls):
return []
def add_magictoken(self, token, data):
assert self.id
super().add_magictoken(token, data)
self.magictokens[token] = TransientData(id=token, session_id=self.id, data=data)
self.magictokens[token].store()
def get_by_magictoken(self, token, default=None):
if not self.magictokens:
self.magictokens = {}
try:
if token not in self.magictokens:
self.magictokens[token] = TransientData.select(
[Equal('session_id', self.id), Equal('id', token)]
)[0]
return self.magictokens[token].data
except IndexError:
return default
def remove_magictoken(self, token):
super().remove_magictoken(token)
TransientData.remove_object(token)
def create_form_token(self):
token = TransientData(id=secrets.token_urlsafe(16), session_id=self.id, data=None)
token.store()
return token.id
def has_form_token(self, token):
return TransientData.exists([Equal('id', token)])
def remove_form_token(self, token):
TransientData.remove_object(token)
def create_token(self, usage, context):
context['session_id'] = self.id
context['usage'] = usage
token_id = hashlib.sha1(repr(context).encode()).hexdigest()
try:
token = self.get_token(usage, token_id)
except KeyError:
token = TransientData(id=token_id, session_id=self.id, data=context)
token.store()
return token
def get_token(self, usage, token_id):
tokens = TransientData.select([Equal('id', token_id), Equal('session_id', self.id)])
if not tokens or tokens[0].data.get('usage') != usage: # missing or misusage
raise KeyError(token_id)
return tokens[0]
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
_table_name = 'tracking_codes'
_table_static_fields = [
('id', 'varchar'),
('formdef_id', 'varchar'),
('formdata_id', 'varchar'),
]
_numerical_id = False
id = None
@classmethod
def get(cls, id, **kwargs):
return super().get(id.upper(), **kwargs)
@invalidate_substitution_cache
def store(self):
sql_dict = {'id': self.id, 'formdef_id': self.formdef_id, 'formdata_id': self.formdata_id}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
ON CONFLICT ON CONSTRAINT tracking_codes_pkey
DO UPDATE
SET %s
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
(o.id, o.formdef_id, o.formdata_id) = (str_encode(x) for x in tuple(row[:3]))
return o
@classmethod
def get_data_fields(cls):
return []
class CustomView(SqlMixin, wcs.custom_views.CustomView):
_table_name = 'custom_views'
_table_static_fields = [
('id', 'varchar'),
('title', 'varchar'),
('slug', 'varchar'),
('user_id', 'varchar'),
('role_id', 'varchar'),
('visibility', 'varchar'),
('formdef_type', 'varchar'),
('formdef_id', 'varchar'),
('is_default', 'boolean'),
('order_by', 'varchar'),
('group_by', 'varchar'),
('columns', 'jsonb'),
('filters', 'jsonb'),
]
_sql_indexes = [
'custom_views_formdef_type_id ON custom_views (formdef_type, formdef_id)',
]
@invalidate_substitution_cache
def store(self):
self.ensure_slug()
sql_dict = {
'id': self.id,
'title': self.title,
'slug': self.slug,
'user_id': self.user_id,
'role_id': self.role_id,
'visibility': self.visibility,
'formdef_type': self.formdef_type,
'formdef_id': self.formdef_id,
'is_default': self.is_default,
'order_by': self.order_by,
'group_by': self.group_by,
'columns': self.columns,
'filters': self.filters,
}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = str_encode(cur.fetchone()[0])
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] == 'varchar':
setattr(o, field[0], str_encode(value))
elif field[1] in ('jsonb', 'boolean'):
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
_table_name = 'snapshots'
_table_static_fields = [
('id', 'serial'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('timestamp', 'timestamptz'),
('user_id', 'varchar'),
('comment', 'text'),
('serialization', 'text'),
('patch', 'text'),
('label', 'varchar'),
('test_result_id', 'integer'),
('application_slug', 'varchar'),
('application_version', 'varchar'),
]
_table_select_skipped_fields = ['serialization', 'patch']
_sql_indexes = [
'snapshots_object_by_date ON snapshots (object_type, object_id, timestamp DESC)',
]
@invalidate_substitution_cache
def store(self):
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
_, cur = get_connection_and_cursor()
if not self.id:
column_names = [x for x in sql_dict.keys() if x != 'id']
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
cur.close()
@classmethod
def select_object_history(cls, obj, clause=None):
return cls.select(
[Equal('object_type', obj.xml_root_node), Equal('object_id', str(obj.id))] + (clause or []),
order_by='-timestamp',
)
def is_from_object(self, obj):
return self.object_type == obj.xml_root_node and self.object_id == str(obj.id)
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('serial', 'timestamptz', 'integer'):
setattr(o, field[0], value)
elif field[1] in ('varchar', 'text'):
setattr(o, field[0], str_encode(value))
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_latest(cls, object_type, object_id, complete=False, max_timestamp=None):
_, cur = get_connection_and_cursor()
sql_statement = '''SELECT id FROM snapshots
WHERE object_type = %%(object_type)s
AND object_id = %%(object_id)s
%s
%s
ORDER BY timestamp DESC
LIMIT 1''' % (
'AND serialization IS NOT NULL' if complete else '',
'AND timestamp <= %(max_timestamp)s' if max_timestamp else '',
)
cur.execute(
sql_statement,
{'object_type': object_type, 'object_id': str(object_id), 'max_timestamp': max_timestamp},
)
row = cur.fetchone()
cur.close()
if row is None:
return None
return cls.get(row[0])
@classmethod
def _get_recent_changes(cls, object_types, user=None, limit=5, offset=0):
_, cur = get_connection_and_cursor()
clause = [Contains('object_type', object_types)]
if user is not None:
clause.append(Equal('user_id', str(user.id)))
where_clauses, parameters, dummy = parse_clause(clause)
sql_statement = 'SELECT object_type, object_id, MAX(timestamp) AS m FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' GROUP BY object_type, object_id ORDER BY m DESC'
if limit:
sql_statement += ' LIMIT %(limit)s'
parameters['limit'] = limit
if offset:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
cur.execute(sql_statement, parameters)
result = cur.fetchall()
cur.close()
return result
@classmethod
def count_recent_changes(cls, object_types):
_, cur = get_connection_and_cursor()
clause = [Contains('object_type', object_types)]
where_clauses, parameters, dummy = parse_clause(clause)
sql_statement = 'SELECT COUNT(*) FROM (SELECT object_type, object_id FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' GROUP BY object_type, object_id) AS s'
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
cur.close()
return count
class LoggedError(SqlMixin, wcs.logged_errors.LoggedError):
_table_name = 'loggederrors'
_table_static_fields = [
('id', 'serial'),
('kind', 'varchar'),
('tech_id', 'varchar'),
('summary', 'varchar'),
('formdef_class', 'varchar'),
('formdata_id', 'varchar'),
('formdef_id', 'varchar'),
('workflow_id', 'varchar'),
('status_id', 'varchar'),
('status_item_id', 'varchar'),
('expression', 'varchar'),
('expression_type', 'varchar'),
('context', 'jsonb'),
('traceback', 'text'),
('exception_class', 'varchar'),
('exception_message', 'varchar'),
('occurences_count', 'integer'),
('first_occurence_timestamp', 'timestamptz'),
('latest_occurence_timestamp', 'timestamptz'),
]
_sql_indexes = [
'loggederrors_formdef_id_idx ON loggederrors (formdef_id)',
'loggederrors_workflow_id_idx ON loggederrors (workflow_id)',
]
@invalidate_substitution_cache
def store(self):
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
conn, cur = get_connection_and_cursor()
error = self
if not self.id:
existing_errors = list(self.select([Equal('tech_id', self.tech_id)]))
if not existing_errors:
column_names = [x for x in sql_dict.keys() if x != 'id']
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
try:
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
except psycopg2.IntegrityError:
# tech_id already used ?
conn.rollback()
existing_errors = list(self.select([Equal('tech_id', self.tech_id)]))
if existing_errors:
error = existing_errors[0]
error.record_new_occurence(self)
else:
column_names = sql_dict.keys()
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
assert cur.fetchone() is not None, 'LoggedError id not found'
cur.close()
return error
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
if field[1] in ('varchar', 'text'):
setattr(o, field[0], str_encode(value))
else:
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
class Token(SqlMixin, wcs.qommon.tokens.Token):
_table_name = 'tokens'
_table_static_fields = [
('id', 'varchar'),
('type', 'varchar'),
('expiration', 'timestamptz'),
('context', 'jsonb'),
]
_numerical_id = False
def store(self):
sql_dict = {
'id': self.id,
'type': self.type,
'expiration': self.expiration,
'context': self.context,
}
conn, cur = get_connection_and_cursor()
column_names = sql_dict.keys()
if not self.id:
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = cur.fetchone()[0]
else:
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def get_data_fields(cls):
return []
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
setattr(o, field[0], value)
o.expiration_check()
return o
class TranslatableMessage(SqlMixin):
_table_name = 'translatable_messages'
_table_static_fields = [
('id', 'serial'),
('string', 'varchar'),
('context', 'varchar'),
('locations', 'varchar[]'),
('last_update_time', 'timestamptz'),
('translatable', 'boolean'),
]
_sql_indexes = [
'translatable_messages_fts ON translatable_messages USING gin(fts)',
]
id = None
@classmethod
def do_table(cls, conn=None, cur=None):
conn, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL,
string VARCHAR,
context VARCHAR,
locations VARCHAR[],
last_update_time TIMESTAMPTZ,
translatable BOOLEAN DEFAULT TRUE,
fts TSVECTOR
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
if 'translatable' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN translatable BOOLEAN DEFAULT(TRUE)''' % table_name)
# add columns for translations
for field in cls.get_data_fields():
if field not in existing_fields:
cur.execute('ALTER TABLE %s ADD COLUMN %s VARCHAR' % (table_name, field))
cls.do_indexes(cur)
cur.close()
@classmethod
def get_data_fields(cls):
languages = get_cfg('language', {}).get('languages') or []
return ['string_%s' % x for x in languages]
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
sql_dict.update({x: getattr(self, x) for x in self.get_data_fields() if hasattr(self, x)})
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
sql_dict['fts'] = FtsMatch.get_fts_value(self.string)
if not self.id:
sql_statement = '''INSERT INTO %s (id, %s, fts)
VALUES (DEFAULT, %s, TO_TSVECTOR(%%(fts)s))
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s, fts = TO_TSVECTOR(%%(fts)s)
WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for attr, value in zip([x[0] for x in cls._table_static_fields] + cls.get_data_fields(), row):
setattr(o, attr, value)
return o
@classmethod
def load_as_catalog(cls, language):
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT context, string, string_%s FROM %s WHERE translatable = TRUE' % (
language,
cls._table_name,
)
cur.execute(sql_statement)
catalog = {(x[0], x[1]): x[2] for x in cur.fetchall()}
cur.close()
return catalog
class TestDef(SqlMixin):
_table_name = 'testdef'
_table_static_fields = [
('id', 'serial'),
('name', 'varchar'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('data', 'jsonb'),
('is_in_backoffice', 'boolean'),
('expected_error', 'varchar'),
('user_uuid', 'varchar'),
('agent_id', 'varchar'),
]
id = None
@classmethod
def do_table(cls, conn=None, cur=None):
conn, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
name varchar,
object_type varchar NOT NULL,
object_id varchar NOT NULL,
data jsonb,
is_in_backoffice boolean NOT NULL DEFAULT FALSE,
expected_error varchar,
user_uuid varchar,
agent_id varchar
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
if 'is_in_backoffice' not in existing_fields:
cur.execute(
'''ALTER TABLE %s ADD COLUMN is_in_backoffice boolean NOT NULL DEFAULT FALSE''' % table_name
)
if 'expected_error' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN expected_error varchar''' % table_name)
if 'agent_id' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN agent_id varchar''' % table_name)
if 'user_uuid' not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN user_uuid varchar''' % table_name)
# delete obsolete fields
needed_fields = {x[0] for x in TestDef._table_static_fields}
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cur.close()
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
if not self.id:
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for attr, value in zip([x[0] for x in cls._table_static_fields], row):
setattr(o, attr, value)
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def migrate_legacy(cls):
for testdef in TestDef.select():
if testdef.data and 'expected_error' in testdef.data:
testdef.expected_error = testdef.data['expected_error']
del testdef.data['expected_error']
testdef.store()
if testdef.data.get('user'):
cls.create_and_link_test_users(testdef)
@staticmethod
def create_and_link_test_users(testdef):
from wcs.testdef import TestDef
try:
user = get_publisher().user_class.get(testdef.data['user']['id'])
except KeyError:
return
user, _ = TestDef.get_or_create_test_user(user)
testdef.user_uuid = user.test_uuid
del testdef.data['user']
testdef.store()
class TestResult(SqlMixin):
_table_name = 'test_result'
_table_static_fields = [
('id', 'serial'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('timestamp', 'timestamptz'),
('success', 'boolean'),
('reason', 'varchar'),
('results', 'jsonb[]'),
]
_table_select_skipped_fields = ['results']
id = None
@classmethod
def do_table(cls, conn=None, cur=None):
conn, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
object_type varchar NOT NULL,
object_id varchar NOT NULL,
timestamp timestamptz,
success boolean NOT NULL,
reason varchar NOT NULL,
results jsonb[]
)'''
% table_name
)
cur.close()
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
column_values = []
for name in column_names:
value = '%%(%s)s' % name
if name == 'results':
value += '::jsonb[]'
column_values.append(value)
if not self.id:
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(column_values),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %s' % (x, y) for x, y in zip(column_names, column_values)]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for attr, value in zip([x[0] for x in cls._table_static_fields], row):
setattr(o, attr, value)
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def migrate_legacy(cls):
skipped_fields = cls._table_select_skipped_fields.copy()
cls._table_select_skipped_fields = []
for test_result in cls.select():
store = False
for result in test_result.results:
if 'details' not in result:
result['details'] = {
'recorded_errors': result.pop('recorded_errors', []),
'missing_required_fields': result.pop('missing_required_fields', []),
'workflow_test_action_uuid': None,
'form_status': None,
}
store = True
if store:
test_result.store()
cls._table_select_skipped_fields = skipped_fields
class WorkflowTrace(SqlMixin):
_table_name = 'workflow_traces'
_table_static_fields = [
('id', 'serial'),
('formdef_type', 'varchar'),
('formdef_id', 'integer'),
('formdata_id', 'integer'),
('status_id', 'varchar'),
('event', 'varchar'),
('event_args', 'jsonb'),
('timestamp', 'timestamptz'),
('action_item_key', 'varchar'),
('action_item_id', 'varchar'),
]
_sql_indexes = [
'workflow_traces_idx ON workflow_traces (formdef_type, formdef_id, formdata_id)',
]
id = None
formdef_type = None
formdef_id = None
formdata_id = None
status_id = None
timestamp = None
event = None
event_args = None
action_item_key = None
action_item_id = None
def __init__(self, formdata, event=None, event_args=None, action=None):
self.timestamp = localtime()
self.formdef_type = formdata.formdef.xml_root_node
self.formdef_id = formdata.formdef.id
self.formdata_id = formdata.id
self.status_id = formdata.status
self.event = event
self.event_args = event_args
if action:
self.action_item_key = action.key
self.action_item_id = action.id
@classmethod
def do_table(cls, conn=None, cur=None):
conn, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
formdef_type varchar NOT NULL,
formdef_id integer NOT NULL,
formdata_id integer NOT NULL,
status_id varchar,
event varchar,
event_args jsonb,
timestamp timestamptz,
action_item_key varchar,
action_item_id varchar
)'''
% table_name
)
cls.do_indexes(cur)
cur.close()
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
if not self.id:
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for attr, value in zip([x[0] for x in cls._table_static_fields], row):
setattr(o, attr, value)
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def migrate_legacy(cls):
from wcs.carddef import CardDef
from wcs.formdef import FormDef
from wcs.workflows import ActionsTracingEvolutionPart
criterias = [StrictNotEqual('status', 'draft')]
for formdef in itertools.chain(FormDef.select(), CardDef.select()):
for formdata in formdef.data_class().select_iterator(criterias, itersize=200):
status_id = None
changed = False
for evo in formdata.evolution or []:
status_id = evo.status or status_id
for part in evo.parts or []:
if not isinstance(part, ActionsTracingEvolutionPart):
continue
changed = True
trace = cls(formdata=formdata)
trace.event = part.event
trace.event_args = {}
if part.external_workflow_id:
trace.event_args = {
'external_workflow_id': part.external_workflow_id,
'external_status_id': part.external_status_id,
'external_item_id': part.external_item_id,
}
if part.event_args:
if trace.event in ('api-post-edit-action', 'edit-action', 'timeout-jump'):
trace.event_args = {'action_item_id': part.event_args[0]}
elif trace.event in (
'global-api-trigger',
'global-external-workflow',
'global-interactive-action',
):
trace.event_args = {'global_action_id': part.event_args[0]}
elif trace.event in ('global-action-timeout',):
if isinstance(part.event_args[0], tuple):
# adapt for some old bug
part.event_args = part.event_args[0]
trace.event_args = {
'global_action_id': part.event_args[0],
'global_trigger_id': part.event_args[1],
}
elif trace.event in ('workflow-created',):
trace.event_args['display_id'] = part.event_args[0]
trace.status_id = status_id
trace.timestamp = evo.time
trace.store()
for action in part.actions or []:
trace = cls(formdata=formdata)
trace.timestamp = make_aware(action[0], is_dst=True)
trace.status_id = status_id
trace.action_item_key = action[1]
trace.action_item_id = action[2]
trace.store()
if changed and evo.parts:
evo.parts = [x for x in evo.parts if not isinstance(x, ActionsTracingEvolutionPart)]
if changed:
formdata._store_all_evolution = True
formdata.store()
class Audit(SqlMixin):
_table_name = 'audit'
_table_static_fields = [
('id', 'bigserial'),
('timestamp', 'timestamptz'),
('action', 'varchar'),
('url', 'varchar'),
('user_id', 'varchar'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('data_id', 'int'),
('extra_data', 'jsonb'),
('frozen', 'jsonb'), # plain copy of user email, object name and slug
]
_sql_indexes = [
'audit_id_idx ON audit USING btree (id)',
]
id = None
@classmethod
def do_table(cls):
_, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id BIGSERIAL,
timestamp TIMESTAMP WITH TIME ZONE,
action VARCHAR,
url VARCHAR,
user_id VARCHAR,
user_email VARCHAR,
object_type VARCHAR,
object_id VARCHAR,
data_id INTEGER,
extra_data JSONB,
frozen JSONB
)'''
% table_name
)
cur.execute(
'''SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
existing_fields = {x[0] for x in cur.fetchall()}
needed_fields = {x[0] for x in Audit._table_static_fields}
# delete obsolete fields
for field in existing_fields - needed_fields:
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
cls.do_indexes(cur)
cur.close()
def store(self):
if self.id:
# do not allow updates
raise AssertionError()
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
_, cur = get_connection_and_cursor()
column_names = [x for x in sql_dict.keys() if x != 'id']
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls()
for field, value in zip(cls._table_static_fields, tuple(row)):
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_first_id(cls, clause=None):
_, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM audit'
where_clauses, parameters, func_clause = parse_clause(clause)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' ORDER BY id LIMIT 1'
cur.execute(sql_statement, parameters)
try:
first_id = cur.fetchall()[0][0]
except IndexError:
first_id = 0
cur.close()
return first_id
class Application(SqlMixin):
_table_name = 'applications'
_table_static_fields = [
('id', 'serial'),
('slug', 'varchar'),
('name', 'varchar'),
('description', 'text'),
('documentation_url', 'varchar'),
('icon', 'bytea'),
('version_number', 'varchar'),
('version_notes', 'text'),
('editable', 'boolean'),
('visible', 'boolean'),
('created_at', 'timestamptz'),
('updated_at', 'timestamptz'),
]
_sql_indexes = [
'applications_slug ON applications (slug)',
]
id = None
@classmethod
def do_table(cls):
_, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
slug VARCHAR NOT NULL,
name VARCHAR NOT NULL,
description TEXT,
documentation_url VARCHAR,
icon BYTEA,
version_number VARCHAR NOT NULL,
version_notes TEXT,
editable BOOLEAN,
visible BOOLEAN,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL
)'''
% table_name
)
cls.do_indexes(cur)
cur.close()
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
sql_dict['updated_at'] = localtime()
if self.icon:
sql_dict['icon'] = bytearray(pickle.dumps(self.icon, protocol=2))
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
if not self.id:
sql_dict['created_at'] = sql_dict['updated_at']
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for field, value in zip(cls._table_static_fields, tuple(row)):
if value and field[1] in ('bytea'):
value = pickle_loads(value)
setattr(o, field[0], value)
return o
@classmethod
def get_data_fields(cls):
return []
class ApplicationElement(SqlMixin):
_table_name = 'application_elements'
_table_static_fields = [
('id', 'serial'),
('application_id', 'integer'),
('object_type', 'varchar'),
('object_id', 'varchar'),
('created_at', 'timestamptz'),
('updated_at', 'timestamptz'),
]
_sql_indexes = [
'application_elements_object_idx ON application_elements (object_type, object_id)',
]
id = None
@classmethod
def do_table(cls):
_, cur = get_connection_and_cursor()
table_name = cls._table_name
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
application_id INTEGER NOT NULL,
object_type varchar NOT NULL,
object_id varchar NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL
)'''
% table_name
)
cls.do_indexes(cur)
cur.execute(
'''SELECT COUNT(*) FROM information_schema.constraint_column_usage
WHERE table_name = %s
AND constraint_name=%s''',
(table_name, '%s_unique' % table_name),
)
if cur.fetchone()[0] == 0:
cur.execute(
'ALTER TABLE %s ADD CONSTRAINT %s_unique UNIQUE (application_id, object_type, object_id)'
% (table_name, table_name)
)
cur.close()
def store(self):
sql_dict = {x[0]: getattr(self, x[0], None) for x in self._table_static_fields if x[0] != 'id'}
sql_dict['updated_at'] = localtime()
_, cur = get_connection_and_cursor()
column_names = list(sql_dict.keys())
if not self.id:
sql_dict['created_at'] = sql_dict['updated_at']
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
)
cur.execute(sql_statement, sql_dict)
cur.close()
@classmethod
def _row2ob(cls, row, **kwargs):
o = cls.__new__(cls)
for attr, value in zip([x[0] for x in cls._table_static_fields], row):
setattr(o, attr, value)
return o
@classmethod
def get_data_fields(cls):
return []
class classproperty:
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
_formdef_cache = {}
_has_id = False
@classproperty
def _table_static_fields(self):
if not hasattr(self, '__table_static_fields'):
fake_formdef = FormDef()
common_fields = get_view_fields(fake_formdef)
self.__table_static_fields = [(x[1], x[0]) for x in common_fields]
self.__table_static_fields.append(('criticality_level', 'criticality_level'))
self.__table_static_fields.append(('geoloc_base_x', 'geoloc_base_x'))
self.__table_static_fields.append(('geoloc_base_y', 'geoloc_base_y'))
self.__table_static_fields.append(('concerned_roles_array', 'concerned_roles_array'))
self.__table_static_fields.append(('anonymised', 'anonymised'))
return self.__table_static_fields
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_objects(cls, *args, **kwargs):
cls._formdef_cache = {}
return super().get_objects(*args, **kwargs)
@classmethod
def _row2ob(cls, row, **kwargs):
formdef_id = row[1]
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
o = formdef.data_class()()
for static_field, value in zip(cls._table_static_fields, tuple(row[: len(cls._table_static_fields)])):
setattr(o, static_field[0], str_encode(value))
# [CRITICALITY_2] transform criticality_level back to the expected
# range (see [CRITICALITY_1])
levels = len(formdef.workflow.criticality_levels or [0])
o.criticality_level = levels + o.criticality_level
# convert back unstructured geolocation to the 'native' formdata format.
if o.geoloc_base_x is not None:
o.geolocations = {'base': {'lon': o.geoloc_base_x, 'lat': o.geoloc_base_y}}
# do not allow storing those partial objects
o.store = None
return o
@classmethod
def load_all_evolutions(cls, formdatas):
classes = {}
for formdata in formdatas:
if formdata._table_name not in classes:
classes[formdata._table_name] = []
classes[formdata._table_name].append(formdata)
for formdatas in classes.values():
formdatas[0].load_all_evolutions(formdatas)
def get_period_query(
period_start=None, include_start=True, period_end=None, include_end=True, criterias=None, parameters=None
):
clause = [NotNull('receipt_time')]
table_name = 'wcs_all_forms'
if criterias:
formdef_class = FormDef
for criteria in criterias:
if criteria.__class__.__name__ == 'Equal' and criteria.attribute == 'formdef_klass':
formdef_class = criteria.value
continue
if (
formdef_class
and criteria.__class__.__name__ == 'Equal'
and criteria.attribute == 'formdef_id'
):
# if there's a formdef_id specified, switch to using the
# specific table so we have access to all fields
table_name = get_formdef_table_name(formdef_class.get(criteria.value))
continue
clause.append(criteria)
if period_start:
if include_start:
clause.append(GreaterOrEqual('receipt_time', period_start))
else:
clause.append(Greater('receipt_time', period_start))
if period_end:
if include_end:
clause.append(LessOrEqual('receipt_time', period_end))
else:
clause.append(Less('receipt_time', period_end))
where_clauses, params, dummy = parse_clause(clause)
parameters.update(params)
statement = ' FROM %s ' % table_name
statement += ' WHERE ' + ' AND '.join(where_clauses)
return statement
class SearchableFormDef(SqlMixin):
_table_name = 'searchable_formdefs'
_sql_indexes = [
'searchable_formdefs_fts ON searchable_formdefs USING gin(fts)',
]
@classmethod
@atomic
def do_table(cls):
_, cur = get_connection_and_cursor()
cur.execute(
'''SELECT COUNT(*) FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = %s''',
(cls._table_name,),
)
if cur.fetchone()[0] == 0:
cur.execute(
'''CREATE TABLE %s (id SERIAL PRIMARY KEY,
object_type VARCHAR,
object_id VARCHAR,
timestamp TIMESTAMPTZ DEFAULT NOW(),
fts TSVECTOR)
'''
% cls._table_name
)
cur.execute(
'ALTER TABLE %s ADD CONSTRAINT %s_unique UNIQUE (object_type, object_id)'
% (cls._table_name, cls._table_name)
)
cls.do_indexes(cur)
cur.close()
from wcs.carddef import CardDef
from wcs.formdef import FormDef
for objectdef in itertools.chain(
CardDef.select(ignore_errors=True), FormDef.select(ignore_errors=True)
):
cls.update(obj=objectdef)
@classmethod
def update(cls, obj=None, removed_obj_type=None, removed_obj_id=None):
_, cur = get_connection_and_cursor()
if removed_obj_id:
cur.execute(
'DELETE FROM searchable_formdefs WHERE object_type = %s AND object_id = %s',
(removed_obj_type, removed_obj_id),
)
else:
cur.execute(
'''INSERT INTO searchable_formdefs (object_type, object_id, fts)
VALUES (%(object_type)s, %(object_id)s,
setweight(to_tsvector(%(fts_a)s), 'A') ||
setweight(to_tsvector(%(fts_b)s), 'B') ||
setweight(to_tsvector(%(fts_c)s), 'C'))
ON CONFLICT(object_type, object_id) DO UPDATE
SET fts = setweight(to_tsvector(%(fts_a)s), 'A') ||
setweight(to_tsvector(%(fts_b)s), 'B') ||
setweight(to_tsvector(%(fts_c)s), 'C'),
timestamp = NOW()
''',
{
'object_type': obj.xml_root_node,
'object_id': obj.id,
'fts_a': FtsMatch.get_fts_value(obj.name),
'fts_b': FtsMatch.get_fts_value(obj.description or ''),
'fts_c': FtsMatch.get_fts_value(obj.keywords or ''),
},
)
cur.close()
@classmethod
def search(cls, obj_type, string):
_, cur = get_connection_and_cursor()
cur.execute(
'SELECT object_id FROM searchable_formdefs WHERE fts @@ plainto_tsquery(%s)',
(FtsMatch.get_fts_value(string),),
)
ids = [x[0] for x in cur.fetchall()]
cur.close()
return ids
def get_time_aggregate_query(time_interval, query, group_by, function='DATE_TRUNC', group_by_clause=None):
statement = f"SELECT {function}('{time_interval}', receipt_time) AS {time_interval}, "
if group_by:
if group_by_clause:
statement += group_by_clause
else:
statement += '%s, ' % group_by
statement += 'COUNT(*) '
statement += query
aggregate_fields = time_interval
if group_by:
aggregate_fields += ', %s' % group_by
statement += f' GROUP BY {aggregate_fields} ORDER BY {aggregate_fields}'
return statement
def get_actionable_counts(user_roles):
_, cur = get_connection_and_cursor()
criterias = [
Equal('is_at_endpoint', False),
Intersects('actions_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s
GROUP BY formdef_id''' % ' AND '.join(
where_clauses
)
cur.execute(statement, parameters)
counts = {str(x): y for x, y in cur.fetchall()}
cur.close()
return counts
def get_total_counts(user_roles):
_, cur = get_connection_and_cursor()
criterias = [
Intersects('concerned_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s
GROUP BY formdef_id''' % ' AND '.join(
where_clauses
)
cur.execute(statement, parameters)
counts = {str(x): y for x, y in cur.fetchall()}
cur.close()
return counts
def get_weekday_totals(
period_start=None, period_end=None, criterias=None, group_by=None, group_by_clause=None
):
__, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query(
'dow', statement, group_by, function='DATE_PART', group_by_clause=group_by_clause
)
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x[0]), *x[1:]) for x in result]
coverage = [x[0] for x in result]
for weekday in range(7):
if weekday not in coverage:
result.append((weekday, 0))
result.sort(key=lambda x: x[0])
# add labels,
weekday_names = [
_('Sunday'),
_('Monday'),
_('Tuesday'),
_('Wednesday'),
_('Thursday'),
_('Friday'),
_('Saturday'),
]
result = [(weekday_names[x[0]], *x[1:]) for x in result]
# and move Sunday last
result = result[1:] + [result[0]]
cur.close()
return result
def get_formdef_totals(period_start=None, period_end=None, criterias=None):
_, cur = get_connection_and_cursor()
statement = '''SELECT formdef_id, COUNT(*)'''
parameters = {}
statement += get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement += ' GROUP BY formdef_id'
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x), y) for x, y in result]
cur.close()
return result
def get_global_totals(
period_start=None, period_end=None, criterias=None, group_by=None, group_by_clause=None
):
_, cur = get_connection_and_cursor()
statement = 'SELECT '
if group_by:
if group_by_clause:
statement += group_by_clause
else:
statement += f'{group_by}, '
statement += 'COUNT(*) '
parameters = {}
statement += get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
if group_by:
statement += f' GROUP BY {group_by} ORDER BY {group_by}'
cur.execute(statement, parameters)
result = cur.fetchall()
if not group_by:
result = [('', result[0][0])]
cur.close()
return result
def get_hour_totals(period_start=None, period_end=None, criterias=None, group_by=None, group_by_clause=None):
_, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query(
'hour', statement, group_by, function='DATE_PART', group_by_clause=group_by_clause
)
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x[0]), *x[1:]) for x in result]
coverage = [x[0] for x in result]
for hour in range(24):
if hour not in coverage:
result.append((hour, 0))
result.sort(key=lambda x: x[0])
cur.close()
return result
def get_monthly_totals(
period_start=None,
period_end=None,
criterias=None,
group_by=None,
group_by_clause=None,
):
_, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('month', statement, group_by, group_by_clause=group_by_clause)
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [('%d-%02d' % x[0].timetuple()[:2], *x[1:]) for x in raw_result]
if result:
coverage = [x[0] for x in result]
current_month = raw_result[0][0]
last_month = raw_result[-1][0]
while current_month < last_month:
label = '%d-%02d' % current_month.timetuple()[:2]
if label not in coverage:
result.append((label, 0))
current_month = current_month + datetime.timedelta(days=31)
current_month = current_month - datetime.timedelta(days=current_month.day - 1)
result.sort(key=lambda x: x[0])
cur.close()
return result
def get_yearly_totals(
period_start=None, period_end=None, criterias=None, group_by=None, group_by_clause=None
):
_, cur = get_connection_and_cursor()
parameters = {}
statement = get_period_query(
period_start=period_start, period_end=period_end, criterias=criterias, parameters=parameters
)
statement = get_time_aggregate_query('year', statement, group_by, group_by_clause=group_by_clause)
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [(str(x[0].year), *x[1:]) for x in raw_result]
if result:
coverage = [x[0] for x in result]
current_year = raw_result[0][0]
last_year = raw_result[-1][0]
while current_year < last_year:
label = str(current_year.year)
if label not in coverage:
result.append((label, 0))
current_year = current_year + datetime.timedelta(days=366)
result.sort(key=lambda x: x[0])
cur.close()
return result
def get_period_total(
period_start=None, include_start=True, period_end=None, include_end=True, criterias=None
):
_, cur = get_connection_and_cursor()
statement = '''SELECT COUNT(*)'''
parameters = {}
statement += get_period_query(
period_start=period_start,
include_start=include_start,
period_end=period_end,
include_end=include_end,
criterias=criterias,
parameters=parameters,
)
cur.execute(statement, parameters)
result = int(cur.fetchone()[0])
cur.close()
return result
# latest migration, number + description (description is not used
# programmaticaly but will make sure git conflicts if two migrations are
# separately added with the same number)
SQL_LEVEL = (107, 'add test_uuid column to users table')
def migrate_global_views(conn, cur):
drop_global_views(conn, cur)
do_global_views(conn, cur)
def get_sql_level(conn, cur):
do_meta_table(conn, cur, insert_current_sql_level=False)
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', ('sql_level',))
sql_level = int(cur.fetchone()[0])
return sql_level
def get_cron_status():
conn, cur = get_connection_and_cursor()
do_meta_table(conn, cur, insert_current_sql_level=False)
key = 'cron-status-%s' % get_publisher().tenant.hostname
cur.execute('SELECT value, updated_at FROM wcs_meta WHERE key = %s', (key,))
row = cur.fetchone()
cur.close()
return tuple(row) if row else (None, now())
@atomic
def get_and_update_cron_status():
conn, cur = get_connection_and_cursor()
do_meta_table(conn, cur, insert_current_sql_level=False)
key = 'cron-status-%s' % get_publisher().tenant.hostname
cur.execute('SELECT value, created_at FROM wcs_meta WHERE key = %s FOR UPDATE', (key,))
row = cur.fetchone()
timestamp = now()
if row is None:
cur.execute("INSERT INTO wcs_meta (key, value) VALUES (%s, 'running') ON CONFLICT DO NOTHING", (key,))
if cur.rowcount != 1:
# since we could not insert, it means somebody else did meanwhile, and thus we can assume it's running
status = 'running'
else:
status = 'done'
elif row[0] in ('done', 'needed'): # (needed is legacy)
cur.execute(
"""UPDATE wcs_meta
SET value = 'running', created_at = NOW(), updated_at = NOW()
WHERE key = %s""",
(key,),
)
status, timestamp = 'done', row[1]
else:
status, timestamp = row
cur.close()
return (status, timestamp)
def mark_cron_status(status):
_, cur = get_connection_and_cursor()
key = 'cron-status-%s' % get_publisher().tenant.hostname
cur.execute('UPDATE wcs_meta SET value = %s, updated_at = NOW() WHERE key = %s', (status, key))
cur.close()
def is_reindex_needed(index, conn, cur):
do_meta_table(conn, cur, insert_current_sql_level=False)
key_name = 'reindex_%s' % index
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name,))
row = cur.fetchone()
if row is None:
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
(key_name, 'no'),
)
return False
return row[0] == 'needed'
def set_reindex(index, value, conn=None, cur=None):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
do_meta_table(conn, cur, insert_current_sql_level=False)
key_name = 'reindex_%s' % index
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name,))
row = cur.fetchone()
if row is None:
cur.execute(
'''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''',
(key_name, value),
)
else:
if row[0] != value:
cur.execute(
'''UPDATE wcs_meta SET value = %s, updated_at = NOW() WHERE key = %s''', (value, key_name)
)
if own_conn:
cur.close()
def migrate_views(conn, cur):
drop_views(None, conn, cur)
for formdef in FormDef.select() + CardDef.select():
# make sure all formdefs have up-to-date views
do_formdef_tables(formdef, conn=conn, cur=cur, rebuild_views=True, rebuild_global_views=False)
migrate_global_views(conn, cur)
def migrate():
conn, cur = get_connection_and_cursor()
sql_level = get_sql_level(conn, cur)
if sql_level < 0:
# fake code to help in testing the error code path.
raise RuntimeError()
if sql_level < 1: # 1: introduction of tracking_code table
do_tracking_code_table()
if sql_level < 93:
# 42: create snapshots table
# 54: add patch column
# 63: add index
# 83: add test_result table
# 93: add application columns in snapshot table
do_snapshots_table()
if sql_level < 50:
# 49: store Role in SQL
# 50: switch role uuid column to varchar
do_role_table()
migrate_legacy_roles()
if sql_level < 106:
# 47: store LoggedErrors in SQL
# 48: remove acked attribute from LoggedError
# 53: add kind column to logged_errors table
# 106: add context column to logged_errors table
do_loggederrors_table()
if sql_level < 107:
# 3: introduction of _structured for user fields
# 4: removal of identification_token
# 12: (first part) add fts to users
# 16: add verified_fields to users
# 21: (first part) add ascii_name to users
# 39: add deleted_timestamp
# 40: add is_active to users
# 65: index users(name_identifiers)
# 85: remove anonymous column
# 94: add preferences column to users table
# 107: add test_uuid column to users table
do_user_table()
if sql_level < 32:
# 25: create session_table
# 32: add last_update_time column to session table
do_session_table()
if sql_level < 64:
# 64: add transient data table
do_transient_data_table()
if sql_level < 92:
# 37: create custom_views table
# 44: add is_default column to custom_views table
# 66: index the formdef_id column
# 90: add role_id to custom views
# 92: add group_by column to custom views
do_custom_views_table()
if sql_level < 67:
# 57: store tokens in SQL
# 67: re-migrate legacy tokens
do_tokens_table()
migrate_legacy_tokens()
if sql_level < 100:
# 68: multilinguism
# 79: add translatable column to TranslatableMessage table
# 100: always create translation messages table
TranslatableMessage.do_table()
if sql_level < 107:
# 72: add testdef table
# 87: add testdef is_in_backoffice column
# 88: add testdef expected_error column
# 103: drop testdef slug column
# 104: add testdef agent_id column
# 107: add test_uuid column to users table
TestDef.do_table()
if sql_level < 95:
# 95: add a searchable_formdefs table
SearchableFormDef.do_table()
if sql_level < 107:
# 88: add testdef expected_error column
# 107: add test_uuid column to users table
set_reindex('testdef', 'needed', conn=conn, cur=cur)
if sql_level < 76:
# 75: migrate to dedicated workflow traces table
# 76: add index to workflow traces table
WorkflowTrace.do_table()
if sql_level < 78:
# 78: add audit table
Audit.do_table()
if sql_level < 89:
# 83: add test_result table
# 89: rerun creation of test results table
TestResult.do_table()
if sql_level < 105:
# 105: change test result json structure
set_reindex('test_result', 'needed', conn=conn, cur=cur)
if sql_level < 84:
# 84: add application tables
Application.do_table()
ApplicationElement.do_table()
if sql_level < 52:
# 2: introduction of formdef_id in views
# 5: add concerned_roles_array, is_at_endpoint and fts to views
# 7: add backoffice_submission to tables
# 8: add submission_context to tables
# 9: add last_update_time to views
# 10: add submission_channel to tables
# 11: add formdef_name and user_name to views
# 13: add backoffice_submission to views
# 14: add criticality_level to tables & views
# 15: add geolocation to formdata
# 19: add geolocation to views
# 20: remove user hash stuff
# 22: rebuild views
# 26: add digest to formdata
# 27: add last_jump_datetime in evolutions tables
# 31: add user_label to formdata
# 33: add anonymised field to global view
# 38: extract submission_agent_id to its own column
# 43: add prefilling_data to formdata
# 52: store digests on formdata and carddata
migrate_views(conn, cur)
if sql_level < 6:
# 6: add actions_roles_array to tables and views
migrate_views(conn, cur)
for formdef in FormDef.select():
formdef.data_class().rebuild_security()
if sql_level < 62:
# 12: (second part), store fts in existing rows
# 21: (second part), store ascii_name of users
# 23: (first part), use misc.simplify() over full text queries
# 61: use setweight on formdata & user indexation
# 62: use setweight on formdata & user indexation (reapply)
set_reindex('user', 'needed', conn=conn, cur=cur)
if sql_level < 96:
# 17: store last_update_time in tables
# 18: add user name to full-text search index
# 21: (third part), add user ascii_names to full-text index
# 23: (second part) use misc.simplify() over full text queries
# 28: add display id and formdef name to full-text index
# 29: add evolution parts to full-text index
# 31: add user_label to formdata
# 38: extract submission_agent_id to its own column
# 41: update full text normalization
# 51: add index on formdata blockdef fields
# 55: update full text normalisation (switch to unidecode)
# 58: add workflow_merged_roles_dict as a jsonb column with
# combined formdef and formdata value.
# 61: use setweight on formdata & user indexation
# 62: use setweight on formdata & user indexation (reapply)
# 96: change to fts normalization
set_reindex('formdata', 'needed', conn=conn, cur=cur)
if sql_level < 99:
# 24: add index on evolution(formdata_id)
# 35: add indexes on formdata(receipt_time) and formdata(anonymised)
# 36: add index on formdata(user_id)
# 45 & 46: add index on formdata(status)
# 56: add GIN indexes to concerned_roles_array and actions_roles_array
# 74: (late migration) change evolution index to be on (fomdata_id, id)
# 97&98: add index on carddata/id_display
# 99: add more indexes
set_reindex('sqlindexes', 'needed', conn=conn, cur=cur)
if sql_level < 30:
# 30: actually remove evo.who on anonymised formdatas
for formdef in FormDef.select():
for formdata in formdef.data_class().select_iterator(clause=[NotNull('anonymised')]):
if formdata.evolution:
for evo in formdata.evolution:
evo.who = None
formdata.store()
if sql_level < 52:
# 52: store digests on formdata and carddata
for formdef in FormDef.select() + CardDef.select():
if not formdef.digest_templates:
continue
for formdata in formdef.data_class().select_iterator():
formdata._set_auto_fields(cur) # build digests
if sql_level < 102:
# 58: add workflow_merged_roles_dict as a jsonb column with
# combined formdef and formdata value.
# 69: add auto_geoloc field to form/card tables
# 80: add jsonb column to hold statistics data
# 91: add jsonb column to hold relations data
# 102: switch formdata datetime columns to timestamptz
drop_views(None, conn, cur)
for formdef in FormDef.select() + CardDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
migrate_views(conn, cur)
if sql_level < 102:
# 81: add statistics data column to wcs_all_forms
# 82: add statistics data column to wcs_all_forms, for real
# 99: add more indexes
# 102: switch formdata datetime columns to timestamptz
migrate_global_views(conn, cur)
if sql_level < 60:
# 59: switch wcs_all_forms to a trigger-maintained table
# 60: rebuild triggers
init_global_table(conn, cur)
for formdef in FormDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
if sql_level < 71:
# 71: python datasource migration
set_reindex('python_ds_migration', 'needed', conn=conn, cur=cur)
if sql_level < 73:
# 73: form tokens to db
# it uses the existing tokens table, this "migration" is just to remove old files.
form_tokens_dir = os.path.join(get_publisher().app_dir, 'form_tokens')
if os.path.exists(form_tokens_dir):
shutil.rmtree(form_tokens_dir, ignore_errors=True)
if sql_level < 75:
# 75 (part 2): migrate to dedicated workflow traces table
set_reindex('workflow_traces_migration', 'needed', conn=conn, cur=cur)
if sql_level < 77:
# 77: use token table for nonces
# it uses the existing tokens table, this "migration" is just to remove old files.
nonces_dir = os.path.join(get_publisher().app_dir, 'nonces')
if os.path.exists(nonces_dir):
shutil.rmtree(nonces_dir, ignore_errors=True)
if sql_level < 86:
# 86: add uuid to cards
for formdef in CardDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
if sql_level < 101:
# 101: add page_id to formdatas
for formdef in FormDef.select() + CardDef.select():
do_formdef_tables(formdef, rebuild_views=False, rebuild_global_views=False)
if sql_level != SQL_LEVEL[0]:
cur.execute(
'''UPDATE wcs_meta SET value = %s, updated_at=NOW() WHERE key = %s''',
(str(SQL_LEVEL[0]), 'sql_level'),
)
cur.close()
def reindex():
conn, cur = get_connection_and_cursor()
if is_reindex_needed('sqlindexes', conn=conn, cur=cur):
for klass in (
SqlUser,
Session,
CustomView,
Snapshot,
LoggedError,
TranslatableMessage,
WorkflowTrace,
Audit,
Application,
ApplicationElement,
):
klass.do_indexes(cur, concurrently=True)
for formdef in FormDef.select() + CardDef.select():
do_formdef_indexes(formdef, cur=cur, concurrently=True)
set_reindex('sqlindexes', 'done', conn=conn, cur=cur)
if is_reindex_needed('user', conn=conn, cur=cur):
for user in SqlUser.select(iterator=True):
user.store()
set_reindex('user', 'done', conn=conn, cur=cur)
if is_reindex_needed('formdata', conn=conn, cur=cur):
# load and store all formdatas
for formdef in FormDef.select() + CardDef.select():
for formdata in formdef.data_class().select(iterator=True):
try:
formdata.migrate()
formdata.store()
except Exception as e:
print('error reindexing %s (%r)' % (formdata, e))
set_reindex('formdata', 'done', conn=conn, cur=cur)
if is_reindex_needed('python_ds_migration', conn=conn, cur=cur):
# migrate python datasource
def migrate_value(value):
try:
# noqa pylint: disable=eval-used
value = eval(value)
except Exception:
return
try:
if len(value) == 0:
return []
except TypeError:
return
if isinstance(value[0], (list, tuple)):
if len(value[0]) >= 3:
return [{'id': x[0], 'text': x[1], 'key': x[2]} for x in value]
elif len(value[0]) == 2:
return [{'id': x[0], 'text': x[1]} for x in value]
elif len(value[0]) == 1:
return [{'id': x[0], 'text': x[0]} for x in value]
return value
elif isinstance(value[0], str):
return [{'id': x, 'text': x} for x in value]
elif isinstance(value[0], dict):
if all(str(x.get('id', '')) and x.get('text') for x in value):
return value
def migrate_field(field):
if not getattr(field, 'data_source', None):
return
data_source_type = field.data_source.get('type')
if data_source_type != 'formula':
return
value = migrate_value(field.data_source.get('value'))
if value is not None:
field.data_source['type'] = 'jsonvalue'
field.data_source['value'] = json.dumps(value)
return True
for formdef in itertools.chain(FormDef.select(), CardDef.select()):
changed = False
for field in formdef.fields or []:
changed |= bool(migrate_field(field))
if changed:
formdef.store(comment=_('Automatic update'), snapshot_store_user=False)
from wcs.workflows import Workflow
for workflow in Workflow.select():
changed = False
if workflow.backoffice_fields_formdef and workflow.backoffice_fields_formdef.fields:
for field in workflow.backoffice_fields_formdef.fields:
changed |= bool(migrate_field(field))
if changed:
workflow.store(
migration_update=True, comment=_('Automatic update'), snapshot_store_user=False
)
from wcs.data_sources import NamedDataSource
for datasource in NamedDataSource.select():
data_source = datasource.data_source or {}
if not data_source.get('type') == 'formula':
continue
value = migrate_value(data_source.get('value'))
if value is not None:
datasource.data_source['type'] = 'jsonvalue'
datasource.data_source['value'] = json.dumps(value)
datasource.store(comment=_('Automatic update'), snapshot_store_user=False)
set_reindex('python_ds_migration', 'done', conn=conn, cur=cur)
if is_reindex_needed('workflow_traces_migration', conn=conn, cur=cur):
WorkflowTrace.migrate_legacy()
set_reindex('workflow_traces_migration', 'done', conn=conn, cur=cur)
if is_reindex_needed('testdef', conn=conn, cur=cur):
TestDef.migrate_legacy()
set_reindex('testdef', 'done', conn=conn, cur=cur)
if is_reindex_needed('test_result', conn=conn, cur=cur):
TestResult.migrate_legacy()
set_reindex('test_result', 'done', conn=conn, cur=cur)
cur.close()
def formdef_remap_statuses(formdef, mapping):
table_name = get_formdef_table_name(formdef)
evolutions_table_name = table_name + '_evolutions'
unmapped_status_suffix = str(formdef.workflow_id or 'default')
# build the case expression
status_cases = []
for old_id, new_id in mapping.items():
status_cases.append(
SQL('WHEN status = {old_status} THEN {new_status}').format(
old_status=Literal(old_id), new_status=Literal(new_id)
)
)
case_expression = SQL(
'(CASE WHEN status IS NULL THEN NULL '
'{status_cases} '
# keep status alread marked as invalid
'WHEN status LIKE {pattern} THEN status '
# mark unknown statuses as invalid
'ELSE (status || {suffix}) END)'
).format(
status_cases=SQL('').join(status_cases),
pattern=Literal('%-invalid-%'),
suffix=Literal('-invalid-' + unmapped_status_suffix),
)
_, cur = get_connection_and_cursor()
# update formdatas statuses
cur.execute(
SQL('UPDATE {table_name} SET status = {case_expression} WHERE status <> {draft_status}').format(
table_name=Identifier(table_name), case_expression=case_expression, draft_status=Literal('draft')
)
)
# update evolutions statuses
cur.execute(
SQL('UPDATE {table_name} SET status = {case_expression}').format(
table_name=Identifier(evolutions_table_name), case_expression=case_expression
)
)
cur.close()