2854 lines
107 KiB
Python
2854 lines
107 KiB
Python
# w.c.s. - web application for online forms
|
|
# Copyright (C) 2005-2012 Entr'ouvert
|
|
#
|
|
# This program is free software; you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation; either version 2 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program; if not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import copy
|
|
import psycopg2
|
|
import psycopg2.extensions
|
|
import psycopg2.extras
|
|
import datetime
|
|
import time
|
|
import re
|
|
import unicodedata
|
|
try:
|
|
import cPickle as pickle
|
|
except ImportError:
|
|
import pickle
|
|
|
|
from django.utils import six
|
|
from django.utils.encoding import force_bytes, force_text
|
|
from django.utils.six import BytesIO
|
|
|
|
from quixote import get_publisher
|
|
from . import qommon
|
|
from wcs.qommon import force_str, PICKLE_KWARGS
|
|
from .qommon.storage import _take, deep_bytes2str, parse_clause as parse_storage_clause
|
|
from .qommon.substitution import invalidate_substitution_cache
|
|
from .qommon import get_cfg
|
|
from .qommon.upload_storage import PicklableUpload
|
|
from .qommon.misc import strftime
|
|
from .publisher import UnpicklerClass
|
|
|
|
import wcs.categories
|
|
import wcs.carddata
|
|
import wcs.custom_views
|
|
import wcs.formdata
|
|
import wcs.snapshots
|
|
import wcs.tracking_code
|
|
import wcs.users
|
|
|
|
# enable psycogp2 unicode mode, this will fetch postgresql varchar/text columns
|
|
# as unicode objects
|
|
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
|
|
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
|
|
|
|
# automatically adapt dictionaries into json fields
|
|
psycopg2.extensions.register_adapter(dict, psycopg2.extras.Json)
|
|
|
|
|
|
SQL_TYPE_MAPPING = {
|
|
'title': None,
|
|
'subtitle': None,
|
|
'comment': None,
|
|
'page': None,
|
|
'text': 'text',
|
|
'bool': 'boolean',
|
|
'file': 'bytea',
|
|
'date': 'date',
|
|
'items': 'text[]',
|
|
'table': 'text[][]',
|
|
'table-select': 'text[][]',
|
|
'tablerows': 'text[][]',
|
|
# mapping of dicts
|
|
'ranked-items': 'text[][]',
|
|
'password': 'text[][]',
|
|
# field block
|
|
'block': 'jsonb',
|
|
}
|
|
|
|
|
|
def pickle_loads(value):
|
|
if hasattr(value, 'tobytes'):
|
|
value = value.tobytes()
|
|
obj = UnpicklerClass(BytesIO(force_bytes(value)), **PICKLE_KWARGS).load()
|
|
obj = deep_bytes2str(obj)
|
|
return obj
|
|
|
|
|
|
class Criteria(qommon.storage.Criteria):
|
|
def __init__(self, attribute, value, **kwargs):
|
|
self.attribute = attribute.replace('-', '_')
|
|
self.value = value
|
|
|
|
def as_sql(self):
|
|
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
|
|
|
|
def as_sql_param(self):
|
|
if isinstance(self.value, time.struct_time):
|
|
value = datetime.datetime.fromtimestamp(time.mktime(self.value))
|
|
else:
|
|
value = self.value
|
|
return {'c%s' % id(self.value): value}
|
|
|
|
|
|
class Less(Criteria):
|
|
sql_op = '<'
|
|
|
|
|
|
class Greater(Criteria):
|
|
sql_op = '>'
|
|
|
|
|
|
class Equal(Criteria):
|
|
sql_op = '='
|
|
|
|
def as_sql(self):
|
|
if self.value in ([], ()):
|
|
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
|
|
return super(Equal, self).as_sql()
|
|
|
|
|
|
class LessOrEqual(Criteria):
|
|
sql_op = '<='
|
|
|
|
|
|
class GreaterOrEqual(Criteria):
|
|
sql_op = '>='
|
|
|
|
|
|
class NotEqual(Criteria):
|
|
sql_op = '!='
|
|
|
|
|
|
class Contains(Criteria):
|
|
sql_op = 'IN'
|
|
|
|
def as_sql(self):
|
|
if not self.value:
|
|
return 'FALSE'
|
|
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
|
|
|
|
def as_sql_param(self):
|
|
return {'c%s' % id(self.value): tuple(self.value)}
|
|
|
|
|
|
class NotContains(Contains):
|
|
sql_op = 'NOT IN'
|
|
|
|
def as_sql(self):
|
|
if not self.value:
|
|
return 'TRUE'
|
|
return super(NotContains, self).as_sql()
|
|
|
|
|
|
class NotNull(Criteria):
|
|
sql_op = 'IS NOT NULL'
|
|
|
|
def __init__(self, attribute):
|
|
self.attribute = attribute
|
|
|
|
def as_sql(self):
|
|
return '%s %s' % (self.attribute, self.sql_op)
|
|
|
|
def as_sql_param(self):
|
|
return {}
|
|
|
|
|
|
class Null(Criteria):
|
|
sql_op = 'IS NULL'
|
|
|
|
def __init__(self, attribute, **kwargs):
|
|
self.attribute = attribute
|
|
|
|
def as_sql(self):
|
|
return '%s %s' % (self.attribute, self.sql_op)
|
|
|
|
def as_sql_param(self):
|
|
return {}
|
|
|
|
|
|
class Or(Criteria):
|
|
def __init__(self, criterias, **kwargs):
|
|
self.criterias = []
|
|
for element in criterias:
|
|
sql_class = globals().get(element.__class__.__name__)
|
|
sql_element = sql_class(**element.__dict__)
|
|
self.criterias.append(sql_element)
|
|
|
|
def as_sql(self):
|
|
if not self.criterias:
|
|
return '( FALSE )'
|
|
return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
|
|
|
|
def as_sql_param(self):
|
|
d = {}
|
|
for criteria in self.criterias:
|
|
d.update(criteria.as_sql_param())
|
|
return d
|
|
|
|
|
|
class And(Criteria):
|
|
def __init__(self, criterias, **kwargs):
|
|
self.criterias = []
|
|
for element in criterias:
|
|
sql_class = globals().get(element.__class__.__name__)
|
|
sql_element = sql_class(**element.__dict__)
|
|
self.criterias.append(sql_element)
|
|
|
|
def as_sql(self):
|
|
return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
|
|
|
|
def as_sql_param(self):
|
|
d = {}
|
|
for criteria in self.criterias:
|
|
d.update(criteria.as_sql_param())
|
|
return d
|
|
|
|
|
|
class Intersects(Criteria):
|
|
def as_sql(self):
|
|
if not self.value:
|
|
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
|
|
else:
|
|
return '%s && %%(c%s)s' % (self.attribute, id(self.value))
|
|
|
|
def as_sql_param(self):
|
|
return {'c%s' % id(self.value): list(self.value)}
|
|
|
|
|
|
class ILike(Criteria):
|
|
def __init__(self, attribute, value, **kwargs):
|
|
super(ILike, self).__init__(attribute, value, **kwargs)
|
|
self.value = '%' + self.value + '%'
|
|
|
|
def as_sql(self):
|
|
return '%s ILIKE %%(c%s)s' % (self.attribute, id(self.value))
|
|
|
|
|
|
class FtsMatch(Criteria):
|
|
def __init__(self, value):
|
|
self.value = self.get_fts_value(value)
|
|
|
|
@classmethod
|
|
def get_fts_value(self, value):
|
|
return unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
|
|
|
def as_sql(self):
|
|
return 'fts @@ plainto_tsquery(%%(c%s)s)' % id(self.value)
|
|
|
|
|
|
def get_name_as_sql_identifier(name):
|
|
name = qommon.misc.simplify(name)
|
|
for char in '<>|{}!?^*+/=\'': # forbidden chars
|
|
name = name.replace(char, '')
|
|
name = name.replace('-', '_')
|
|
return name
|
|
|
|
|
|
def get_field_id(field):
|
|
return 'f' + str(field.id).replace('-', '_').lower()
|
|
|
|
|
|
def parse_clause(clause):
|
|
# returns a three-elements tuple with:
|
|
# - a list of SQL 'WHERE' clauses
|
|
# - a dict for query parameters
|
|
# - a callable, or None if all clauses have been successfully translated
|
|
|
|
if clause is None:
|
|
return ([], {}, None)
|
|
|
|
if callable(clause): # already a callable
|
|
return ([], {}, clause)
|
|
|
|
# create 'WHERE' clauses
|
|
func_clauses = []
|
|
where_clauses = []
|
|
parameters = {}
|
|
for element in clause:
|
|
if callable(element):
|
|
func_clauses.append(element)
|
|
else:
|
|
sql_class = globals().get(element.__class__.__name__)
|
|
if sql_class:
|
|
sql_element = sql_class(**element.__dict__)
|
|
where_clauses.append(sql_element.as_sql())
|
|
parameters.update(sql_element.as_sql_param())
|
|
else:
|
|
func_clauses.append(element.build_lambda())
|
|
|
|
if func_clauses:
|
|
return (where_clauses, parameters, parse_storage_clause(func_clauses))
|
|
else:
|
|
return (where_clauses, parameters, None)
|
|
|
|
|
|
def str_encode(value):
|
|
if isinstance(value, list):
|
|
return [str_encode(x) for x in value]
|
|
return value
|
|
|
|
|
|
def site_unicode(value):
|
|
return force_text(value, get_publisher().site_charset)
|
|
|
|
|
|
def get_connection(new=False):
|
|
if new:
|
|
cleanup_connection()
|
|
if not hasattr(get_publisher(), 'pgconn') or get_publisher().pgconn is None:
|
|
postgresql_cfg = {}
|
|
for param in ('database', 'user', 'password', 'host', 'port'):
|
|
value = get_cfg('postgresql', {}).get(param)
|
|
if value:
|
|
postgresql_cfg[param] = value
|
|
try:
|
|
get_publisher().pgconn = psycopg2.connect(**postgresql_cfg)
|
|
except psycopg2.Error:
|
|
if new:
|
|
raise
|
|
get_publisher().pgconn = None
|
|
else:
|
|
cur = get_publisher().pgconn.cursor()
|
|
cur.execute('SHOW server_version_num')
|
|
get_publisher().pg_version = int(cur.fetchone()[0])
|
|
cur.close()
|
|
return get_publisher().pgconn
|
|
|
|
|
|
def cleanup_connection():
|
|
if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
|
|
get_publisher().pgconn.close()
|
|
get_publisher().pgconn = None
|
|
|
|
|
|
def get_connection_and_cursor(new=False):
|
|
conn = get_connection(new=new)
|
|
try:
|
|
cur = conn.cursor()
|
|
except psycopg2.InterfaceError:
|
|
# may be postgresql was restarted in between
|
|
conn = get_connection(new=True)
|
|
cur = conn.cursor()
|
|
return (conn, cur)
|
|
|
|
|
|
def get_formdef_table_name(formdef):
|
|
# PostgreSQL limits identifier length to 63 bytes
|
|
#
|
|
# The system uses no more than NAMEDATALEN-1 bytes of an identifier;
|
|
# longer names can be written in commands, but they will be truncated.
|
|
# By default, NAMEDATALEN is 64 so the maximum identifier length is
|
|
# 63 bytes. If this limit is problematic, it can be raised by changing
|
|
# the NAMEDATALEN constant in src/include/pg_config_manual.h.
|
|
#
|
|
# as we have to know our table names, we crop the names here, and to an
|
|
# extent that allows suffixes (like _evolution) to be added.
|
|
assert formdef.id is not None
|
|
if hasattr(formdef, 'table_name') and formdef.table_name:
|
|
return formdef.table_name
|
|
formdef.table_name = '%s_%s_%s' % (
|
|
formdef.data_sql_prefix, formdef.id,
|
|
get_name_as_sql_identifier(formdef.url_name)[:30])
|
|
formdef.store()
|
|
return formdef.table_name
|
|
|
|
|
|
def get_formdef_new_id(id_start):
|
|
new_id = id_start
|
|
conn, cur = get_connection_and_cursor()
|
|
while True:
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('formdata\\_%s\\_%%' % new_id,))
|
|
if cur.fetchone()[0] == 0:
|
|
break
|
|
new_id += 1
|
|
conn.commit()
|
|
cur.close()
|
|
return new_id
|
|
|
|
|
|
def get_carddef_new_id(id_start):
|
|
new_id = id_start
|
|
conn, cur = get_connection_and_cursor()
|
|
while True:
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('carddata\\_%s\\_%%' % new_id,))
|
|
if cur.fetchone()[0] == 0:
|
|
break
|
|
new_id += 1
|
|
conn.commit()
|
|
cur.close()
|
|
return new_id
|
|
|
|
|
|
def formdef_wipe():
|
|
conn, cur = get_connection_and_cursor()
|
|
cur.execute('''SELECT table_name FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('formdata\\_%%\\_%%',))
|
|
for table_name in [x[0] for x in cur.fetchall()]:
|
|
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def carddef_wipe():
|
|
conn, cur = get_connection_and_cursor()
|
|
cur.execute('''SELECT table_name FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('carddata\\_%%\\_%%',))
|
|
for table_name in [x[0] for x in cur.fetchall()]:
|
|
cur.execute('''DROP TABLE %s CASCADE''' % table_name)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def get_formdef_view_name(formdef):
|
|
prefix = 'wcs_view'
|
|
if formdef.data_sql_prefix != 'formdata':
|
|
prefix = 'wcs_%s_view' % formdef.data_sql_prefix
|
|
return '%s_%s_%s' % (prefix, formdef.id,
|
|
get_name_as_sql_identifier(formdef.url_name)[:40])
|
|
|
|
|
|
def guard_postgres(func):
|
|
def f(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except psycopg2.Error:
|
|
get_connection().rollback()
|
|
raise
|
|
return f
|
|
|
|
|
|
@guard_postgres
|
|
def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=True):
|
|
if formdef.id is None:
|
|
return []
|
|
|
|
if getattr(formdef, 'fields', None) is Ellipsis:
|
|
# don't touch tables for lightweight objects
|
|
return []
|
|
|
|
own_conn = False
|
|
if not conn:
|
|
own_conn = True
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
table_name = get_formdef_table_name(formdef)
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY,
|
|
user_id varchar,
|
|
receipt_time timestamp,
|
|
anonymised timestamptz,
|
|
status varchar,
|
|
page_no varchar,
|
|
workflow_data bytea,
|
|
id_display varchar
|
|
)''' % table_name)
|
|
cur.execute('''CREATE TABLE %s_evolutions (id serial PRIMARY KEY,
|
|
who varchar,
|
|
status varchar,
|
|
time timestamp,
|
|
last_jump_datetime timestamp,
|
|
comment text,
|
|
parts bytea,
|
|
formdata_id integer REFERENCES %s (id) ON DELETE CASCADE)''' % (
|
|
table_name, table_name))
|
|
do_formdef_indexes(formdef, created=True, conn=conn, cur=cur)
|
|
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set(['id', 'user_id', 'receipt_time',
|
|
'status', 'workflow_data', 'id_display', 'fts', 'page_no',
|
|
'anonymised', 'workflow_roles', 'workflow_roles_array',
|
|
'concerned_roles_array', 'tracking_code',
|
|
'actions_roles_array', 'backoffice_submission',
|
|
'submission_context', 'submission_agent_id', 'submission_channel',
|
|
'criticality_level', 'last_update_time',
|
|
'digest', 'user_label'])
|
|
|
|
# migrations
|
|
if not 'fts' in existing_fields:
|
|
# full text search
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
|
|
cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (
|
|
table_name, table_name))
|
|
|
|
if not 'workflow_roles' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles bytea''' % table_name)
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles_array text[]''' % table_name)
|
|
if not 'concerned_roles_array' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN concerned_roles_array text[]''' % table_name)
|
|
if not 'actions_roles_array' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN actions_roles_array text[]''' % table_name)
|
|
|
|
if not 'page_no' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN page_no varchar''' % table_name)
|
|
|
|
if not 'anonymised' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN anonymised timestamptz''' % table_name)
|
|
|
|
if not 'tracking_code' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN tracking_code varchar''' % table_name)
|
|
|
|
if not 'backoffice_submission' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN backoffice_submission boolean''' % table_name)
|
|
|
|
if not 'submission_context' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN submission_context bytea''' % table_name)
|
|
|
|
if 'submission_agent_id' not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN submission_agent_id varchar''' % table_name)
|
|
|
|
if not 'submission_channel' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN submission_channel varchar''' % table_name)
|
|
|
|
if not 'criticality_level' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN criticality_level integer NOT NULL DEFAULT(0)''' % table_name)
|
|
|
|
if not 'last_update_time' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp''' % table_name)
|
|
|
|
if not 'digest' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN digest varchar''' % table_name)
|
|
|
|
if not 'user_label' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN user_label varchar''' % table_name)
|
|
|
|
# add new fields
|
|
for field in formdef.get_all_fields():
|
|
assert field.id is not None
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
needed_fields.add(get_field_id(field))
|
|
if get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (
|
|
table_name,
|
|
get_field_id(field),
|
|
sql_type))
|
|
if field.store_display_value:
|
|
needed_fields.add('%s_display' % get_field_id(field))
|
|
if '%s_display' % get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % (
|
|
table_name,
|
|
'%s_display' % get_field_id(field)))
|
|
if field.store_structured_value:
|
|
needed_fields.add('%s_structured' % get_field_id(field))
|
|
if '%s_structured' % get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % (
|
|
table_name,
|
|
'%s_structured' % get_field_id(field)))
|
|
|
|
for field in (formdef.geolocations or {}).keys():
|
|
column_name = 'geoloc_%s' % field
|
|
needed_fields.add(column_name)
|
|
if column_name not in existing_fields:
|
|
cur.execute('ALTER TABLE %s ADD COLUMN %s %s''' % (
|
|
table_name, column_name, 'POINT'))
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field))
|
|
|
|
# migrations on _evolutions table
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = '%s_evolutions'
|
|
''' % table_name)
|
|
evo_existing_fields = set([x[0] for x in cur.fetchall()])
|
|
if 'last_jump_datetime' not in evo_existing_fields:
|
|
cur.execute('''ALTER TABLE %s_evolutions ADD COLUMN last_jump_datetime timestamp''' % table_name)
|
|
|
|
if rebuild_views or len(existing_fields - needed_fields):
|
|
# views may have been dropped when dropping columns, so we recreate
|
|
# them even if not asked to.
|
|
redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views)
|
|
|
|
if own_conn:
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
actions = []
|
|
if not 'concerned_roles_array' in existing_fields:
|
|
actions.append('rebuild_security')
|
|
elif not 'actions_roles_array' in existing_fields:
|
|
actions.append('rebuild_security')
|
|
if not 'tracking_code' in existing_fields:
|
|
# if tracking code has just been added to the table we need to make
|
|
# sure the tracking code table does exist.
|
|
actions.append('do_tracking_code_table')
|
|
|
|
return actions
|
|
|
|
|
|
def do_formdef_indexes(formdef, created, conn, cur, concurrently=False):
|
|
table_name = get_formdef_table_name(formdef)
|
|
evolutions_table_name = table_name + '_evolutions'
|
|
existing_indexes = set()
|
|
if not created:
|
|
cur.execute('''SELECT indexname
|
|
FROM pg_indexes
|
|
WHERE schemaname = 'public'
|
|
AND tablename IN (%s, %s)''', (table_name, evolutions_table_name))
|
|
existing_indexes = set([x[0] for x in cur.fetchall()])
|
|
|
|
create_index = 'CREATE INDEX'
|
|
if concurrently:
|
|
create_index = 'CREATE INDEX CONCURRENTLY'
|
|
|
|
if not evolutions_table_name + '_fid' in existing_indexes:
|
|
cur.execute('''%s %s_fid ON %s (formdata_id)''' % (
|
|
create_index, evolutions_table_name, evolutions_table_name))
|
|
|
|
for attr in ('receipt_time', 'anonymised', 'user_id'):
|
|
if not table_name + '_' + attr + '_idx' in existing_indexes:
|
|
cur.execute('%(create_index)s %(table_name)s_%(attr)s_idx ON %(table_name)s (%(attr)s)' % {
|
|
'create_index': create_index,
|
|
'table_name': table_name,
|
|
'attr': attr})
|
|
|
|
|
|
@guard_postgres
|
|
def do_user_table():
|
|
conn, cur = get_connection_and_cursor()
|
|
table_name = 'users'
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY,
|
|
name varchar,
|
|
ascii_name varchar,
|
|
email varchar,
|
|
roles text[],
|
|
is_active bool,
|
|
is_admin bool,
|
|
anonymous bool,
|
|
verified_fields text[],
|
|
name_identifiers text[],
|
|
lasso_dump text,
|
|
last_seen timestamp,
|
|
deleted_timestamp timestamp
|
|
)''' % table_name)
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set(['id', 'name', 'email', 'roles', 'is_admin',
|
|
'anonymous', 'name_identifiers', 'verified_fields',
|
|
'lasso_dump', 'last_seen', 'fts', 'ascii_name',
|
|
'deleted_timestamp', 'is_active'])
|
|
|
|
from wcs.admin.settings import UserFieldsFormDef
|
|
formdef = UserFieldsFormDef()
|
|
|
|
for field in formdef.get_all_fields():
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
needed_fields.add(get_field_id(field))
|
|
if get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (
|
|
table_name,
|
|
get_field_id(field),
|
|
sql_type))
|
|
if field.store_display_value:
|
|
needed_fields.add('%s_display' % get_field_id(field))
|
|
if '%s_display' % get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % (
|
|
table_name, '%s_display' % get_field_id(field)))
|
|
if field.store_structured_value:
|
|
needed_fields.add('%s_structured' % get_field_id(field))
|
|
if '%s_structured' % get_field_id(field) not in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % (
|
|
table_name, '%s_structured' % get_field_id(field)))
|
|
|
|
# migrations
|
|
if not 'fts' in existing_fields:
|
|
# full text search
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
|
|
cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (
|
|
table_name, table_name))
|
|
|
|
if not 'verified_fields' in existing_fields:
|
|
cur.execute('ALTER TABLE %s ADD COLUMN verified_fields text[]' % table_name)
|
|
|
|
if not 'ascii_name' in existing_fields:
|
|
cur.execute('ALTER TABLE %s ADD COLUMN ascii_name varchar' % table_name)
|
|
|
|
if 'deleted_timestamp' not in existing_fields:
|
|
cur.execute('ALTER TABLE %s ADD COLUMN deleted_timestamp timestamp' % table_name)
|
|
|
|
if 'is_active' not in existing_fields:
|
|
cur.execute('ALTER TABLE %s ADD COLUMN is_active bool DEFAULT TRUE' % table_name)
|
|
cur.execute('UPDATE %s SET is_active = FALSE WHERE deleted_timestamp IS NOT NULL' % table_name)
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
|
|
|
|
conn.commit()
|
|
|
|
try:
|
|
cur.execute('CREATE INDEX users_name_idx ON users (name)')
|
|
conn.commit()
|
|
except psycopg2.ProgrammingError:
|
|
conn.rollback()
|
|
|
|
cur.close()
|
|
|
|
|
|
def do_tracking_code_table():
|
|
conn, cur = get_connection_and_cursor()
|
|
table_name = 'tracking_codes'
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY,
|
|
formdef_id varchar,
|
|
formdata_id varchar)''' % table_name)
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set(['id', 'formdef_id', 'formdata_id'])
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def do_session_table():
|
|
conn, cur = get_connection_and_cursor()
|
|
table_name = 'sessions'
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY,
|
|
session_data bytea,
|
|
name_identifier varchar,
|
|
visiting_objects_keys varchar[]
|
|
)''' % table_name)
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set(['id', 'session_data', 'name_identifier',
|
|
'visiting_objects_keys', 'last_update_time'])
|
|
|
|
# migrations
|
|
if not 'last_update_time' in existing_fields:
|
|
cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp DEFAULT NOW()''' % table_name)
|
|
cur.execute('''CREATE INDEX %s_ts ON %s (last_update_time)''' % (
|
|
table_name, table_name))
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def do_custom_views_table():
|
|
conn, cur = get_connection_and_cursor()
|
|
table_name = 'custom_views'
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY,
|
|
title varchar,
|
|
slug varchar,
|
|
user_id varchar,
|
|
visibility varchar,
|
|
formdef_type varchar,
|
|
formdef_id varchar,
|
|
order_by varchar,
|
|
columns jsonb,
|
|
filters jsonb
|
|
)''' % table_name)
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set([x[0] for x in CustomView._table_static_fields])
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def do_snapshots_table():
|
|
conn, cur = get_connection_and_cursor()
|
|
table_name = 'snapshots'
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE %s (id SERIAL,
|
|
object_type VARCHAR,
|
|
object_id VARCHAR,
|
|
timestamp TIMESTAMP WITH TIME ZONE,
|
|
user_id VARCHAR,
|
|
comment TEXT,
|
|
serialization TEXT,
|
|
label VARCHAR
|
|
)''' % table_name)
|
|
cur.execute('''SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', (table_name,))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
|
|
needed_fields = set([x[0] for x in Snapshot._table_static_fields])
|
|
|
|
# delete obsolete fields
|
|
for field in (existing_fields - needed_fields):
|
|
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
@guard_postgres
|
|
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True):
|
|
own_conn = False
|
|
if not conn:
|
|
own_conn = True
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', ('wcs_meta',))
|
|
if cur.fetchone()[0] == 0:
|
|
cur.execute('''CREATE TABLE wcs_meta (id serial PRIMARY KEY,
|
|
key varchar,
|
|
value varchar)''')
|
|
if insert_current_sql_level:
|
|
sql_level = SQL_LEVEL
|
|
else:
|
|
sql_level = 0
|
|
cur.execute('''INSERT INTO wcs_meta (id, key, value)
|
|
VALUES (DEFAULT, %s, %s)''', ('sql_level', str(sql_level)))
|
|
|
|
if own_conn:
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
@guard_postgres
|
|
def redo_views(conn, cur, formdef, rebuild_global_views=False):
|
|
if get_publisher().get_site_option('postgresql_views') == 'false':
|
|
return
|
|
|
|
if formdef.id is None:
|
|
return
|
|
|
|
drop_views(formdef, conn, cur)
|
|
do_views(formdef, conn, cur, rebuild_global_views=rebuild_global_views)
|
|
|
|
|
|
@guard_postgres
|
|
def drop_views(formdef, conn, cur):
|
|
# remove the global views
|
|
drop_global_views(conn, cur)
|
|
|
|
if formdef:
|
|
# remove the form view itself
|
|
view_prefix = 'wcs\\_view\\_%s\\_%%' % formdef.id
|
|
if formdef.data_sql_prefix != 'formdata':
|
|
view_prefix = 'wcs\\_%s\\_view\\_%s\\_%%' % (formdef.data_sql_prefix, formdef.id)
|
|
cur.execute('''SELECT table_name FROM information_schema.views
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', (view_prefix,))
|
|
else:
|
|
# if there's no formdef specified, remove all form views
|
|
cur.execute('''SELECT table_name FROM information_schema.views
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('wcs\\_view\\_%',))
|
|
view_names = []
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
view_names.append(row[0])
|
|
|
|
for view_name in view_names:
|
|
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
|
|
|
|
|
|
def get_view_fields(formdef):
|
|
view_fields = []
|
|
view_fields.append(("int '%s'" % (formdef.category_id or 0), 'category_id'))
|
|
view_fields.append(("int '%s'" % (formdef.id or 0), 'formdef_id'))
|
|
for field in ('id', 'user_id', 'receipt_time', 'status',
|
|
'id_display', 'submission_agent_id', 'submission_channel', 'backoffice_submission',
|
|
'last_update_time', 'digest', 'user_label'):
|
|
view_fields.append((field, field))
|
|
return view_fields
|
|
|
|
|
|
@guard_postgres
|
|
def do_views(formdef, conn, cur, rebuild_global_views=True):
|
|
# create new view
|
|
table_name = get_formdef_table_name(formdef)
|
|
view_name = get_formdef_view_name(formdef)
|
|
view_fields = get_view_fields(formdef)
|
|
|
|
column_names = {}
|
|
for field in formdef.get_all_fields():
|
|
field_key = get_field_id(field)
|
|
if field.type in ('page', 'title', 'subtitle', 'comment'):
|
|
continue
|
|
if field.varname:
|
|
# the variable should be fine as is but we pass it through
|
|
# get_name_as_sql_identifier nevertheless, to be extra sure it
|
|
# doesn't contain invalid characters.
|
|
field_name = 'f_%s' % get_name_as_sql_identifier(field.varname)[:50]
|
|
else:
|
|
field_name = '%s_%s' % (get_field_id(field), get_name_as_sql_identifier(field.label))
|
|
field_name = field_name[:50]
|
|
if field_name in column_names:
|
|
# it may happen that the same varname is used on multiple fields
|
|
# (for example in the case of conditional pages), in that situation
|
|
# we suffix the field name with an index count
|
|
while field_name in column_names:
|
|
column_names[field_name] += 1
|
|
field_name = '%s_%s' % (field_name, column_names[field_name])
|
|
column_names[field_name] = 1
|
|
view_fields.append((field_key, field_name))
|
|
if field.store_display_value:
|
|
field_key = '%s_display' % get_field_id(field)
|
|
view_fields.append((field_key, field_name + '_display'))
|
|
|
|
view_fields.append(('''ARRAY(SELECT status FROM %s_evolutions '''\
|
|
''' WHERE %s.id = %s_evolutions.formdata_id'''\
|
|
''' ORDER BY %s_evolutions.time)''' % ((table_name,) * 4),
|
|
'status_history'))
|
|
|
|
# add a is_at_endpoint column, dynamically created againt the endpoint status.
|
|
endpoint_status = formdef.workflow.get_endpoint_status()
|
|
view_fields.append(('''(SELECT status = ANY(ARRAY[[%s]]::text[]))''' % \
|
|
', '.join(["'wf-%s'" % x.id for x in endpoint_status]),
|
|
'''is_at_endpoint'''))
|
|
|
|
# [CRITICALITY_1] Add criticality_level, computed relative to levels in
|
|
# the given workflow, so all higher criticalites are sorted first. This is
|
|
# reverted when loading the formdata back, in [CRITICALITY_2]
|
|
levels = len(formdef.workflow.criticality_levels or [0])
|
|
view_fields.append(('''(criticality_level - %d)''' % levels,
|
|
'''criticality_level'''))
|
|
|
|
view_fields.append((
|
|
cur.mogrify('(SELECT text %s)', (formdef.name,)), 'formdef_name'))
|
|
|
|
view_fields.append(('''(SELECT name FROM users
|
|
WHERE users.id = CAST(user_id AS INTEGER))''', 'user_name'))
|
|
|
|
view_fields.append(('concerned_roles_array', 'concerned_roles_array'))
|
|
view_fields.append(('actions_roles_array', 'actions_roles_array'))
|
|
view_fields.append(('fts', 'fts'))
|
|
|
|
if formdef.geolocations and 'base' in formdef.geolocations:
|
|
# default geolocation is in the 'base' key; we have to unstructure the
|
|
# field is the POINT type of postgresql cannot be used directly as it
|
|
# doesn't have an equality operator.
|
|
view_fields.append(('geoloc_base[0]', 'geoloc_base_x'))
|
|
view_fields.append(('geoloc_base[1]', 'geoloc_base_y'))
|
|
else:
|
|
view_fields.append(('NULL::real', 'geoloc_base_x'))
|
|
view_fields.append(('NULL::real', 'geoloc_base_y'))
|
|
view_fields.append(('anonymised', 'anonymised'))
|
|
|
|
fields_list = ', '.join(['%s AS %s' % (force_text(x), force_text(y)) for (x, y) in view_fields])
|
|
|
|
cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % (
|
|
view_name, fields_list, table_name))
|
|
|
|
if rebuild_global_views:
|
|
do_global_views(conn, cur) # recreate global views
|
|
|
|
|
|
def drop_global_views(conn, cur):
|
|
cur.execute('''SELECT table_name FROM information_schema.views
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('wcs\\_category\\_%',))
|
|
view_names = []
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
view_names.append(row[0])
|
|
|
|
for view_name in view_names:
|
|
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
|
|
|
|
if get_publisher().pg_version >= 90400:
|
|
# drop materialized view that may have been created by a previous
|
|
# version.
|
|
cur.execute('''DROP MATERIALIZED VIEW IF EXISTS wcs_materialized_all_forms''')
|
|
|
|
cur.execute('''DROP VIEW IF EXISTS wcs_all_forms''')
|
|
|
|
|
|
def do_global_views(conn, cur):
|
|
# recreate global views
|
|
from wcs.formdef import FormDef
|
|
view_names = [get_formdef_view_name(x) for x in FormDef.select(ignore_migration=True)]
|
|
|
|
cur.execute('''SELECT table_name FROM information_schema.views
|
|
WHERE table_schema = 'public'
|
|
AND table_name LIKE %s''', ('wcs\\_view\\_%',))
|
|
existing_views = set()
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
existing_views.add(row[0])
|
|
|
|
view_names = existing_views.intersection(view_names)
|
|
if not view_names:
|
|
return
|
|
|
|
cur.execute('''DROP VIEW IF EXISTS wcs_all_forms CASCADE''')
|
|
|
|
fake_formdef = FormDef()
|
|
common_fields = get_view_fields(fake_formdef)
|
|
common_fields.append(('concerned_roles_array', 'concerned_roles_array'))
|
|
common_fields.append(('actions_roles_array', 'actions_roles_array'))
|
|
common_fields.append(('fts', 'fts'))
|
|
common_fields.append(('is_at_endpoint', 'is_at_endpoint'))
|
|
common_fields.append(('formdef_name', 'formdef_name'))
|
|
common_fields.append(('user_name', 'user_name'))
|
|
common_fields.append(('criticality_level', 'criticality_level'))
|
|
common_fields.append(('geoloc_base_x', 'geoloc_base_x'))
|
|
common_fields.append(('geoloc_base_y', 'geoloc_base_y'))
|
|
common_fields.append(('anonymised', 'anonymised'))
|
|
|
|
union = ' UNION ALL '.join(['''SELECT %s FROM %s''' % (
|
|
', '.join([y[1] for y in common_fields]), x) for x in view_names])
|
|
cur.execute('''CREATE VIEW wcs_all_forms AS %s''' % union)
|
|
|
|
for category in wcs.categories.Category.select():
|
|
name = get_name_as_sql_identifier(category.url_name)[:40]
|
|
cur.execute('''CREATE VIEW wcs_category_%s AS SELECT * from wcs_all_forms
|
|
WHERE category_id = %s''' % (name, category.id))
|
|
|
|
|
|
class SqlMixin(object):
|
|
_table_name = None
|
|
_numerical_id = True
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def keys(cls, clause=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
assert not func_clause
|
|
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
|
if where_clauses:
|
|
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
cur.execute(sql_statement, parameters)
|
|
ids = [x[0] for x in cur.fetchall()]
|
|
conn.commit()
|
|
cur.close()
|
|
return ids
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def count(cls, clause=None):
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
if func_clause:
|
|
# fallback to counting the result of a select()
|
|
return len(cls.select(clause))
|
|
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
|
|
if where_clauses:
|
|
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
conn, cur = get_connection_and_cursor()
|
|
cur.execute(sql_statement, parameters)
|
|
count = cur.fetchone()[0]
|
|
conn.commit()
|
|
cur.close()
|
|
return count
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_with_indexed_value(cls, index, value, ignore_errors = False):
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE %s = %%(value)s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name,
|
|
index)
|
|
cur.execute(sql_statement, {'value': str(value)})
|
|
try:
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
ob = cls._row2ob(row)
|
|
if ignore_errors and ob is None:
|
|
continue
|
|
yield ob
|
|
finally:
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_ids_from_query(cls, query):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT id FROM %s
|
|
WHERE fts @@ plainto_tsquery(%%(value)s)''' % cls._table_name
|
|
cur.execute(sql_statement, {'value': FtsMatch.get_fts_value(query)})
|
|
all_ids = [x[0] for x in cur.fetchall()]
|
|
cur.close()
|
|
return all_ids
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
|
if cls._numerical_id or id is None:
|
|
try:
|
|
int(id)
|
|
except (TypeError, ValueError):
|
|
if ignore_errors and id is None:
|
|
return None
|
|
else:
|
|
raise KeyError()
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE id = %%(id)s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'id': str(id)})
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
cur.close()
|
|
if ignore_errors:
|
|
return None
|
|
raise KeyError()
|
|
cur.close()
|
|
return cls._row2ob(row)
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None):
|
|
if not ids:
|
|
return []
|
|
tables = [cls._table_name]
|
|
columns = ['%s.%s' % (cls._table_name, column_name) for column_name in
|
|
[x[0] for x in cls._table_static_fields] + cls.get_data_fields()]
|
|
extra_fields = []
|
|
if fields:
|
|
# look for relations
|
|
for field in fields:
|
|
if not getattr(field, 'is_related_field', False):
|
|
continue
|
|
if field.parent_field_id == 'user-label':
|
|
# relation to user table
|
|
carddef_table_alias = 'users'
|
|
carddef_table_decl = 'LEFT JOIN users ON (CAST(%s.user_id AS INTEGER) = users.id)' % cls._table_name
|
|
else:
|
|
carddef_dataclass = field.carddef.data_class()
|
|
carddef_table_alias = 't%s' % id(field.carddef)
|
|
carddef_table_decl = 'LEFT JOIN %s AS %s ON (CAST(%s.%s AS INTEGER) = %s.id)' % (
|
|
carddef_dataclass._table_name,
|
|
carddef_table_alias,
|
|
cls._table_name,
|
|
get_field_id(field.parent_field),
|
|
carddef_table_alias)
|
|
|
|
if carddef_table_decl not in tables:
|
|
tables.append(carddef_table_decl)
|
|
|
|
column_field_id = get_field_id(field.related_field)
|
|
if field.related_field.store_display_value:
|
|
column_field_id += '_display'
|
|
columns.append('%s.%s' % (carddef_table_alias, column_field_id))
|
|
extra_fields.append(field.id)
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE %s.id IN (%s)''' % (
|
|
', '.join(columns),
|
|
' '.join(tables),
|
|
cls._table_name,
|
|
','.join([str(x) for x in ids]))
|
|
cur.execute(sql_statement)
|
|
objects = cls.get_objects(cur, extra_fields=extra_fields)
|
|
conn.commit()
|
|
cur.close()
|
|
if ignore_errors:
|
|
objects = (x for x in objects if x is not None)
|
|
if keep_order:
|
|
objects_dict = {}
|
|
for object in objects:
|
|
objects_dict[object.id] = object
|
|
objects = [objects_dict[x] for x in ids if objects_dict.get(x)]
|
|
return list(objects)
|
|
|
|
@classmethod
|
|
def get_objects_iterator(cls, cur, ignore_errors=False, extra_fields=None):
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
yield cls._row2ob(row, extra_fields=extra_fields)
|
|
|
|
@classmethod
|
|
def get_objects(cls, cur, ignore_errors=False, iterator=False, extra_fields=None):
|
|
generator = cls.get_objects_iterator(
|
|
cur=cur,
|
|
ignore_errors=ignore_errors,
|
|
extra_fields=extra_fields)
|
|
if iterator:
|
|
return generator
|
|
return list(generator)
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def select_iterator(cls, clause=None, order_by=None, ignore_errors=False,
|
|
limit=None, offset=None):
|
|
sql_statement = '''SELECT %s
|
|
FROM %s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
if where_clauses:
|
|
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
|
|
if order_by:
|
|
# [SEC_ORDER] security note: it is not possible to use
|
|
# prepared statements for ORDER BY clauses, therefore input
|
|
# is controlled beforehand (see misc.get_order_by_or_400).
|
|
if order_by.startswith('-'):
|
|
order_by = order_by[1:]
|
|
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
|
|
else:
|
|
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
|
|
|
|
if not func_clause:
|
|
if limit:
|
|
sql_statement += ' LIMIT %(limit)s'
|
|
parameters['limit'] = limit
|
|
if offset:
|
|
sql_statement += ' OFFSET %(offset)s'
|
|
parameters['offset'] = offset
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
cur.execute(sql_statement, parameters)
|
|
try:
|
|
for object in cls.get_objects(cur, iterator=True):
|
|
if object is None:
|
|
continue
|
|
if func_clause and not func_clause(object):
|
|
continue
|
|
yield object
|
|
finally:
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None,
|
|
iterator=False):
|
|
objects = cls.select_iterator(clause=clause, order_by=order_by,
|
|
ignore_errors=ignore_errors,
|
|
limit=limit, offset=offset)
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
if func_clause and (limit or offset):
|
|
objects = _take(objects, limit, offset)
|
|
if iterator:
|
|
return objects
|
|
return list(objects)
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def select_distinct(cls, columns, clause=None):
|
|
# do note this method returns unicode strings.
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = 'SELECT DISTINCT ON (%s) %s FROM %s' % (columns[0], ', '.join(columns), cls._table_name)
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
assert not func_clause
|
|
if where_clauses:
|
|
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
sql_statement += ' ORDER BY %s' % columns[0]
|
|
cur.execute(sql_statement, parameters)
|
|
values = [x for x in cur.fetchall()]
|
|
conn.commit()
|
|
cur.close()
|
|
return values
|
|
|
|
def get_sql_dict_from_data(self, data, formdef):
|
|
sql_dict = {}
|
|
for field in formdef.get_all_fields():
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
value = data.get(field.id)
|
|
if value is not None:
|
|
if field.key in ('ranked-items', 'password'):
|
|
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
|
|
value = [[force_str(x), force_str(y)] for x, y in value.items()]
|
|
elif sql_type == 'varchar':
|
|
assert isinstance(value, six.string_types)
|
|
elif sql_type == 'date':
|
|
assert type(value) is time.struct_time
|
|
value = datetime.datetime(value.tm_year, value.tm_mon, value.tm_mday)
|
|
elif sql_type == 'bytea':
|
|
value = bytearray(pickle.dumps(value, protocol=2))
|
|
elif sql_type == 'jsonb' and value.get('schema'):
|
|
# block field, adapt date/field values
|
|
value = copy.deepcopy(value)
|
|
for field_id, field_type in value.get('schema').items():
|
|
if field_type not in ('date', 'file'):
|
|
continue
|
|
for entry in value.get('data') or []:
|
|
subvalue = entry.get(field_id)
|
|
if subvalue and field_type == 'date':
|
|
entry[field_id] = strftime('%Y-%m-%d', subvalue)
|
|
elif subvalue and field_type == 'file':
|
|
entry[field_id] = subvalue.__getstate__()
|
|
elif sql_type == 'boolean':
|
|
pass
|
|
sql_dict[get_field_id(field)] = value
|
|
if field.store_display_value:
|
|
sql_dict['%s_display' % get_field_id(field)] = data.get('%s_display' % field.id)
|
|
if field.store_structured_value:
|
|
sql_dict['%s_structured' % get_field_id(field)] = bytearray(
|
|
pickle.dumps(data.get('%s_structured' % field.id), protocol=2))
|
|
return sql_dict
|
|
|
|
@classmethod
|
|
def _row2obdata(cls, row, formdef):
|
|
obdata = {}
|
|
i = len(cls._table_static_fields)
|
|
if formdef.geolocations:
|
|
i += len(formdef.geolocations.keys())
|
|
for field in formdef.get_all_fields():
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
value = row[i]
|
|
if value is not None:
|
|
value = str_encode(value)
|
|
if field.key == 'ranked-items':
|
|
d = {}
|
|
for data, rank in value:
|
|
try:
|
|
d[data] = int(rank)
|
|
except ValueError:
|
|
d[data] = rank
|
|
value = d
|
|
elif field.key == 'password':
|
|
d = {}
|
|
for fmt, val in value:
|
|
d[fmt] = force_str(val)
|
|
value = d
|
|
if sql_type == 'date':
|
|
value = value.timetuple()
|
|
elif sql_type == 'bytea':
|
|
value = pickle_loads(value)
|
|
elif sql_type == 'jsonb' and value.get('schema'):
|
|
# block field, adapt date/field values
|
|
for field_id, field_type in value.get('schema').items():
|
|
if field_type not in ('date', 'file'):
|
|
continue
|
|
for entry in value.get('data') or []:
|
|
subvalue = entry.get(field_id)
|
|
if subvalue and field_type == 'date':
|
|
entry[field_id] = time.strptime(subvalue, '%Y-%m-%d')
|
|
elif subvalue and field_type == 'file':
|
|
entry[field_id] = PicklableUpload.__new__(PicklableUpload)
|
|
entry[field_id].__setstate__(subvalue)
|
|
|
|
obdata[field.id] = value
|
|
i += 1
|
|
if field.store_display_value:
|
|
value = str_encode(row[i])
|
|
obdata['%s_display' % field.id] = value
|
|
i += 1
|
|
if field.store_structured_value:
|
|
value = row[i]
|
|
if value is not None:
|
|
obdata['%s_structured' % field.id] = pickle_loads(value)
|
|
if obdata['%s_structured' % field.id] is None:
|
|
del obdata['%s_structured' % field.id]
|
|
i += 1
|
|
|
|
return obdata
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def remove_object(cls, id):
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''DELETE FROM %s
|
|
WHERE id = %%(id)s''' % cls._table_name
|
|
cur.execute(sql_statement, {'id': str(id)})
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def wipe(cls):
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''DELETE FROM %s''' % cls._table_name
|
|
cur.execute(sql_statement, {'id': str(id)})
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_sorted_ids(cls, order_by, clause=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
assert not func_clause
|
|
if where_clauses:
|
|
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
# security note, refer to [SEC_ORDER]
|
|
if order_by.startswith('-'):
|
|
order_by = order_by[1:]
|
|
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
|
|
else:
|
|
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
|
|
cur.execute(sql_statement, parameters)
|
|
ids = [x[0] for x in cur.fetchall()]
|
|
conn.commit()
|
|
cur.close()
|
|
return ids
|
|
|
|
|
|
class SqlDataMixin(SqlMixin):
|
|
_names = None # make sure StorableObject methods fail
|
|
_formdef = None
|
|
|
|
_table_static_fields = [
|
|
('id', 'serial'),
|
|
('user_id', 'varchar'),
|
|
('receipt_time', 'timestamp'),
|
|
('status', 'varchar'),
|
|
('page_no', 'varchar'),
|
|
('anonymised', 'timestamptz'),
|
|
('workflow_data', 'bytea'),
|
|
('id_display', 'varchar'),
|
|
('workflow_roles', 'bytea'),
|
|
('workflow_roles_array', 'text[]'),
|
|
('concerned_roles_array', 'text[]'),
|
|
('actions_roles_array', 'text[]'),
|
|
('tracking_code', 'varchar'),
|
|
('backoffice_submission', 'boolean'),
|
|
('submission_context', 'bytea'),
|
|
('submission_agent_id', 'varchar'),
|
|
('submission_channel', 'varchar'),
|
|
('criticality_level', 'int'),
|
|
('last_update_time', 'timestamp'),
|
|
('digest', 'varchar'),
|
|
('user_label', 'varchar'),
|
|
]
|
|
|
|
def __init__(self, id=None):
|
|
self.id = id
|
|
self.data = {}
|
|
|
|
_evolution = None
|
|
@guard_postgres
|
|
def get_evolution(self):
|
|
if self._evolution:
|
|
return self._evolution
|
|
if not self.id:
|
|
self._evolution = []
|
|
return self._evolution
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''SELECT id, who, status, time, last_jump_datetime,
|
|
comment, parts FROM %s_evolutions
|
|
WHERE formdata_id = %%(id)s
|
|
ORDER BY id''' % self._table_name
|
|
cur.execute(sql_statement, {'id': self.id})
|
|
self._evolution = []
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
self._evolution.append(self._row2evo(row, formdata=self))
|
|
conn.commit()
|
|
cur.close()
|
|
return self._evolution
|
|
|
|
@classmethod
|
|
def _row2evo(cls, row, formdata):
|
|
o = wcs.formdata.Evolution(formdata)
|
|
o._sql_id, o.who, o.status, o.time, o.last_jump_datetime, o.comment = [
|
|
str_encode(x) for x in tuple(row[:6])]
|
|
if o.time:
|
|
o.time = o.time.timetuple()
|
|
if row[6]:
|
|
o.parts = pickle_loads(row[6])
|
|
return o
|
|
|
|
def set_evolution(self, value):
|
|
self._evolution = value
|
|
|
|
evolution = property(get_evolution, set_evolution)
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def load_all_evolutions(cls, values):
|
|
# Typically formdata.evolution is loaded on-demand (see above
|
|
# property()) and this is fine to minimize queries, especially when
|
|
# dealing with a single formdata. However in some places (to compute
|
|
# statistics for example) it is sometimes useful to access .evolution
|
|
# on a serie of formdata and in that case, it's more efficient to
|
|
# optimize the process loading all evolutions in a single batch query.
|
|
object_dict = dict([(x.id, x) for x in values if x.id and x._evolution is None])
|
|
if not object_dict:
|
|
return
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''SELECT id, who, status, time, last_jump_datetime,
|
|
comment, parts, formdata_id
|
|
FROM %s_evolutions''' % cls._table_name
|
|
sql_statement += ''' WHERE formdata_id IN %(object_ids)s ORDER BY id'''
|
|
cur.execute(sql_statement, {'object_ids': tuple(object_dict.keys())})
|
|
|
|
for value in values:
|
|
value._evolution = []
|
|
|
|
while True:
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
break
|
|
_sql_id, who, status, time, last_jump_datetime, comment, parts, formdata_id = tuple(row[:8])
|
|
formdata = object_dict.get(formdata_id)
|
|
if not formdata:
|
|
continue
|
|
formdata._evolution.append(formdata._row2evo(row, formdata))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@guard_postgres
|
|
@invalidate_substitution_cache
|
|
def store(self):
|
|
sql_dict = {
|
|
'user_id': self.user_id,
|
|
'status': self.status,
|
|
'page_no': self.page_no,
|
|
'workflow_data': bytearray(pickle.dumps(self.workflow_data, protocol=2)),
|
|
'id_display': self.id_display,
|
|
'anonymised': self.anonymised,
|
|
'tracking_code': self.tracking_code,
|
|
'backoffice_submission': self.backoffice_submission,
|
|
'submission_context': self.submission_context,
|
|
'submission_agent_id': self.submission_agent_id,
|
|
'submission_channel': self.submission_channel,
|
|
'criticality_level': self.criticality_level,
|
|
}
|
|
if self.last_update_time:
|
|
sql_dict['last_update_time'] = datetime.datetime.fromtimestamp(time.mktime(self.last_update_time))
|
|
else:
|
|
sql_dict['last_update_time'] = None
|
|
if self.receipt_time:
|
|
sql_dict['receipt_time'] = datetime.datetime.fromtimestamp(time.mktime(self.receipt_time))
|
|
else:
|
|
sql_dict['receipt_time'] = None
|
|
if self.workflow_roles:
|
|
sql_dict['workflow_roles'] = bytearray(pickle.dumps(self.workflow_roles, protocol=2))
|
|
sql_dict['workflow_roles_array'] = [str(x) for x in self.workflow_roles.values() if x is not None]
|
|
else:
|
|
sql_dict['workflow_roles'] = None
|
|
sql_dict['workflow_roles_array'] = None
|
|
if self.submission_context:
|
|
sql_dict['submission_context'] = bytearray(pickle.dumps(self.submission_context, protocol=2))
|
|
else:
|
|
sql_dict['submission_context'] = None
|
|
|
|
for field in (self._formdef.geolocations or {}).keys():
|
|
value = (self.geolocations or {}).get(field)
|
|
if value:
|
|
value = '(%.6f, %.6f)' % (value.get('lon'), value.get('lat'))
|
|
sql_dict['geoloc_%s' % field] = value
|
|
|
|
sql_dict['concerned_roles_array'] = [str(x) for x in self.concerned_roles if x]
|
|
sql_dict['actions_roles_array'] = [str(x) for x in self.actions_roles if x]
|
|
|
|
sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef))
|
|
conn, cur = get_connection_and_cursor()
|
|
if not self.id:
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''INSERT INTO %s (id, %s)
|
|
VALUES (DEFAULT, %s)
|
|
RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
self.id = cur.fetchone()[0]
|
|
else:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.id
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
self.id = cur.fetchone()[0]
|
|
|
|
if self.set_auto_fields():
|
|
sql_statement = '''UPDATE %s
|
|
SET id_display = %%(id_display)s,
|
|
digest = %%(digest)s,
|
|
user_label = %%(user_label)s
|
|
WHERE id = %%(id)s''' % self._table_name
|
|
cur.execute(sql_statement, {
|
|
'id': self.id,
|
|
'id_display': self.id_display,
|
|
'digest': self.digest,
|
|
'user_label': self.user_label,
|
|
})
|
|
|
|
if self._evolution:
|
|
for evo in self._evolution:
|
|
sql_dict = {}
|
|
if hasattr(evo, '_sql_id'):
|
|
sql_dict.update({'id': evo._sql_id})
|
|
sql_statement = '''UPDATE %s_evolutions SET
|
|
who = %%(who)s,
|
|
time = %%(time)s,
|
|
last_jump_datetime = %%(last_jump_datetime)s,
|
|
status = %%(status)s,
|
|
comment = %%(comment)s,
|
|
parts = %%(parts)s
|
|
WHERE id = %%(id)s
|
|
RETURNING id''' % self._table_name
|
|
else:
|
|
sql_statement = '''INSERT INTO %s_evolutions (
|
|
id, who, status,
|
|
time, last_jump_datetime,
|
|
comment, parts,
|
|
formdata_id)
|
|
VALUES (DEFAULT, %%(who)s, %%(status)s,
|
|
%%(time)s, %%(last_jump_datetime)s,
|
|
%%(comment)s,
|
|
%%(parts)s, %%(formdata_id)s)
|
|
RETURNING id''' % self._table_name
|
|
sql_dict.update({
|
|
'who': evo.who,
|
|
'status': evo.status,
|
|
'time': datetime.datetime.fromtimestamp(time.mktime(evo.time)),
|
|
'last_jump_datetime': evo.last_jump_datetime,
|
|
'comment': evo.comment,
|
|
'formdata_id': self.id,
|
|
})
|
|
if evo.parts:
|
|
sql_dict['parts'] = bytearray(pickle.dumps(evo.parts, protocol=2))
|
|
else:
|
|
sql_dict['parts'] = None
|
|
cur.execute(sql_statement, sql_dict)
|
|
evo._sql_id = cur.fetchone()[0]
|
|
|
|
fts_strings = [str(self.id), self.get_display_id()]
|
|
fts_strings.append(self._formdef.name)
|
|
if self.tracking_code:
|
|
fts_strings.append(self.tracking_code)
|
|
for field in self._formdef.get_all_fields():
|
|
if not self.data.get(field.id):
|
|
continue
|
|
value = None
|
|
if field.key in ('string', 'text', 'email'):
|
|
value = self.data.get(field.id)
|
|
elif field.key in ('item', 'items'):
|
|
value = self.data.get('%s_display' % field.id)
|
|
if value:
|
|
if isinstance(value, six.string_types):
|
|
fts_strings.append(value)
|
|
elif type(value) in (tuple, list):
|
|
fts_strings.extend(value)
|
|
if self._evolution:
|
|
for evo in self._evolution:
|
|
if evo.comment:
|
|
fts_strings.append(evo.comment)
|
|
for part in evo.parts or []:
|
|
if hasattr(part, 'view'):
|
|
html_part = part.view()
|
|
if html_part:
|
|
fts_strings.append(qommon.misc.html2text(html_part))
|
|
user = self.get_user()
|
|
if user:
|
|
fts_strings.append(user.get_display_name())
|
|
|
|
sql_statement = '''UPDATE %s SET fts = to_tsvector( %%(fts)s)
|
|
WHERE id = %%(id)s''' % self._table_name
|
|
cur.execute(sql_statement, {
|
|
'id': self.id,
|
|
'fts': FtsMatch.get_fts_value(' '.join(fts_strings)),
|
|
})
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, extra_fields=None):
|
|
o = cls()
|
|
for static_field, value in zip(cls._table_static_fields,
|
|
tuple(row[:len(cls._table_static_fields)])):
|
|
setattr(o, static_field[0], str_encode(value))
|
|
if o.receipt_time:
|
|
o.receipt_time = o.receipt_time.timetuple()
|
|
if o.workflow_data:
|
|
o.workflow_data = pickle_loads(o.workflow_data)
|
|
if o.workflow_roles:
|
|
o.workflow_roles = pickle_loads(o.workflow_roles)
|
|
if o.submission_context:
|
|
o.submission_context = pickle_loads(o.submission_context)
|
|
|
|
o.geolocations = {}
|
|
for i, field in enumerate((cls._formdef.geolocations or {}).keys()):
|
|
value = row[len(cls._table_static_fields)+i]
|
|
if not value:
|
|
continue
|
|
m = re.match(r"\(([^)]+),([^)]+)\)", value)
|
|
o.geolocations[field] = {'lon': float(m.group(1)),
|
|
'lat': float(m.group(2))}
|
|
|
|
o.data = cls._row2obdata(row, cls._formdef)
|
|
if extra_fields:
|
|
# extra fields are tuck at the end
|
|
for i, field_id in enumerate(reversed(extra_fields)):
|
|
o.data[field_id] = row[-(i + 1)]
|
|
pass
|
|
del o._last_update_time
|
|
return o
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
data_fields = ['geoloc_%s' % x for x in (cls._formdef.geolocations or {}).keys()]
|
|
for field in cls._formdef.get_all_fields():
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
data_fields.append(get_field_id(field))
|
|
if field.store_display_value:
|
|
data_fields.append('%s_display' % get_field_id(field))
|
|
if field.store_structured_value:
|
|
data_fields.append('%s_structured' % get_field_id(field))
|
|
return data_fields
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
|
try:
|
|
int(id)
|
|
except (TypeError, ValueError):
|
|
if ignore_errors:
|
|
return None
|
|
else:
|
|
raise KeyError()
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
fields = cls.get_data_fields()
|
|
|
|
potential_comma = ', '
|
|
if not fields:
|
|
potential_comma = ''
|
|
|
|
sql_statement = '''SELECT %s
|
|
%s
|
|
%s
|
|
FROM %s
|
|
WHERE id = %%(id)s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]),
|
|
potential_comma,
|
|
', '.join(fields),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'id': str(id)})
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
cur.close()
|
|
if ignore_errors:
|
|
return None
|
|
raise KeyError()
|
|
cur.close()
|
|
return cls._row2ob(row)
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True, clause=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
where_clauses, parameters, func_clause = parse_clause(clause)
|
|
assert not func_clause
|
|
|
|
if type(value) is int:
|
|
value = str(value)
|
|
|
|
if '%s_array' % index in [x[0] for x in cls._table_static_fields]:
|
|
sql_statement = '''SELECT id FROM %s WHERE %%(value)s = ANY (%s_array)''' % (
|
|
cls._table_name,
|
|
index)
|
|
else:
|
|
sql_statement = '''SELECT id FROM %s WHERE %s = %%(value)s''' % (
|
|
cls._table_name,
|
|
index)
|
|
|
|
if where_clauses:
|
|
sql_statement += ' AND ' + ' AND '.join(where_clauses)
|
|
else:
|
|
parameters = {}
|
|
|
|
parameters.update({'value': value})
|
|
cur.execute(sql_statement, parameters)
|
|
all_ids = [x[0] for x in cur.fetchall()]
|
|
cur.close()
|
|
return all_ids
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def fix_sequences(cls):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
for table_name in (cls._table_name, '%s_evolutions' % cls._table_name):
|
|
sql_statement = '''select max(id) from %s''' % table_name
|
|
cur.execute(sql_statement)
|
|
max_id = cur.fetchone()[0]
|
|
if max_id:
|
|
sql_statement = '''ALTER SEQUENCE %s RESTART %s''' % (
|
|
'%s_id_seq' % table_name, max_id+1)
|
|
cur.execute(sql_statement)
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def rebuild_security(cls):
|
|
formdatas = cls.select(order_by='id')
|
|
conn, cur = get_connection_and_cursor()
|
|
for formdata in formdatas:
|
|
sql_statement = '''UPDATE %s
|
|
SET concerned_roles_array = %%(roles)s,
|
|
actions_roles_array = %%(actions_roles)s
|
|
WHERE id = %%(id)s''' % cls._table_name
|
|
with get_publisher().substitutions.temporary_feed(formdata):
|
|
# formdata is already added to sources list in individual
|
|
# {concerned,actions}_roles but adding it first here will
|
|
# allow cached values to be reused between the properties.
|
|
cur.execute(sql_statement, {
|
|
'id': formdata.id,
|
|
'roles': [str(x) for x in formdata.concerned_roles if x],
|
|
'actions_roles': [str(x) for x in formdata.actions_roles if x]})
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def wipe(cls, drop=False):
|
|
conn, cur = get_connection_and_cursor()
|
|
if drop:
|
|
cur.execute('''DROP TABLE %s_evolutions CASCADE''' % cls._table_name)
|
|
cur.execute('''DROP TABLE %s CASCADE''' % cls._table_name)
|
|
else:
|
|
cur.execute('''DELETE FROM %s_evolutions''' % cls._table_name)
|
|
cur.execute('''DELETE FROM %s''' % cls._table_name)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def do_tracking_code_table(cls):
|
|
do_tracking_code_table()
|
|
|
|
|
|
class SqlFormData(SqlDataMixin, wcs.formdata.FormData):
|
|
pass
|
|
|
|
|
|
class SqlCardData(SqlDataMixin, wcs.carddata.CardData):
|
|
pass
|
|
|
|
|
|
class SqlUser(SqlMixin, wcs.users.User):
|
|
_table_name = 'users'
|
|
_table_static_fields = [
|
|
('id', 'serial'),
|
|
('name', 'varchar'),
|
|
('email', 'varchar'),
|
|
('roles', 'varchar[]'),
|
|
('is_admin', 'bool'),
|
|
('anonymous', 'bool'),
|
|
('name_identifiers', 'varchar[]'),
|
|
('verified_fields', 'varchar[]'),
|
|
('lasso_dump', 'text'),
|
|
('last_seen', 'timestamp'),
|
|
('ascii_name', 'varchar'),
|
|
('deleted_timestamp', 'timestamp'),
|
|
('is_active', 'bool'),
|
|
]
|
|
|
|
id = None
|
|
|
|
def __init__(self, name=None):
|
|
self.name = name
|
|
self.name_identifiers = []
|
|
self.verified_fields = []
|
|
self.roles = []
|
|
|
|
@guard_postgres
|
|
@invalidate_substitution_cache
|
|
def store(self):
|
|
sql_dict = {
|
|
'name': self.name,
|
|
'ascii_name': self.ascii_name,
|
|
'email': self.email,
|
|
'roles': self.roles,
|
|
'is_admin': self.is_admin,
|
|
'anonymous': self.anonymous,
|
|
'name_identifiers': self.name_identifiers,
|
|
'verified_fields': self.verified_fields,
|
|
'lasso_dump': self.lasso_dump,
|
|
'last_seen': None,
|
|
'deleted_timestamp': self.deleted_timestamp,
|
|
'is_active': self.is_active,
|
|
}
|
|
if self.last_seen:
|
|
sql_dict['last_seen'] = datetime.datetime.fromtimestamp(self.last_seen),
|
|
|
|
user_formdef = self.get_formdef()
|
|
if not self.form_data:
|
|
self.form_data = {}
|
|
sql_dict.update(self.get_sql_dict_from_data(self.form_data, user_formdef))
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
if not self.id:
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''INSERT INTO %s (id, %s)
|
|
VALUES (DEFAULT, %s)
|
|
RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
self.id = cur.fetchone()[0]
|
|
else:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.id
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
|
|
fts_strings = []
|
|
if self.name:
|
|
fts_strings.append(self.name)
|
|
fts_strings.append(self.ascii_name)
|
|
if self.email:
|
|
fts_strings.append(self.email)
|
|
if user_formdef and user_formdef.fields:
|
|
for field in user_formdef.fields:
|
|
if not self.form_data.get(field.id):
|
|
continue
|
|
value = None
|
|
if field.key in ('string', 'text', 'email'):
|
|
value = self.form_data.get(field.id)
|
|
elif field.key in ('item', 'items'):
|
|
value = self.form_data.get('%s_display' % field.id)
|
|
if value:
|
|
if isinstance(value, six.string_types):
|
|
fts_strings.append(value)
|
|
elif type(value) in (tuple, list):
|
|
fts_strings.extend(value)
|
|
sql_statement = '''UPDATE %s SET fts = to_tsvector( %%(fts)s)
|
|
WHERE id = %%(id)s''' % self._table_name
|
|
cur.execute(sql_statement, {'id': self.id, 'fts': ' '.join(fts_strings)})
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
o = cls()
|
|
(o.id, o.name, o.email, o.roles, o.is_admin, o.anonymous,
|
|
o.name_identifiers, o.verified_fields, o.lasso_dump,
|
|
o.last_seen, ascii_name, o.deleted_timestamp, o.is_active) = [
|
|
str_encode(x) for x in tuple(row[:13])]
|
|
if o.last_seen:
|
|
o.last_seen = time.mktime(o.last_seen.timetuple())
|
|
if o.roles:
|
|
o.roles = [str(x) for x in o.roles]
|
|
o.form_data = cls._row2obdata(row, cls.get_formdef())
|
|
return o
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
data_fields = []
|
|
for field in cls.get_formdef().get_all_fields():
|
|
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
|
|
if sql_type is None:
|
|
continue
|
|
data_fields.append(get_field_id(field))
|
|
if field.store_display_value:
|
|
data_fields.append('%s_display' % get_field_id(field))
|
|
if field.store_structured_value:
|
|
data_fields.append('%s_structured' % get_field_id(field))
|
|
return data_fields
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def fix_sequences(cls):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''select max(id) from %s''' % cls._table_name
|
|
cur.execute(sql_statement)
|
|
max_id = cur.fetchone()[0]
|
|
if max_id is not None:
|
|
sql_statement = '''ALTER SEQUENCE %s_id_seq RESTART %s''' % (
|
|
cls._table_name, max_id+1)
|
|
cur.execute(sql_statement)
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_users_with_name_identifier(cls, name_identifier):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE deleted_timestamp IS NULL
|
|
AND %%(value)s = ANY(name_identifiers)''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'value': name_identifier})
|
|
objects = cls.get_objects(cur)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return objects
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_users_with_email(cls, email):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE deleted_timestamp IS NULL
|
|
AND email = %%(value)s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'value': email})
|
|
objects = cls.get_objects(cur)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return objects
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def get_users_with_role(cls, role_id):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE deleted_timestamp IS NULL
|
|
AND %%(value)s = ANY(roles)''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'value': str(role_id)})
|
|
objects = cls.get_objects(cur)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return objects
|
|
|
|
|
|
class Session(SqlMixin, wcs.sessions.BasicSession):
|
|
_table_name = 'sessions'
|
|
_table_static_fields = [
|
|
('id', 'varchar'),
|
|
('session_data', 'bytea'),
|
|
]
|
|
_numerical_id = False
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def select_recent(cls, seconds=30*60, **kwargs):
|
|
clause = [GreaterOrEqual('last_update_time', datetime.datetime.now() - datetime.timedelta(seconds=seconds))]
|
|
return cls.select(clause=clause, **kwargs)
|
|
|
|
@guard_postgres
|
|
def store(self):
|
|
sql_dict = {
|
|
'id': self.id,
|
|
'session_data': bytearray(pickle.dumps(self.__dict__, protocol=2)),
|
|
# the other fields are stored to run optimized SELECT() against the
|
|
# table, they are ignored when loading the data.
|
|
'name_identifier': self.name_identifier,
|
|
'visiting_objects_keys': list(self.visiting_objects.keys()) if getattr(self, 'visiting_objects') else None,
|
|
'last_update_time': datetime.datetime.now(),
|
|
}
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
o = cls.__new__(cls)
|
|
cls.id = str_encode(row[0])
|
|
session_data = pickle_loads(row[1])
|
|
for k, v in session_data.items():
|
|
setattr(o, k, v)
|
|
return o
|
|
|
|
@classmethod
|
|
def get_sessions_for_saml(cls, name_identifier=Ellipsis, *args, **kwargs):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE name_identifier = %%(value)s''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'value': name_identifier})
|
|
objects = cls.get_objects(cur)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return objects
|
|
|
|
@classmethod
|
|
def get_sessions_with_visited_object(cls, object_key):
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
sql_statement = '''SELECT %s
|
|
FROM %s
|
|
WHERE %%(value)s = ANY(visiting_objects_keys)
|
|
AND last_update_time > (now() - interval '30 minutes')
|
|
''' % (
|
|
', '.join([x[0] for x in cls._table_static_fields]
|
|
+ cls.get_data_fields()),
|
|
cls._table_name)
|
|
cur.execute(sql_statement, {'value': object_key})
|
|
objects = cls.get_objects(cur)
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return objects
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
return []
|
|
|
|
|
|
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
|
|
_table_name = 'tracking_codes'
|
|
_table_static_fields = [
|
|
('id', 'varchar'),
|
|
('formdef_id', 'varchar'),
|
|
('formdata_id', 'varchar'),
|
|
]
|
|
_numerical_id = False
|
|
|
|
id = None
|
|
|
|
@classmethod
|
|
def get(cls, id, **kwargs):
|
|
return super(TrackingCode, cls).get(id.upper(), **kwargs)
|
|
|
|
@guard_postgres
|
|
@invalidate_substitution_cache
|
|
def store(self):
|
|
sql_dict = {
|
|
'id': self.id,
|
|
'formdef_id': self.formdef_id,
|
|
'formdata_id': self.formdata_id
|
|
}
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
if not self.id:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.get_new_id()
|
|
sql_statement = '''INSERT INTO %s (%s)
|
|
VALUES (%s)
|
|
RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
while True:
|
|
try:
|
|
cur.execute(sql_statement, sql_dict)
|
|
except psycopg2.IntegrityError:
|
|
conn.rollback()
|
|
sql_dict['id'] = self.get_new_id()
|
|
else:
|
|
break
|
|
self.id = str_encode(cur.fetchone()[0])
|
|
else:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.id
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
raise AssertionError()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
o = cls()
|
|
(o.id, o.formdef_id, o.formdata_id) = [str_encode(x) for x in tuple(row[:3])]
|
|
return o
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
return []
|
|
|
|
|
|
class CustomView(SqlMixin, wcs.custom_views.CustomView):
|
|
_table_name = 'custom_views'
|
|
_table_static_fields = [
|
|
('id', 'varchar'),
|
|
('title', 'varchar'),
|
|
('slug', 'varchar'),
|
|
('user_id', 'varchar'),
|
|
('visibility', 'varchar'),
|
|
('formdef_type', 'varchar'),
|
|
('formdef_id', 'varchar'),
|
|
('order_by', 'varchar'),
|
|
('columns', 'jsonb'),
|
|
('filters', 'jsonb'),
|
|
]
|
|
|
|
@guard_postgres
|
|
@invalidate_substitution_cache
|
|
def store(self):
|
|
self.ensure_slug()
|
|
sql_dict = {
|
|
'id': self.id,
|
|
'title': self.title,
|
|
'slug': self.slug,
|
|
'user_id': self.user_id,
|
|
'visibility': self.visibility,
|
|
'formdef_type': self.formdef_type,
|
|
'formdef_id': self.formdef_id,
|
|
'order_by': self.order_by,
|
|
'columns': self.columns,
|
|
'filters': self.filters,
|
|
}
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
if not self.id:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.get_new_id()
|
|
sql_statement = '''INSERT INTO %s (%s)
|
|
VALUES (%s)
|
|
RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
while True:
|
|
try:
|
|
cur.execute(sql_statement, sql_dict)
|
|
except psycopg2.IntegrityError:
|
|
conn.rollback()
|
|
sql_dict['id'] = self.get_new_id()
|
|
else:
|
|
break
|
|
self.id = str_encode(cur.fetchone()[0])
|
|
else:
|
|
column_names = sql_dict.keys()
|
|
sql_dict['id'] = self.id
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
raise AssertionError()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
o = cls()
|
|
for field, value in zip(cls._table_static_fields, tuple(row)):
|
|
if field[1] == 'varchar':
|
|
setattr(o, field[0], str_encode(value))
|
|
elif field[1] == 'jsonb':
|
|
setattr(o, field[0], value)
|
|
return o
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
return []
|
|
|
|
|
|
class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
|
|
_table_name = 'snapshots'
|
|
_table_static_fields = [
|
|
('id', 'serial'),
|
|
('object_type', 'varchar'),
|
|
('object_id', 'varchar'),
|
|
('timestamp', 'timestamptz'),
|
|
('user_id', 'varchar'),
|
|
('comment', 'text'),
|
|
('serialization', 'text'),
|
|
('label', 'varchar'),
|
|
]
|
|
|
|
@guard_postgres
|
|
@invalidate_substitution_cache
|
|
def store(self):
|
|
sql_dict = {x: getattr(self, x) for x, y in self._table_static_fields}
|
|
|
|
conn, cur = get_connection_and_cursor()
|
|
if not self.id:
|
|
column_names = [x for x in sql_dict.keys() if x != 'id']
|
|
sql_statement = '''INSERT INTO %s (%s)
|
|
VALUES (%s)
|
|
RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(column_names),
|
|
', '.join(['%%(%s)s' % x for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
self.id = cur.fetchone()[0]
|
|
else:
|
|
column_names = sql_dict.keys()
|
|
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
|
|
self._table_name,
|
|
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]))
|
|
cur.execute(sql_statement, sql_dict)
|
|
if cur.fetchone() is None:
|
|
raise AssertionError()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
@classmethod
|
|
def select_object_history(cls, obj, clause=None):
|
|
return cls.select([
|
|
Equal('object_type', obj.xml_root_node),
|
|
Equal('object_id', obj.id)] + (clause or []),
|
|
order_by='-timestamp')
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
o = cls()
|
|
for field, value in zip(cls._table_static_fields, tuple(row)):
|
|
if field[1] in ('serial', 'timestamptz'):
|
|
setattr(o, field[0], value)
|
|
elif field[1] in ('varchar', 'text'):
|
|
setattr(o, field[0], str_encode(value))
|
|
return o
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
return []
|
|
|
|
@classmethod
|
|
def get_latest(cls, object_type, object_id):
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_statement = '''SELECT id FROM snapshots
|
|
WHERE object_type = %(object_type)s
|
|
AND object_id = %(object_id)s
|
|
ORDER BY timestamp DESC
|
|
LIMIT 1'''
|
|
cur.execute(sql_statement, {'object_type': object_type, 'object_id': object_id})
|
|
row = cur.fetchone()
|
|
conn.commit()
|
|
cur.close()
|
|
if row is None:
|
|
return None
|
|
return cls.get(row[0])
|
|
|
|
|
|
class classproperty(object):
|
|
def __init__(self, f):
|
|
self.f = f
|
|
|
|
def __get__(self, obj, owner):
|
|
return self.f(owner)
|
|
|
|
|
|
class AnyFormData(SqlMixin):
|
|
_table_name = 'wcs_all_forms'
|
|
__table_static_fields = []
|
|
_formdef_cache = {}
|
|
|
|
@classproperty
|
|
def _table_static_fields(cls):
|
|
if cls.__table_static_fields:
|
|
return cls.__table_static_fields
|
|
from wcs.formdef import FormDef
|
|
fake_formdef = FormDef()
|
|
common_fields = get_view_fields(fake_formdef)
|
|
cls.__table_static_fields = [(x[1], x[0]) for x in common_fields]
|
|
cls.__table_static_fields.append(('criticality_level', 'criticality_level'))
|
|
cls.__table_static_fields.append(('geoloc_base_x', 'geoloc_base_x'))
|
|
cls.__table_static_fields.append(('geoloc_base_y', 'geoloc_base_y'))
|
|
cls.__table_static_fields.append(('concerned_roles_array', 'concerned_roles_array'))
|
|
cls.__table_static_fields.append(('anonymised', 'anonymised'))
|
|
return cls.__table_static_fields
|
|
|
|
@classmethod
|
|
def get_data_fields(cls):
|
|
return []
|
|
|
|
@classmethod
|
|
def get_objects(cls, *args, **kwargs):
|
|
cls._formdef_cache = {}
|
|
return super(AnyFormData, cls).get_objects(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def _row2ob(cls, row, **kwargs):
|
|
formdef_id = row[1]
|
|
from wcs.formdef import FormDef
|
|
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
|
|
o = formdef.data_class()()
|
|
for static_field, value in zip(cls._table_static_fields,
|
|
tuple(row[:len(cls._table_static_fields)])):
|
|
setattr(o, static_field[0], str_encode(value))
|
|
# [CRITICALITY_2] transform criticality_level back to the expected
|
|
# range (see [CRITICALITY_1])
|
|
levels = len(formdef.workflow.criticality_levels or [0])
|
|
o.criticality_level = levels + o.criticality_level
|
|
# convert back unstructured geolocation to the 'native' formdata format.
|
|
if o.geoloc_base_x is not None:
|
|
o.geolocations = {'base': {'lon': o.geoloc_base_x, 'lat': o.geoloc_base_y}}
|
|
return o
|
|
|
|
@classmethod
|
|
@guard_postgres
|
|
def load_all_evolutions(cls, formdatas):
|
|
classes = {}
|
|
for formdata in formdatas:
|
|
if not formdata._table_name in classes:
|
|
classes[formdata._table_name] = []
|
|
classes[formdata._table_name].append(formdata)
|
|
for formdatas in classes.values():
|
|
formdatas[0].load_all_evolutions(formdatas)
|
|
|
|
|
|
def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None):
|
|
clause = [NotNull('receipt_time')]
|
|
table_name = 'wcs_all_forms'
|
|
if criterias:
|
|
for criteria in criterias:
|
|
if criteria.__class__.__name__ == 'Equal' and \
|
|
criteria.attribute == 'formdef_id':
|
|
# if there's a formdef_id specified, switch to using the
|
|
# specific table so we have access to all fields
|
|
from wcs.formdef import FormDef
|
|
table_name = get_formdef_table_name(FormDef.get(criteria.value))
|
|
continue
|
|
clause.append(criteria)
|
|
if period_start:
|
|
clause.append(GreaterOrEqual('receipt_time', period_start))
|
|
if period_end:
|
|
clause.append(LessOrEqual('receipt_time', period_end))
|
|
where_clauses, params, func_clause = parse_clause(clause)
|
|
parameters.update(params)
|
|
statement = ' FROM %s ' % table_name
|
|
statement += ' WHERE ' + ' AND '.join(where_clauses)
|
|
return statement
|
|
|
|
|
|
@guard_postgres
|
|
def get_actionable_counts(user_roles):
|
|
conn, cur = get_connection_and_cursor()
|
|
criterias = [Equal('is_at_endpoint', False),
|
|
Intersects('actions_roles_array', user_roles)]
|
|
where_clauses, parameters, func_clause = parse_clause(criterias)
|
|
statement = '''SELECT formdef_id, COUNT(*)
|
|
FROM wcs_all_forms
|
|
WHERE %s
|
|
GROUP BY formdef_id''' % ' AND '.join(where_clauses)
|
|
cur.execute(statement, parameters)
|
|
counts = {str(x): y for x, y in cur.fetchall()}
|
|
conn.commit()
|
|
cur.close()
|
|
return counts
|
|
|
|
|
|
@guard_postgres
|
|
def get_total_counts(user_roles):
|
|
conn, cur = get_connection_and_cursor()
|
|
criterias = [
|
|
Intersects('concerned_roles_array', user_roles),
|
|
]
|
|
where_clauses, parameters, func_clause = parse_clause(criterias)
|
|
statement = '''SELECT formdef_id, COUNT(*)
|
|
FROM wcs_all_forms
|
|
WHERE %s
|
|
GROUP BY formdef_id''' % ' AND '.join(where_clauses)
|
|
cur.execute(statement, parameters)
|
|
counts = {str(x): y for x, y in cur.fetchall()}
|
|
conn.commit()
|
|
cur.close()
|
|
return counts
|
|
|
|
|
|
@guard_postgres
|
|
def get_weekday_totals(period_start=None, period_end=None, criterias=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
statement = '''SELECT DATE_PART('dow', receipt_time) AS weekday, COUNT(*)'''
|
|
parameters = {}
|
|
statement += get_period_query(period_start, period_end, criterias, parameters)
|
|
statement += ' GROUP BY weekday ORDER BY weekday'''
|
|
cur.execute(statement, parameters)
|
|
|
|
result = cur.fetchall()
|
|
result = [(int(x), y) for x, y in result]
|
|
coverage = [x[0] for x in result]
|
|
for weekday in range(7):
|
|
if not weekday in coverage:
|
|
result.append((weekday, 0))
|
|
result.sort()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return result
|
|
|
|
|
|
@guard_postgres
|
|
def get_hour_totals(period_start=None, period_end=None, criterias=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
statement = '''SELECT DATE_PART('hour', receipt_time) AS hour, COUNT(*)'''
|
|
parameters = {}
|
|
statement += get_period_query(period_start, period_end, criterias, parameters)
|
|
statement += ' GROUP BY hour ORDER BY hour'
|
|
cur.execute(statement, parameters)
|
|
|
|
result = cur.fetchall()
|
|
result = [(int(x), y) for x, y in result]
|
|
|
|
coverage = [x[0] for x in result]
|
|
for hour in range(24):
|
|
if not hour in coverage:
|
|
result.append((hour, 0))
|
|
result.sort()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return result
|
|
|
|
|
|
@guard_postgres
|
|
def get_monthly_totals(period_start=None, period_end=None, criterias=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
statement = '''SELECT DATE_TRUNC('month', receipt_time) AS month, COUNT(*) '''
|
|
parameters = {}
|
|
statement += get_period_query(period_start, period_end, criterias, parameters)
|
|
statement += ' GROUP BY month ORDER BY month'''
|
|
cur.execute(statement, parameters)
|
|
|
|
raw_result = cur.fetchall()
|
|
result = [('%d-%02d' % x.timetuple()[:2], y) for x, y in raw_result]
|
|
if result:
|
|
coverage = [x[0] for x in result]
|
|
current_month = raw_result[0][0]
|
|
last_month = raw_result[-1][0]
|
|
while current_month < last_month:
|
|
label = '%d-%02d' % current_month.timetuple()[:2]
|
|
if not label in coverage:
|
|
result.append((label, 0))
|
|
current_month = current_month + datetime.timedelta(days=31)
|
|
current_month = current_month - datetime.timedelta(days=current_month.day-1)
|
|
result.sort()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return result
|
|
|
|
|
|
@guard_postgres
|
|
def get_yearly_totals(period_start=None, period_end=None, criterias=None):
|
|
conn, cur = get_connection_and_cursor()
|
|
statement = '''SELECT DATE_TRUNC('year', receipt_time) AS year, COUNT(*)'''
|
|
parameters = {}
|
|
statement += get_period_query(period_start, period_end, criterias, parameters)
|
|
statement += ' GROUP BY year ORDER BY year'
|
|
cur.execute(statement, parameters)
|
|
|
|
raw_result = cur.fetchall()
|
|
result = [(str(x.year), y) for x, y in raw_result]
|
|
if result:
|
|
coverage = [x[0] for x in result]
|
|
current_year = raw_result[0][0]
|
|
last_year = raw_result[-1][0]
|
|
while current_year < last_year:
|
|
label = str(current_year.year)
|
|
if not label in coverage:
|
|
result.append((label, 0))
|
|
current_year = current_year + datetime.timedelta(days=366)
|
|
result.sort()
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
return result
|
|
|
|
|
|
SQL_LEVEL = 42
|
|
|
|
|
|
def migrate_global_views(conn, cur):
|
|
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_name = %s''', ('wcs_all_forms',))
|
|
existing_fields = set([x[0] for x in cur.fetchall()])
|
|
if 'formdef_id' not in existing_fields:
|
|
drop_global_views(conn, cur)
|
|
do_global_views(conn, cur)
|
|
|
|
|
|
@guard_postgres
|
|
def get_sql_level(conn, cur):
|
|
do_meta_table(conn, cur, insert_current_sql_level=False)
|
|
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', ('sql_level', ))
|
|
sql_level = int(cur.fetchone()[0])
|
|
return sql_level
|
|
|
|
|
|
@guard_postgres
|
|
def is_reindex_needed(index, conn, cur):
|
|
do_meta_table(conn, cur, insert_current_sql_level=False)
|
|
key_name = 'reindex_%s' % index
|
|
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name, ))
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
cur.execute('''INSERT INTO wcs_meta (id, key, value)
|
|
VALUES (DEFAULT, %s, %s)''', (key_name, 'no'))
|
|
return False
|
|
return row[0] == 'needed'
|
|
|
|
|
|
@guard_postgres
|
|
def set_reindex(index, value, conn=None, cur=None):
|
|
own_conn = False
|
|
if not conn:
|
|
own_conn = True
|
|
conn, cur = get_connection_and_cursor()
|
|
do_meta_table(conn, cur, insert_current_sql_level=False)
|
|
key_name = 'reindex_%s' % index
|
|
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name, ))
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
cur.execute('''INSERT INTO wcs_meta (id, key, value)
|
|
VALUES (DEFAULT, %s, %s)''', (key_name, value))
|
|
else:
|
|
cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', (
|
|
value, key_name))
|
|
if own_conn:
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
def migrate_views(conn, cur):
|
|
drop_views(None, conn, cur)
|
|
from wcs.formdef import FormDef
|
|
from wcs.carddef import CardDef
|
|
for formdef in FormDef.select() + CardDef.select():
|
|
# make sure all formdefs have up-to-date views
|
|
do_formdef_tables(formdef, conn=conn, cur=cur, rebuild_views=True, rebuild_global_views=False)
|
|
migrate_global_views(conn, cur)
|
|
|
|
|
|
@guard_postgres
|
|
def migrate():
|
|
conn, cur = get_connection_and_cursor()
|
|
sql_level = get_sql_level(conn, cur)
|
|
if sql_level < 0:
|
|
# fake code to help in tetsting the error code path.
|
|
raise RuntimeError()
|
|
if sql_level < 1: # 1: introduction of tracking_code table
|
|
do_tracking_code_table()
|
|
if sql_level < 38:
|
|
# 2: introduction of formdef_id in views
|
|
# 5: add concerned_roles_array, is_at_endpoint and fts to views
|
|
# 7: add backoffice_submission to tables
|
|
# 8: add submission_context to tables
|
|
# 9: add last_update_time to views
|
|
# 10: add submission_channel to tables
|
|
# 11: add formdef_name and user_name to views
|
|
# 13: add backoffice_submission to views
|
|
# 14: add criticality_level to tables & views
|
|
# 15: add geolocation to formdata
|
|
# 19: add geolocation to views
|
|
# 20: remove user hash stuff
|
|
# 22: rebuild views
|
|
# 26: add digest to formdata
|
|
# 27: add last_jump_datetime in evolutions tables
|
|
# 31: add user_label to formdata
|
|
# 33: add anonymised field to global view
|
|
# 38: extract submission_agent_id to its own column
|
|
migrate_views(conn, cur)
|
|
if sql_level < 40:
|
|
# 3: introduction of _structured for user fields
|
|
# 4: removal of identification_token
|
|
# 12: (first part) add fts to users
|
|
# 16: add verified_fields to users
|
|
# 21: (first part) add ascii_name to users
|
|
# 39: add deleted_timestamp
|
|
# 40: add is_active to users
|
|
do_user_table()
|
|
if sql_level < 6:
|
|
# 6: add actions_roles_array to tables and views
|
|
from wcs.formdef import FormDef
|
|
migrate_views(conn, cur)
|
|
for formdef in FormDef.select():
|
|
formdef.data_class().rebuild_security()
|
|
if sql_level < 23:
|
|
# 12: (second part), store fts in existing rows
|
|
# 21: (second part), store ascii_name of users
|
|
# 23: (first part), use misc.simplify() over full text queries
|
|
set_reindex('user', 'needed', conn=conn, cur=cur)
|
|
if sql_level < 41:
|
|
# 17: store last_update_time in tables
|
|
# 18: add user name to full-text search index
|
|
# 21: (third part), add user ascii_names to full-text index
|
|
# 23: (second part) use misc.simplify() over full text queries
|
|
# 28: add display id and formdef name to full-text index
|
|
# 29: add evolution parts to full-text index
|
|
# 31: add user_label to formdata
|
|
# 38: extract submission_agent_id to its own column
|
|
# 41: update full text normalization
|
|
set_reindex('formdata', 'needed', conn=conn, cur=cur)
|
|
if sql_level < 36:
|
|
from wcs.formdef import FormDef
|
|
# 24: add index on evolution(formdata_id)
|
|
# 35: add indexes on formdata(receipt_time) and formdata(anonymised)
|
|
# 36: add index on formdata(user_id)
|
|
for formdef in FormDef.select():
|
|
do_formdef_indexes(formdef, created=False, conn=conn, cur=cur)
|
|
if sql_level < 32:
|
|
# 25: create session_table
|
|
# 32: add last_update_time column to session table
|
|
do_session_table()
|
|
if sql_level < 37:
|
|
# 37: create custom_views tabl
|
|
do_custom_views_table()
|
|
if sql_level < 30:
|
|
# 30: actually remove evo.who on anonymised formdatas
|
|
from wcs.formdef import FormDef
|
|
for formdef in FormDef.select():
|
|
for formdata in formdef.data_class().select_iterator(clause=[NotNull('anonymised')]):
|
|
if formdata.evolution:
|
|
for evo in formdata.evolution:
|
|
evo.who = None
|
|
formdata.store()
|
|
if sql_level < 42:
|
|
# 42: create snapshots table
|
|
do_snapshots_table()
|
|
|
|
cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', (
|
|
str(SQL_LEVEL), 'sql_level'))
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
|
|
|
|
@guard_postgres
|
|
def reindex():
|
|
conn, cur = get_connection_and_cursor()
|
|
|
|
if is_reindex_needed('user', conn=conn, cur=cur):
|
|
for user in SqlUser.select(iterator=True):
|
|
user.store()
|
|
set_reindex('user', 'done', conn=conn, cur=cur)
|
|
|
|
from wcs.formdef import FormDef
|
|
if is_reindex_needed('formdata', conn=conn, cur=cur):
|
|
# load and store all formdatas
|
|
for formdef in FormDef.select():
|
|
for formdata in formdef.data_class().select(iterator=True):
|
|
formdata.migrate()
|
|
formdata.store()
|
|
set_reindex('formdata', 'done', conn=conn, cur=cur)
|
|
|
|
conn.commit()
|
|
cur.close()
|