# w.c.s. - web application for online forms # Copyright (C) 2005-2012 Entr'ouvert # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, see . import psycopg2 import psycopg2.extensions import datetime import time import re import cPickle from quixote import get_publisher import qommon from qommon.storage import _take, parse_clause as parse_storage_clause from qommon.substitution import invalidate_substitution_cache from qommon import get_cfg import wcs.categories import wcs.formdata import wcs.tracking_code import wcs.users # enable psycogp2 unicode mode, this will fetch postgresql varchar/text columns # as unicode objects psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) SQL_TYPE_MAPPING = { 'title': None, 'subtitle': None, 'comment': None, 'page': None, 'text': 'text', 'bool': 'boolean', 'file': 'bytea', 'date': 'date', 'items': 'text[]', 'table': 'text[][]', 'table-select': 'text[][]', 'tablerows': 'text[][]', # mapping of dicts 'ranked-items': 'text[][]', 'password': 'text[][]', } class Criteria(qommon.storage.Criteria): def __init__(self, attribute, value, **kwargs): self.attribute = attribute.replace('-', '_') self.value = value def as_sql(self): return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value)) def as_sql_param(self): if isinstance(self.value, time.struct_time): value = datetime.datetime.fromtimestamp(time.mktime(self.value)) else: value = self.value return {'c%s' % id(self.value): value} class Less(Criteria): sql_op = '<' class Greater(Criteria): sql_op = '>' class Equal(Criteria): sql_op = '=' def as_sql(self): if self.value in ([], ()): return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute return super(Equal, self).as_sql() class LessOrEqual(Criteria): sql_op = '<=' class GreaterOrEqual(Criteria): sql_op = '>=' class NotEqual(Criteria): sql_op = '!=' class Contains(Criteria): sql_op = 'IN' def as_sql(self): if not self.value: return 'FALSE' return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value)) def as_sql_param(self): return {'c%s' % id(self.value): tuple(self.value)} class NotContains(Contains): sql_op = 'NOT IN' def as_sql(self): if not self.value: return 'TRUE' return super(NotContains, self).as_sql() class NotNull(Criteria): sql_op = 'IS NOT NULL' def __init__(self, attribute): self.attribute = attribute def as_sql(self): return '%s %s' % (self.attribute, self.sql_op) def as_sql_param(self): return {} class Null(Criteria): sql_op = 'IS NULL' def __init__(self, attribute): self.attribute = attribute def as_sql(self): return '%s %s' % (self.attribute, self.sql_op) def as_sql_param(self): return {} class Or(Criteria): def __init__(self, criterias, **kwargs): self.criterias = [] for element in criterias: sql_class = globals().get(element.__class__.__name__) sql_element = sql_class(**element.__dict__) self.criterias.append(sql_element) def as_sql(self): return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias]) def as_sql_param(self): d = {} for criteria in self.criterias: d.update(criteria.as_sql_param()) return d class And(Criteria): def __init__(self, criterias, **kwargs): self.criterias = [] for element in criterias: sql_class = globals().get(element.__class__.__name__) sql_element = sql_class(**element.__dict__) self.criterias.append(sql_element) def as_sql(self): return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias]) def as_sql_param(self): d = {} for criteria in self.criterias: d.update(criteria.as_sql_param()) return d class Intersects(Criteria): def as_sql(self): if not self.value: return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute else: return '%s && %%(c%s)s' % (self.attribute, id(self.value)) def as_sql_param(self): return {'c%s' % id(self.value): list(self.value)} class ILike(Criteria): def __init__(self, attribute, value, **kwargs): super(ILike, self).__init__(attribute, value, **kwargs) self.value = '%' + self.value + '%' def as_sql(self): return '%s ILIKE %%(c%s)s' % (self.attribute, id(self.value)) class FtsMatch(Criteria): def __init__(self, value): self.value = qommon.misc.simplify(value, space=' ') def as_sql(self): return 'fts @@ plainto_tsquery(%%(c%s)s)' % id(self.value) def get_name_as_sql_identifier(name): name = qommon.misc.simplify(name) for char in '<>|{}!?^*+/=\'': # forbidden chars name = name.replace(char, '') name = name.replace('-', '_') return name def get_field_id(field): return 'f' + str(field.id).replace('-', '_').lower() def parse_clause(clause): # returns a three-elements tuple with: # - a list of SQL 'WHERE' clauses # - a dict for query parameters # - a callable, or None if all clauses have been successfully translated if clause is None: return ([], None, None) if callable(clause): # already a callable return ([], None, clause) # create 'WHERE' clauses func_clauses = [] where_clauses = [] parameters = {} for element in clause: if callable(element): func_clauses.append(element) else: sql_class = globals().get(element.__class__.__name__) if sql_class: sql_element = sql_class(**element.__dict__) where_clauses.append(sql_element.as_sql()) parameters.update(sql_element.as_sql_param()) else: func_clauses.append(element.build_lambda()) if func_clauses: return (where_clauses, parameters, parse_storage_clause(func_clauses)) else: return (where_clauses, parameters, None) def str_encode(value): if isinstance(value, list): return [str_encode(x) for x in value] if not isinstance(value, unicode): return value return value.encode(get_publisher().site_charset) def site_unicode(value): if not isinstance(value, basestring): value = unicode(value) if isinstance(value, unicode): return value return unicode(value, get_publisher().site_charset) def get_connection(new=False): if new: cleanup_connection() if not hasattr(get_publisher(), 'pgconn') or get_publisher().pgconn is None: postgresql_cfg = {} for param in ('database', 'user', 'password', 'host', 'port'): value = get_cfg('postgresql', {}).get(param) if value: postgresql_cfg[param] = value try: get_publisher().pgconn = psycopg2.connect(**postgresql_cfg) except psycopg2.Error: if new: raise get_publisher().pgconn = None else: cur = get_publisher().pgconn.cursor() cur.execute('SHOW server_version_num') get_publisher().pg_version = int(cur.fetchone()[0]) cur.close() return get_publisher().pgconn def cleanup_connection(): if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None: get_publisher().pgconn.close() get_publisher().pgconn = None def get_connection_and_cursor(new=False): conn = get_connection(new=new) try: cur = conn.cursor() except psycopg2.InterfaceError: # may be postgresql was restarted in between conn = get_connection(new=True) cur = conn.cursor() return (conn, cur) def get_formdef_table_name(formdef): # PostgreSQL limits identifier length to 63 bytes # # The system uses no more than NAMEDATALEN-1 bytes of an identifier; # longer names can be written in commands, but they will be truncated. # By default, NAMEDATALEN is 64 so the maximum identifier length is # 63 bytes. If this limit is problematic, it can be raised by changing # the NAMEDATALEN constant in src/include/pg_config_manual.h. # # as we have to know our table names, we crop the names here, and to an # extent that allows suffixes (like _evolution) to be added. assert formdef.id is not None if hasattr(formdef, 'table_name') and formdef.table_name: return formdef.table_name formdef.table_name = 'formdata_%s_%s' % (formdef.id, get_name_as_sql_identifier(formdef.url_name)[:30]) formdef.store() return formdef.table_name def get_formdef_new_id(id_start): new_id = id_start conn, cur = get_connection_and_cursor() while True: cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE %s''', ('formdata\\_%s\\_%%' % new_id,)) if cur.fetchone()[0] == 0: break new_id += 1 conn.commit() cur.close() return new_id def formdef_wipe(): conn, cur = get_connection_and_cursor() cur.execute('''SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE %s''', ('formdata\\_%%\\_%%',)) for table_name in [x[0] for x in cur.fetchall()]: cur.execute('''DROP TABLE %s CASCADE''' % table_name) conn.commit() cur.close() def get_formdef_view_name(formdef): return 'wcs_view_%s_%s' % (formdef.id, get_name_as_sql_identifier(formdef.url_name)[:40]) def guard_postgres(func): def f(*args, **kwargs): try: return func(*args, **kwargs) except psycopg2.Error: get_connection().rollback() raise return f @guard_postgres def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=True): if formdef.id is None: return [] if getattr(formdef, 'fields', None) is Ellipsis: # don't touch tables for lightweight objects return [] own_conn = False if not conn: own_conn = True conn, cur = get_connection_and_cursor() table_name = get_formdef_table_name(formdef) cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) if cur.fetchone()[0] == 0: cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY, user_id varchar, receipt_time timestamp, anonymised timestamptz, status varchar, page_no varchar, workflow_data bytea, id_display varchar )''' % table_name) cur.execute('''CREATE TABLE %s_evolutions (id serial PRIMARY KEY, who varchar, status varchar, time timestamp, last_jump_datetime timestamp, comment text, parts bytea, formdata_id integer REFERENCES %s (id) ON DELETE CASCADE)''' % ( table_name, table_name)) do_formdef_indexes(formdef, created=True, conn=conn, cur=cur) cur.execute('''SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) existing_fields = set([x[0] for x in cur.fetchall()]) needed_fields = set(['id', 'user_id', 'receipt_time', 'status', 'workflow_data', 'id_display', 'fts', 'page_no', 'anonymised', 'workflow_roles', 'workflow_roles_array', 'concerned_roles_array', 'tracking_code', 'actions_roles_array', 'backoffice_submission', 'submission_context', 'submission_channel', 'criticality_level', 'last_update_time', 'digest', 'user_label']) # migrations if not 'fts' in existing_fields: # full text search cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name) cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % ( table_name, table_name)) if not 'workflow_roles' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles bytea''' % table_name) cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles_array text[]''' % table_name) if not 'concerned_roles_array' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN concerned_roles_array text[]''' % table_name) if not 'actions_roles_array' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN actions_roles_array text[]''' % table_name) if not 'page_no' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN page_no varchar''' % table_name) if not 'anonymised' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN anonymised timestamptz''' % table_name) if not 'tracking_code' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN tracking_code varchar''' % table_name) if not 'backoffice_submission' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN backoffice_submission boolean''' % table_name) if not 'submission_context' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN submission_context bytea''' % table_name) if not 'submission_channel' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN submission_channel varchar''' % table_name) if not 'criticality_level' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN criticality_level integer NOT NULL DEFAULT(0)''' % table_name) if not 'last_update_time' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp''' % table_name) if not 'digest' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN digest varchar''' % table_name) if not 'user_label' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN user_label varchar''' % table_name) # add new fields for field in formdef.get_all_fields(): assert field.id is not None sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue needed_fields.add(get_field_id(field)) if get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % ( table_name, get_field_id(field), sql_type)) if field.store_display_value: needed_fields.add('%s_display' % get_field_id(field)) if '%s_display' % get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % ( table_name, '%s_display' % get_field_id(field))) if field.store_structured_value: needed_fields.add('%s_structured' % get_field_id(field)) if '%s_structured' % get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % ( table_name, '%s_structured' % get_field_id(field))) for field in (formdef.geolocations or {}).keys(): column_name = 'geoloc_%s' % field needed_fields.add(column_name) if column_name not in existing_fields: cur.execute('ALTER TABLE %s ADD COLUMN %s %s''' % ( table_name, column_name, 'POINT')) # delete obsolete fields for field in (existing_fields - needed_fields): cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field)) # migrations on _evolutions table cur.execute('''SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = '%s_evolutions' ''' % table_name) evo_existing_fields = set([x[0] for x in cur.fetchall()]) if 'last_jump_datetime' not in evo_existing_fields: cur.execute('''ALTER TABLE %s_evolutions ADD COLUMN last_jump_datetime timestamp''' % table_name) if rebuild_views or len(existing_fields - needed_fields): # views may have been dropped when dropping columns, so we recreate # them even if not asked to. redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views) if own_conn: conn.commit() cur.close() actions = [] if not 'concerned_roles_array' in existing_fields: actions.append('rebuild_security') elif not 'actions_roles_array' in existing_fields: actions.append('rebuild_security') if not 'tracking_code' in existing_fields: # if tracking code has just been added to the table we need to make # sure the tracking code table does exist. actions.append('do_tracking_code_table') return actions def do_formdef_indexes(formdef, created, conn, cur, concurrently=False): table_name = get_formdef_table_name(formdef) evolutions_table_name = table_name + '_evolutions' existing_indexes = set() if not created: cur.execute('''SELECT indexname FROM pg_indexes WHERE schemaname = 'public' AND tablename IN (%s, %s)''', (table_name, evolutions_table_name)) existing_indexes = set([x[0] for x in cur.fetchall()]) create_index = 'CREATE INDEX' if concurrently: create_index = 'CREATE INDEX CONCURRENTLY' if not evolutions_table_name + '_fid' in existing_indexes: cur.execute('''%s %s_fid ON %s (formdata_id)''' % ( create_index, evolutions_table_name, evolutions_table_name)) @guard_postgres def do_user_table(): conn, cur = get_connection_and_cursor() table_name = 'users' cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) if cur.fetchone()[0] == 0: cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY, name varchar, ascii_name varchar, email varchar, roles text[], is_admin bool, anonymous bool, verified_fields text[], name_identifiers text[], lasso_dump text, last_seen timestamp)''' % table_name) cur.execute('''SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) existing_fields = set([x[0] for x in cur.fetchall()]) needed_fields = set(['id', 'name', 'email', 'roles', 'is_admin', 'anonymous', 'name_identifiers', 'verified_fields', 'lasso_dump', 'last_seen', 'fts', 'ascii_name']) from admin.settings import UserFieldsFormDef formdef = UserFieldsFormDef() for field in formdef.get_all_fields(): sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue needed_fields.add(get_field_id(field)) if get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % ( table_name, get_field_id(field), sql_type)) if field.store_display_value: needed_fields.add('%s_display' % get_field_id(field)) if '%s_display' % get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % ( table_name, '%s_display' % get_field_id(field))) if field.store_structured_value: needed_fields.add('%s_structured' % get_field_id(field)) if '%s_structured' % get_field_id(field) not in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % ( table_name, '%s_structured' % get_field_id(field))) # migrations if not 'fts' in existing_fields: # full text search cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name) cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % ( table_name, table_name)) if not 'verified_fields' in existing_fields: cur.execute('ALTER TABLE %s ADD COLUMN verified_fields text[]' % table_name) if not 'ascii_name' in existing_fields: cur.execute('ALTER TABLE %s ADD COLUMN ascii_name varchar' % table_name) # delete obsolete fields for field in (existing_fields - needed_fields): cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() try: cur.execute('CREATE INDEX users_name_idx ON users (name)') conn.commit() except psycopg2.ProgrammingError: conn.rollback() cur.close() def do_tracking_code_table(): conn, cur = get_connection_and_cursor() table_name = 'tracking_codes' cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) if cur.fetchone()[0] == 0: cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY, formdef_id varchar, formdata_id varchar)''' % table_name) cur.execute('''SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) existing_fields = set([x[0] for x in cur.fetchall()]) needed_fields = set(['id', 'formdef_id', 'formdata_id']) # delete obsolete fields for field in (existing_fields - needed_fields): cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() def do_session_table(): conn, cur = get_connection_and_cursor() table_name = 'sessions' cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) if cur.fetchone()[0] == 0: cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY, session_data bytea, name_identifier varchar, visiting_objects_keys varchar[] )''' % table_name) cur.execute('''SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s''', (table_name,)) existing_fields = set([x[0] for x in cur.fetchall()]) needed_fields = set(['id', 'session_data', 'name_identifier', 'visiting_objects_keys', 'last_update_time']) # migrations if not 'last_update_time' in existing_fields: cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp DEFAULT NOW()''' % table_name) cur.execute('''CREATE INDEX %s_ts ON %s (last_update_time)''' % ( table_name, table_name)) # delete obsolete fields for field in (existing_fields - needed_fields): cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() @guard_postgres def do_meta_table(conn=None, cur=None, insert_current_sql_level=True): own_conn = False if not conn: own_conn = True conn, cur = get_connection_and_cursor() cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', ('wcs_meta',)) if cur.fetchone()[0] == 0: cur.execute('''CREATE TABLE wcs_meta (id serial PRIMARY KEY, key varchar, value varchar)''') if insert_current_sql_level: sql_level = SQL_LEVEL else: sql_level = 0 cur.execute('''INSERT INTO wcs_meta (id, key, value) VALUES (DEFAULT, %s, %s)''', ('sql_level', str(sql_level))) if own_conn: conn.commit() cur.close() @guard_postgres def redo_views(conn, cur, formdef, rebuild_global_views=False): if get_publisher().get_site_option('postgresql_views') == 'false': return if formdef.id is None: return drop_views(formdef, conn, cur) do_views(formdef, conn, cur, rebuild_global_views=rebuild_global_views) @guard_postgres def drop_views(formdef, conn, cur): # remove the global views drop_global_views(conn, cur) if formdef: # remove the form view itself cur.execute('''SELECT table_name FROM information_schema.views WHERE table_schema = 'public' AND table_name LIKE %s''', ('wcs\\_view\\_%s\\_%%' % formdef.id ,)) else: # if there's no formdef specified, remove all form views cur.execute('''SELECT table_name FROM information_schema.views WHERE table_schema = 'public' AND table_name LIKE %s''', ('wcs\\_view\\_%',)) view_names = [] while True: row = cur.fetchone() if row is None: break view_names.append(row[0]) for view_name in view_names: cur.execute('''DROP VIEW IF EXISTS %s''' % view_name) def get_view_fields(formdef): view_fields = [] view_fields.append(("int '%s'" % (formdef.category_id or 0), 'category_id')) view_fields.append(("int '%s'" % (formdef.id or 0), 'formdef_id')) for field in ('id', 'user_id', 'receipt_time', 'status', 'id_display', 'submission_channel', 'backoffice_submission', 'last_update_time', 'digest', 'user_label'): view_fields.append((field, field)) return view_fields @guard_postgres def do_views(formdef, conn, cur, rebuild_global_views=True): # create new view table_name = get_formdef_table_name(formdef) view_name = get_formdef_view_name(formdef) view_fields = get_view_fields(formdef) column_names = {} for field in formdef.get_all_fields(): field_key = get_field_id(field) if field.type in ('page', 'title', 'subtitle', 'comment'): continue if field.varname: # the variable should be fine as is but we pass it through # get_name_as_sql_identifier nevertheless, to be extra sure it # doesn't contain invalid characters. field_name = 'f_%s' % get_name_as_sql_identifier(field.varname)[:50] else: field_name = '%s_%s' % (get_field_id(field), get_name_as_sql_identifier(field.label)) field_name = field_name[:50] if field_name in column_names: # it may happen that the same varname is used on multiple fields # (for example in the case of conditional pages), in that situation # we suffix the field name with an index count while field_name in column_names: column_names[field_name] += 1 field_name = '%s_%s' % (field_name, column_names[field_name]) column_names[field_name] = 1 view_fields.append((field_key, field_name)) if field.store_display_value: field_key = '%s_display' % get_field_id(field) view_fields.append((field_key, field_name + '_display')) view_fields.append(('''ARRAY(SELECT status FROM %s_evolutions '''\ ''' WHERE %s.id = %s_evolutions.formdata_id'''\ ''' ORDER BY %s_evolutions.time)''' % ((table_name,) * 4), 'status_history')) # add a is_at_endpoint column, dynamically created againt the endpoint status. endpoint_status = formdef.workflow.get_endpoint_status() view_fields.append(('''(SELECT status = ANY(ARRAY[[%s]]::text[]))''' % \ ', '.join(["'wf-%s'" % x.id for x in endpoint_status]), '''is_at_endpoint''')) # [CRITICALITY_1] Add criticality_level, computed relative to levels in # the given workflow, so all higher criticalites are sorted first. This is # reverted when loading the formdata back, in [CRITICALITY_2] levels = len(formdef.workflow.criticality_levels or [0]) view_fields.append(('''(criticality_level - %d)''' % levels, '''criticality_level''')) view_fields.append(( cur.mogrify('(SELECT text %s)', (formdef.name,)), 'formdef_name')) view_fields.append(('''(SELECT name FROM users WHERE users.id = CAST(user_id AS INTEGER))''', 'user_name')) view_fields.append(('concerned_roles_array', 'concerned_roles_array')) view_fields.append(('actions_roles_array', 'actions_roles_array')) view_fields.append(('fts', 'fts')) if formdef.geolocations and 'base' in formdef.geolocations: # default geolocation is in the 'base' key; we have to unstructure the # field is the POINT type of postgresql cannot be used directly as it # doesn't have an equality operator. view_fields.append(('geoloc_base[0]', 'geoloc_base_x')) view_fields.append(('geoloc_base[1]', 'geoloc_base_y')) else: view_fields.append(('NULL::real', 'geoloc_base_x')) view_fields.append(('NULL::real', 'geoloc_base_y')) fields_list = ', '.join(['%s AS %s' % x for x in view_fields]) cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % ( view_name, fields_list, table_name)) if rebuild_global_views: do_global_views(conn, cur) # recreate global views def drop_global_views(conn, cur): cur.execute('''SELECT table_name FROM information_schema.views WHERE table_schema = 'public' AND table_name LIKE %s''', ('wcs\\_category\\_%',)) view_names = [] while True: row = cur.fetchone() if row is None: break view_names.append(row[0]) for view_name in view_names: cur.execute('''DROP VIEW IF EXISTS %s''' % view_name) if get_publisher().pg_version >= 90400: # drop materialized view that may have been created by a previous # version. cur.execute('''DROP MATERIALIZED VIEW IF EXISTS wcs_materialized_all_forms''') cur.execute('''DROP VIEW IF EXISTS wcs_all_forms''') def do_global_views(conn, cur): # recreate global views from wcs.formdef import FormDef view_names = [get_formdef_view_name(x) for x in FormDef.select(ignore_migration=True)] cur.execute('''SELECT table_name FROM information_schema.views WHERE table_schema = 'public' AND table_name LIKE %s''', ('wcs\\_view\\_%',)) existing_views = set() while True: row = cur.fetchone() if row is None: break existing_views.add(row[0]) view_names = existing_views.intersection(view_names) if not view_names: return cur.execute('''DROP VIEW IF EXISTS wcs_all_forms CASCADE''') fake_formdef = FormDef() common_fields = get_view_fields(fake_formdef) common_fields.append(('concerned_roles_array', 'concerned_roles_array')) common_fields.append(('actions_roles_array', 'actions_roles_array')) common_fields.append(('fts', 'fts')) common_fields.append(('is_at_endpoint', 'is_at_endpoint')) common_fields.append(('formdef_name', 'formdef_name')) common_fields.append(('user_name', 'user_name')) common_fields.append(('criticality_level', 'criticality_level')) common_fields.append(('geoloc_base_x', 'geoloc_base_x')) common_fields.append(('geoloc_base_y', 'geoloc_base_y')) union = ' UNION '.join(['''SELECT %s FROM %s''' % ( ', '.join([y[1] for y in common_fields]), x) for x in view_names]) cur.execute('''CREATE VIEW wcs_all_forms AS %s''' % union) for category in wcs.categories.Category.select(): name = get_name_as_sql_identifier(category.url_name)[:40] cur.execute('''CREATE VIEW wcs_category_%s AS SELECT * from wcs_all_forms WHERE category_id = %s''' % (name, category.id)) class SqlMixin(object): _table_name = None @classmethod @guard_postgres def keys(cls): conn, cur = get_connection_and_cursor() sql_statement = 'SELECT id FROM %s' % cls._table_name cur.execute(sql_statement) ids = [x[0] for x in cur.fetchall()] conn.commit() cur.close() return ids @classmethod @guard_postgres def count(cls, clause=None): where_clauses, parameters, func_clause = parse_clause(clause) if func_clause: # fallback to counting the result of a select() return len(cls.select(clause)) sql_statement = 'SELECT count(*) FROM %s' % cls._table_name if where_clauses: sql_statement += ' WHERE ' + ' AND '.join(where_clauses) conn, cur = get_connection_and_cursor() cur.execute(sql_statement, parameters) count = cur.fetchone()[0] conn.commit() cur.close() return count @classmethod @guard_postgres def get_with_indexed_value(cls, index, value, ignore_errors = False): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE %s = %%(value)s''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name, index) cur.execute(sql_statement, {'value': str(value)}) try: while True: row = cur.fetchone() if row is None: break ob = cls._row2ob(row) if ignore_errors and ob is None: continue yield ob finally: conn.commit() cur.close() @classmethod @guard_postgres def get_ids_from_query(cls, query): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT id FROM %s WHERE fts @@ plainto_tsquery(%%(value)s)''' % cls._table_name cur.execute(sql_statement, {'value': qommon.misc.simplify(query, ' ')}) all_ids = [x[0] for x in cur.fetchall()] cur.close() return all_ids @classmethod @guard_postgres def get(cls, id, ignore_errors=False, ignore_migration=False): if id is None: if ignore_errors: return None else: raise KeyError() conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE id = %%(id)s''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) cur.execute(sql_statement, {'id': str(id)}) row = cur.fetchone() if row is None: cur.close() if ignore_errors: return None raise KeyError() cur.close() return cls._row2ob(row) @classmethod @guard_postgres def get_ids(cls, ids, ignore_errors=False, keep_order=False): if not ids: return [] conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE id IN (%s)''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name, ','.join([str(x) for x in ids])) cur.execute(sql_statement) objects = cls.get_objects(cur) conn.commit() cur.close() if ignore_errors: objects = (x for x in objects if x is not None) if keep_order: objects_dict = {} for object in objects: objects_dict[object.id] = object objects = [objects_dict[x] for x in ids if objects_dict.get(x)] return list(objects) @classmethod def get_objects_iterator(cls, cur, ignore_errors=False): while True: row = cur.fetchone() if row is None: break yield cls._row2ob(row) @classmethod def get_objects(cls, cur, ignore_errors=False, iterator=False): generator = cls.get_objects_iterator(cur=cur, ignore_errors=ignore_errors) if iterator: return generator return list(generator) @classmethod @guard_postgres def select_iterator(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None): sql_statement = '''SELECT %s FROM %s''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) where_clauses, parameters, func_clause = parse_clause(clause) if where_clauses: sql_statement += ' WHERE ' + ' AND '.join(where_clauses) if order_by: if order_by.startswith('-'): order_by = order_by[1:] sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_') else: sql_statement += ' ORDER BY %s' % order_by.replace('-', '_') if not func_clause: if limit: sql_statement += ' LIMIT %s' % limit if offset: sql_statement += ' OFFSET %s' % offset conn, cur = get_connection_and_cursor() cur.execute(sql_statement, parameters) try: for object in cls.get_objects(cur, iterator=True): if object is None: continue if func_clause and not func_clause(object): continue yield object finally: conn.commit() cur.close() @classmethod @guard_postgres def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=False): objects = cls.select_iterator(clause=clause, order_by=order_by, ignore_errors=ignore_errors, limit=limit, offset=offset) where_clauses, parameters, func_clause = parse_clause(clause) if func_clause and (limit or offset): objects = _take(objects, limit, offset) if iterator: return objects return list(objects) def get_sql_dict_from_data(self, data, formdef): sql_dict = {} for field in formdef.get_all_fields(): sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue value = data.get(field.id) if value is not None: if field.key in ('ranked-items', 'password'): # turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array value = [[x, site_unicode(y).encode('utf-8')] for x, y in value.items()] elif sql_type == 'varchar': assert isinstance(value, basestring) elif sql_type == 'date': assert type(value) is time.struct_time value = datetime.datetime(value.tm_year, value.tm_mon, value.tm_mday) elif sql_type == 'bytea': value = bytearray(cPickle.dumps(value, protocol=2)) elif sql_type == 'boolean': pass sql_dict[get_field_id(field)] = value if field.store_display_value: sql_dict['%s_display' % get_field_id(field)] = data.get('%s_display' % field.id) if field.store_structured_value: sql_dict['%s_structured' % get_field_id(field)] = bytearray( cPickle.dumps(data.get('%s_structured' % field.id), protocol=2)) return sql_dict @classmethod def _row2obdata(cls, row, formdef): obdata = {} i = len(cls._table_static_fields) if formdef.geolocations: i += len(formdef.geolocations.keys()) for field in formdef.get_all_fields(): sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue value = row[i] if value is not None: value = str_encode(value) if field.key == 'ranked-items': d = {} for data, rank in value: try: d[data] = int(rank) except ValueError: d[data] = rank value = d elif field.key == 'password': d = {} for fmt, val in value: d[fmt] = unicode(val, 'utf-8') value = d if sql_type == 'date': value = value.timetuple() elif sql_type == 'bytea': value = cPickle.loads(str(value)) obdata[field.id] = value i += 1 if field.store_display_value: value = str_encode(row[i]) obdata['%s_display' % field.id] = value i += 1 if field.store_structured_value: value = row[i] if value is not None: obdata['%s_structured' % field.id] = cPickle.loads(str(value)) if obdata['%s_structured' % field.id] is None: del obdata['%s_structured' % field.id] i += 1 return obdata @classmethod @guard_postgres def remove_object(cls, id): conn, cur = get_connection_and_cursor() sql_statement = '''DELETE FROM %s WHERE id = %%(id)s''' % cls._table_name cur.execute(sql_statement, {'id': str(id)}) conn.commit() cur.close() @classmethod @guard_postgres def wipe(cls): conn, cur = get_connection_and_cursor() sql_statement = '''DELETE FROM %s''' % cls._table_name cur.execute(sql_statement, {'id': str(id)}) conn.commit() cur.close() @classmethod @guard_postgres def get_sorted_ids(cls, order_by): conn, cur = get_connection_and_cursor() sql_statement = 'SELECT id FROM %s' % cls._table_name if order_by.startswith('-'): order_by = order_by[1:] sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_') else: sql_statement += ' ORDER BY %s' % order_by.replace('-', '_') cur.execute(sql_statement) ids = [x[0] for x in cur.fetchall()] conn.commit() cur.close() return ids class SqlFormData(SqlMixin, wcs.formdata.FormData): _names = None # make sure StorableObject methods fail _formdef = None _table_static_fields = [ ('id', 'serial'), ('user_id', 'varchar'), ('receipt_time', 'timestamp'), ('status', 'varchar'), ('page_no', 'varchar'), ('anonymised', 'timestamptz'), ('workflow_data', 'bytea'), ('id_display', 'varchar'), ('workflow_roles', 'bytea'), ('workflow_roles_array', 'text[]'), ('concerned_roles_array', 'text[]'), ('actions_roles_array', 'text[]'), ('tracking_code', 'varchar'), ('backoffice_submission', 'boolean'), ('submission_context', 'bytea'), ('submission_channel', 'varchar'), ('criticality_level', 'int'), ('last_update_time', 'timestamp'), ('digest', 'varchar'), ('user_label', 'varchar'), ] def __init__(self, id=None): self.id = id self.data = {} _evolution = None @guard_postgres def get_evolution(self): if self._evolution: return self._evolution if not self.id: self._evolution = [] return self._evolution conn, cur = get_connection_and_cursor() sql_statement = '''SELECT id, who, status, time, last_jump_datetime, comment, parts FROM %s_evolutions WHERE formdata_id = %%(id)s ORDER BY id''' % self._table_name cur.execute(sql_statement, {'id': self.id}) self._evolution = [] while True: row = cur.fetchone() if row is None: break self._evolution.append(self._row2evo(row, formdata=self)) conn.commit() cur.close() return self._evolution @classmethod def _row2evo(cls, row, formdata): o = wcs.formdata.Evolution(formdata) o._sql_id, o.who, o.status, o.time, o.last_jump_datetime, o.comment = [ str_encode(x) for x in tuple(row[:6])] if o.time: o.time = o.time.timetuple() if row[6]: o.parts = cPickle.loads(str(row[6])) return o def set_evolution(self, value): self._evolution = value evolution = property(get_evolution, set_evolution) @classmethod @guard_postgres def load_all_evolutions(cls, values): # Typically formdata.evolution is loaded on-demand (see above # property()) and this is fine to minimize queries, especially when # dealing with a single formdata. However in some places (to compute # statistics for example) it is sometimes useful to access .evolution # on a serie of formdata and in that case, it's more efficient to # optimize the process loading all evolutions in a single batch query. object_dict = dict([(x.id, x) for x in values if x.id and x._evolution is None]) if not object_dict: return conn, cur = get_connection_and_cursor() sql_statement = '''SELECT id, who, status, time, last_jump_datetime, comment, parts, formdata_id FROM %s_evolutions''' % cls._table_name sql_statement += ''' WHERE formdata_id IN %(object_ids)s ORDER BY id''' cur.execute(sql_statement, {'object_ids': tuple(object_dict.keys())}) for value in values: value._evolution = [] while True: row = cur.fetchone() if row is None: break _sql_id, who, status, time, last_jump_datetime, comment, parts, formdata_id = tuple(row[:8]) formdata = object_dict.get(formdata_id) if not formdata: continue formdata._evolution.append(formdata._row2evo(row, formdata)) conn.commit() cur.close() @guard_postgres @invalidate_substitution_cache def store(self): sql_dict = { 'user_id': self.user_id, 'status': self.status, 'page_no': self.page_no, 'workflow_data': bytearray(cPickle.dumps(self.workflow_data, protocol=2)), 'id_display': self.id_display, 'anonymised': self.anonymised, 'tracking_code': self.tracking_code, 'backoffice_submission': self.backoffice_submission, 'submission_context': self.submission_context, 'submission_channel': self.submission_channel, 'criticality_level': self.criticality_level, } if self.last_update_time: sql_dict['last_update_time'] = datetime.datetime.fromtimestamp(time.mktime(self.last_update_time)) else: sql_dict['last_update_time'] = None if self.receipt_time: sql_dict['receipt_time'] = datetime.datetime.fromtimestamp(time.mktime(self.receipt_time)) else: sql_dict['receipt_time'] = None if self.workflow_roles: sql_dict['workflow_roles'] = bytearray(cPickle.dumps(self.workflow_roles, protocol=2)) sql_dict['workflow_roles_array'] = [str(x) for x in self.workflow_roles.values() if x is not None] else: sql_dict['workflow_roles'] = None sql_dict['workflow_roles_array'] = None if self.submission_context: sql_dict['submission_context'] = bytearray(cPickle.dumps(self.submission_context, protocol=2)) else: sql_dict['submission_context'] = None for field in (self._formdef.geolocations or {}).keys(): value = (self.geolocations or {}).get(field) if value: value = '(%.6f, %.6f)' % (value.get('lon'), value.get('lat')) sql_dict['geoloc_%s' % field] = value sql_dict['concerned_roles_array'] = [str(x) for x in self.concerned_roles if x] sql_dict['actions_roles_array'] = [str(x) for x in self.actions_roles if x] sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef)) conn, cur = get_connection_and_cursor() if not self.id: column_names = sql_dict.keys() sql_statement = '''INSERT INTO %s (id, %s) VALUES (DEFAULT, %s) RETURNING id''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) cur.execute(sql_statement, sql_dict) self.id = cur.fetchone()[0] else: column_names = sql_dict.keys() sql_dict['id'] = self.id sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % ( self._table_name, ', '.join(['%s = %%(%s)s' % (x,x) for x in column_names])) cur.execute(sql_statement, sql_dict) if cur.fetchone() is None: column_names = sql_dict.keys() sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) cur.execute(sql_statement, sql_dict) self.id = cur.fetchone()[0] if self.set_auto_fields(): sql_statement = '''UPDATE %s SET id_display = %%(id_display)s, digest = %%(digest)s, user_label = %%(user_label)s WHERE id = %%(id)s''' % self._table_name cur.execute(sql_statement, { 'id': self.id, 'id_display': self.id_display, 'digest': self.digest, 'user_label': self.user_label, }) if self._evolution: for evo in self._evolution: sql_dict = {} if hasattr(evo, '_sql_id'): sql_dict.update({'id': evo._sql_id}) sql_statement = '''UPDATE %s_evolutions SET who = %%(who)s, time = %%(time)s, last_jump_datetime = %%(last_jump_datetime)s, status = %%(status)s, comment = %%(comment)s, parts = %%(parts)s WHERE id = %%(id)s RETURNING id''' % self._table_name else: sql_statement = '''INSERT INTO %s_evolutions ( id, who, status, time, last_jump_datetime, comment, parts, formdata_id) VALUES (DEFAULT, %%(who)s, %%(status)s, %%(time)s, %%(last_jump_datetime)s, %%(comment)s, %%(parts)s, %%(formdata_id)s) RETURNING id''' % self._table_name sql_dict.update({ 'who': evo.who, 'status': evo.status, 'time': datetime.datetime.fromtimestamp(time.mktime(evo.time)), 'last_jump_datetime': evo.last_jump_datetime, 'comment': evo.comment, 'formdata_id': self.id, }) if evo.parts: sql_dict['parts'] = bytearray(cPickle.dumps(evo.parts, protocol=2)) else: sql_dict['parts'] = None cur.execute(sql_statement, sql_dict) evo._sql_id = cur.fetchone()[0] fts_strings = [str(self.id), self.get_display_id()] fts_strings.append(self._formdef.name) if self.tracking_code: fts_strings.append(self.tracking_code) for field in self._formdef.get_all_fields(): if not self.data.get(field.id): continue value = None if field.key in ('string', 'text', 'email'): value = self.data.get(field.id) elif field.key in ('item', 'items'): value = self.data.get('%s_display' % field.id) if value: if isinstance(value, basestring): fts_strings.append(value) elif type(value) in (tuple, list): fts_strings.extend(value) if self._evolution: for evo in self._evolution: if evo.comment: fts_strings.append(evo.comment) for part in evo.parts or []: if hasattr(part, 'view'): html_part = part.view() if html_part: fts_strings.append(qommon.misc.html2text(html_part)) user = self.get_user() if user: fts_strings.append(user.get_display_name()) sql_statement = '''UPDATE %s SET fts = to_tsvector( %%(fts)s) WHERE id = %%(id)s''' % self._table_name cur.execute(sql_statement, { 'id': self.id, 'fts': qommon.misc.simplify(' '.join(fts_strings), space=' ')}) conn.commit() cur.close() @classmethod def _row2ob(cls, row): o = cls() for static_field, value in zip(cls._table_static_fields, tuple(row[:len(cls._table_static_fields)])): setattr(o, static_field[0], str_encode(value)) if o.receipt_time: o.receipt_time = o.receipt_time.timetuple() if o.workflow_data: o.workflow_data = cPickle.loads(str(o.workflow_data)) if o.workflow_roles: o.workflow_roles = cPickle.loads(str(o.workflow_roles)) if o.submission_context: o.submission_context = cPickle.loads(str(o.submission_context)) o.geolocations = {} for i, field in enumerate((cls._formdef.geolocations or {}).keys()): value = row[len(cls._table_static_fields)+i] if not value: continue m = re.match(r"\(([^)]+),([^)]+)\)", value) o.geolocations[field] = {'lon': float(m.group(1)), 'lat': float(m.group(2))} o.data = cls._row2obdata(row, cls._formdef) del o._last_update_time return o @classmethod def get_data_fields(cls): data_fields = ['geoloc_%s' % x for x in (cls._formdef.geolocations or {}).keys()] for field in cls._formdef.get_all_fields(): sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue data_fields.append(get_field_id(field)) if field.store_display_value: data_fields.append('%s_display' % get_field_id(field)) if field.store_structured_value: data_fields.append('%s_structured' % get_field_id(field)) return data_fields @classmethod @guard_postgres def get(cls, id, ignore_errors=False, ignore_migration=False): if id is None: if ignore_errors: return None else: raise KeyError() else: try: int(id) except ValueError: raise KeyError() conn, cur = get_connection_and_cursor() fields = cls.get_data_fields() potential_comma = ', ' if not fields: potential_comma = '' sql_statement = '''SELECT %s %s %s FROM %s WHERE id = %%(id)s''' % ( ', '.join([x[0] for x in cls._table_static_fields]), potential_comma, ', '.join(fields), cls._table_name) cur.execute(sql_statement, {'id': str(id)}) row = cur.fetchone() if row is None: cur.close() if ignore_errors: return None raise KeyError() cur.close() return cls._row2ob(row) @classmethod @guard_postgres def get_ids_with_indexed_value(cls, index, value, auto_fallback=True): conn, cur = get_connection_and_cursor() if type(value) is int: value = str(value) if '%s_array' % index in [x[0] for x in cls._table_static_fields]: sql_statement = '''SELECT id FROM %s WHERE %%(value)s = ANY (%s_array)''' % ( cls._table_name, index) else: sql_statement = '''SELECT id FROM %s WHERE %s = %%(value)s''' % ( cls._table_name, index) cur.execute(sql_statement, {'value': value}) all_ids = [x[0] for x in cur.fetchall()] cur.close() return all_ids @classmethod @guard_postgres def fix_sequences(cls): conn, cur = get_connection_and_cursor() for table_name in (cls._table_name, '%s_evolutions' % cls._table_name): sql_statement = '''select max(id) from %s''' % table_name cur.execute(sql_statement) max_id = cur.fetchone()[0] if max_id: sql_statement = '''ALTER SEQUENCE %s RESTART %s''' % ( '%s_id_seq' % table_name, max_id+1) cur.execute(sql_statement) conn.commit() cur.close() @classmethod def rebuild_security(cls): formdatas = cls.select(order_by='id') conn, cur = get_connection_and_cursor() for formdata in formdatas: sql_statement = '''UPDATE %s SET concerned_roles_array = %%(roles)s, actions_roles_array = %%(actions_roles)s WHERE id = %%(id)s''' % cls._table_name with get_publisher().substitutions.temporary_feed(formdata): # formdata is already added to sources list in individual # {concerned,actions}_roles but adding it first here will # allow cached values to be reused between the properties. cur.execute(sql_statement, { 'id': formdata.id, 'roles': [str(x) for x in formdata.concerned_roles if x], 'actions_roles': [str(x) for x in formdata.actions_roles if x]}) conn.commit() cur.close() @classmethod @guard_postgres def wipe(cls, drop=False): conn, cur = get_connection_and_cursor() if drop: cur.execute('''DROP TABLE %s_evolutions CASCADE''' % cls._table_name) cur.execute('''DROP TABLE %s CASCADE''' % cls._table_name) else: cur.execute('''DELETE FROM %s_evolutions''' % cls._table_name) cur.execute('''DELETE FROM %s''' % cls._table_name) conn.commit() cur.close() @classmethod def do_tracking_code_table(cls): do_tracking_code_table() class SqlUser(SqlMixin, wcs.users.User): _table_name = 'users' _table_static_fields = [ ('id', 'serial'), ('name', 'varchar'), ('email', 'varchar'), ('roles', 'varchar[]'), ('is_admin', 'bool'), ('anonymous', 'bool'), ('name_identifiers', 'varchar[]'), ('verified_fields', 'varchar[]'), ('lasso_dump', 'text'), ('last_seen', 'timestamp'), ('ascii_name', 'varchar'), ] id = None def __init__(self, name=None): self.name = name self.name_identifiers = [] self.verified_fields = [] self.roles = [] @guard_postgres @invalidate_substitution_cache def store(self): sql_dict = { 'name': self.name, 'ascii_name': self.ascii_name, 'email': self.email, 'roles': self.roles, 'is_admin': self.is_admin, 'anonymous': self.anonymous, 'name_identifiers': self.name_identifiers, 'verified_fields': self.verified_fields, 'lasso_dump': self.lasso_dump, 'last_seen': None, } if self.last_seen: sql_dict['last_seen'] = datetime.datetime.fromtimestamp(self.last_seen), user_formdef = self.get_formdef() if not self.form_data: self.form_data = {} sql_dict.update(self.get_sql_dict_from_data(self.form_data, user_formdef)) conn, cur = get_connection_and_cursor() if not self.id: column_names = sql_dict.keys() sql_statement = '''INSERT INTO %s (id, %s) VALUES (DEFAULT, %s) RETURNING id''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) cur.execute(sql_statement, sql_dict) self.id = cur.fetchone()[0] else: column_names = sql_dict.keys() sql_dict['id'] = self.id sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % ( self._table_name, ', '.join(['%s = %%(%s)s' % (x,x) for x in column_names])) cur.execute(sql_statement, sql_dict) if cur.fetchone() is None: column_names = sql_dict.keys() sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) cur.execute(sql_statement, sql_dict) fts_strings = [] if self.name: fts_strings.append(self.name) fts_strings.append(self.ascii_name) if self.email: fts_strings.append(self.email) if user_formdef and user_formdef.fields: for field in user_formdef.fields: if not self.form_data.get(field.id): continue value = None if field.key in ('string', 'text', 'email'): value = self.form_data.get(field.id) elif field.key in ('item', 'items'): value = self.form_data.get('%s_display' % field.id) if value: if isinstance(value, basestring): fts_strings.append(value) elif type(value) in (tuple, list): fts_strings.extend(value) sql_statement = '''UPDATE %s SET fts = to_tsvector( %%(fts)s) WHERE id = %%(id)s''' % self._table_name cur.execute(sql_statement, {'id': self.id, 'fts': ' '.join(fts_strings)}) conn.commit() cur.close() @classmethod def _row2ob(cls, row): o = cls() (o.id, o.name, o.email, o.roles, o.is_admin, o.anonymous, o.name_identifiers, o.verified_fields, o.lasso_dump, o.last_seen, ascii_name) = [str_encode(x) for x in tuple(row[:11])] if o.last_seen: o.last_seen = time.mktime(o.last_seen.timetuple()) if o.roles: o.roles = [str(x) for x in o.roles] o.form_data = cls._row2obdata(row, cls.get_formdef()) return o @classmethod def get_data_fields(cls): data_fields = [] for field in cls.get_formdef().get_all_fields(): sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar') if sql_type is None: continue data_fields.append(get_field_id(field)) if field.store_display_value: data_fields.append('%s_display' % get_field_id(field)) if field.store_structured_value: data_fields.append('%s_structured' % get_field_id(field)) return data_fields @classmethod @guard_postgres def fix_sequences(cls): conn, cur = get_connection_and_cursor() sql_statement = '''select max(id) from %s''' % cls._table_name cur.execute(sql_statement) max_id = cur.fetchone()[0] if max_id is not None: sql_statement = '''ALTER SEQUENCE %s_id_seq RESTART %s''' % ( cls._table_name, max_id+1) cur.execute(sql_statement) conn.commit() cur.close() @classmethod @guard_postgres def get_users_with_name_identifier(cls, name_identifier): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE %%(value)s = ANY(name_identifiers)''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) cur.execute(sql_statement, {'value': name_identifier}) objects = cls.get_objects(cur) conn.commit() cur.close() return objects @classmethod @guard_postgres def get_users_with_email(cls, email): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE email = %%(value)s''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) cur.execute(sql_statement, {'value': email}) objects = cls.get_objects(cur) conn.commit() cur.close() return objects @classmethod @guard_postgres def get_users_with_role(cls, role_id): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE %%(value)s = ANY(roles)''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) cur.execute(sql_statement, {'value': str(role_id)}) objects = cls.get_objects(cur) conn.commit() cur.close() return objects class Session(SqlMixin, wcs.sessions.BasicSession): _table_name = 'sessions' _table_static_fields = [ ('id', 'varchar'), ('session_data', 'bytea'), ] @classmethod @guard_postgres def select_recent(cls, seconds=30*60, **kwargs): clause = [GreaterOrEqual('last_update_time', datetime.datetime.now() - datetime.timedelta(seconds=seconds))] return cls.select(clause=clause, **kwargs) @guard_postgres def store(self): sql_dict = { 'id': self.id, 'session_data': bytearray(cPickle.dumps(self.__dict__, protocol=2)), # the other fields are stored to run optimized SELECT() against the # table, they are ignored when loading the data. 'name_identifier': self.name_identifier, 'visiting_objects_keys': self.visiting_objects.keys() if getattr(self, 'visiting_objects') else None, 'last_update_time': datetime.datetime.now(), } conn, cur = get_connection_and_cursor() column_names = sql_dict.keys() sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % ( self._table_name, ', '.join(['%s = %%(%s)s' % (x,x) for x in column_names])) cur.execute(sql_statement, sql_dict) if cur.fetchone() is None: sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) cur.execute(sql_statement, sql_dict) conn.commit() cur.close() @classmethod def _row2ob(cls, row): o = cls.__new__(cls) cls.id = str_encode(row[0]) session_data = cPickle.loads(str(row[1])) for k, v in session_data.items(): setattr(o, k, v) return o @classmethod def get_sessions_for_saml(cls, name_identifier=Ellipsis, *args, **kwargs): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE name_identifier = %%(value)s''' % ( ', '.join([x[0] for x in cls._table_static_fields]), cls._table_name) cur.execute(sql_statement, {'value': name_identifier}) objects = cls.get_objects(cur) conn.commit() cur.close() return objects @classmethod def get_sessions_with_visited_object(cls, object_key): conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s FROM %s WHERE %%(value)s = ANY(visiting_objects_keys) AND last_update_time > (now() - interval '30 minutes') ''' % ( ', '.join([x[0] for x in cls._table_static_fields] + cls.get_data_fields()), cls._table_name) cur.execute(sql_statement, {'value': object_key}) objects = cls.get_objects(cur) conn.commit() cur.close() return objects @classmethod def get_data_fields(cls): return [] class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode): _table_name = 'tracking_codes' _table_static_fields = [ ('id', 'varchar'), ('formdef_id', 'varchar'), ('formdata_id', 'varchar'), ] id = None @classmethod def get(cls, id, **kwargs): return super(TrackingCode, cls).get(id.upper(), **kwargs) @guard_postgres @invalidate_substitution_cache def store(self): sql_dict = { 'id': self.id, 'formdef_id': self.formdef_id, 'formdata_id': self.formdata_id } conn, cur = get_connection_and_cursor() if not self.id: column_names = sql_dict.keys() sql_dict['id'] = self.get_new_id() sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % ( self._table_name, ', '.join(column_names), ', '.join(['%%(%s)s' % x for x in column_names])) while True: try: cur.execute(sql_statement, sql_dict) except psycopg2.IntegrityError: conn.rollback() sql_dict['id'] = self.get_new_id() else: break self.id = str_encode(cur.fetchone()[0]) else: column_names = sql_dict.keys() sql_dict['id'] = self.id sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % ( self._table_name, ', '.join(['%s = %%(%s)s' % (x,x) for x in column_names])) cur.execute(sql_statement, sql_dict) if cur.fetchone() is None: raise AssertionError() conn.commit() cur.close() @classmethod def _row2ob(cls, row): o = cls() (o.id, o.formdef_id, o.formdata_id) = [str_encode(x) for x in tuple(row[:3])] return o @classmethod def get_data_fields(cls): return [] class classproperty(object): def __init__(self, f): self.f = f def __get__(self, obj, owner): return self.f(owner) class AnyFormData(SqlMixin): _table_name = 'wcs_all_forms' __table_static_fields = [] _formdef_cache = {} @classproperty def _table_static_fields(cls): if cls.__table_static_fields: return cls.__table_static_fields from wcs.formdef import FormDef fake_formdef = FormDef() common_fields = get_view_fields(fake_formdef) cls.__table_static_fields = [(x[1], x[0]) for x in common_fields] cls.__table_static_fields.append(('criticality_level', 'criticality_level')) cls.__table_static_fields.append(('geoloc_base_x', 'geoloc_base_x')) cls.__table_static_fields.append(('geoloc_base_y', 'geoloc_base_y')) cls.__table_static_fields.append(('concerned_roles_array', 'concerned_roles_array')) return cls.__table_static_fields @classmethod def get_data_fields(cls): return [] @classmethod def get_objects(cls, *args, **kwargs): cls._formdef_cache = {} return super(AnyFormData, cls).get_objects(*args, **kwargs) @classmethod def _row2ob(cls, row): formdef_id = row[1] from wcs.formdef import FormDef formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id)) o = formdef.data_class()() for static_field, value in zip(cls._table_static_fields, tuple(row[:len(cls._table_static_fields)])): setattr(o, static_field[0], str_encode(value)) # [CRITICALITY_2] transform criticality_level back to the expected # range (see [CRITICALITY_1]) levels = len(formdef.workflow.criticality_levels or [0]) o.criticality_level = levels + o.criticality_level # convert back unstructured geolocation to the 'native' formdata format. if o.geoloc_base_x is not None: o.geolocations = {'base': {'lon': o.geoloc_base_x, 'lat': o.geoloc_base_y}} return o def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None): clause = [NotNull('receipt_time')] table_name = 'wcs_all_forms' if criterias: for criteria in criterias: if criteria.__class__.__name__ == 'Equal' and \ criteria.attribute == 'formdef_id': # if there's a formdef_id specified, switch to using the # specific table so we have access to all fields from wcs.formdef import FormDef table_name = get_formdef_table_name(FormDef.get(criteria.value)) continue clause.append(criteria) if period_start: clause.append(GreaterOrEqual('receipt_time', period_start)) if period_end: clause.append(LessOrEqual('receipt_time', period_end)) where_clauses, params, func_clause = parse_clause(clause) parameters.update(params) statement = ' FROM %s ' % table_name statement += ' WHERE ' + ' AND '.join(where_clauses) return statement @guard_postgres def get_actionable_counts(user_roles): conn, cur = get_connection_and_cursor() criterias = [Equal('is_at_endpoint', False), Intersects('actions_roles_array', user_roles)] where_clauses, parameters, func_clause = parse_clause(criterias) statement = '''SELECT formdef_id, COUNT(*) FROM wcs_all_forms WHERE %s GROUP BY formdef_id''' % ' AND '.join(where_clauses) cur.execute(statement, parameters) counts = {str(x): y for x, y in cur.fetchall()} conn.commit() cur.close() return counts @guard_postgres def get_total_counts(user_roles): conn, cur = get_connection_and_cursor() criterias = [ Intersects('concerned_roles_array', user_roles), ] where_clauses, parameters, func_clause = parse_clause(criterias) statement = '''SELECT formdef_id, COUNT(*) FROM wcs_all_forms WHERE %s GROUP BY formdef_id''' % ' AND '.join(where_clauses) cur.execute(statement, parameters) counts = {str(x): y for x, y in cur.fetchall()} conn.commit() cur.close() return counts @guard_postgres def get_weekday_totals(period_start=None, period_end=None, criterias=None): conn, cur = get_connection_and_cursor() statement = '''SELECT DATE_PART('dow', receipt_time) AS weekday, COUNT(*)''' parameters = {} statement += get_period_query(period_start, period_end, criterias, parameters) statement += ' GROUP BY weekday ORDER BY weekday''' cur.execute(statement, parameters) result = cur.fetchall() result = [(int(x), y) for x, y in result] coverage = [x[0] for x in result] for weekday in range(7): if not weekday in coverage: result.append((weekday, 0)) result.sort() conn.commit() cur.close() return result @guard_postgres def get_hour_totals(period_start=None, period_end=None, criterias=None): conn, cur = get_connection_and_cursor() statement = '''SELECT DATE_PART('hour', receipt_time) AS hour, COUNT(*)''' parameters = {} statement += get_period_query(period_start, period_end, criterias, parameters) statement += ' GROUP BY hour ORDER BY hour' cur.execute(statement, parameters) result = cur.fetchall() result = [(int(x), y) for x, y in result] coverage = [x[0] for x in result] for hour in range(24): if not hour in coverage: result.append((hour, 0)) result.sort() conn.commit() cur.close() return result @guard_postgres def get_monthly_totals(period_start=None, period_end=None, criterias=None): conn, cur = get_connection_and_cursor() statement = '''SELECT DATE_TRUNC('month', receipt_time) AS month, COUNT(*) ''' parameters = {} statement += get_period_query(period_start, period_end, criterias, parameters) statement += ' GROUP BY month ORDER BY month''' cur.execute(statement, parameters) raw_result = cur.fetchall() result = [('%d-%02d' % x.timetuple()[:2], y) for x, y in raw_result] if result: coverage = [x[0] for x in result] current_month = raw_result[0][0] last_month = raw_result[-1][0] while current_month < last_month: label = '%d-%02d' % current_month.timetuple()[:2] if not label in coverage: result.append((label, 0)) current_month = current_month + datetime.timedelta(days=31) current_month = current_month - datetime.timedelta(days=current_month.day-1) result.sort() conn.commit() cur.close() return result @guard_postgres def get_yearly_totals(period_start=None, period_end=None, criterias=None): conn, cur = get_connection_and_cursor() statement = '''SELECT DATE_TRUNC('year', receipt_time) AS year, COUNT(*)''' parameters = {} statement += get_period_query(period_start, period_end, criterias, parameters) statement += ' GROUP BY year ORDER BY year' cur.execute(statement, parameters) raw_result = cur.fetchall() result = [(str(x.year), y) for x, y in raw_result] if result: coverage = [x[0] for x in result] current_year = raw_result[0][0] last_year = raw_result[-1][0] while current_year < last_year: label = str(current_year.year) if not label in coverage: result.append((label, 0)) current_year = current_year + datetime.timedelta(days=366) result.sort() conn.commit() cur.close() return result SQL_LEVEL = 32 def migrate_global_views(conn, cur): cur.execute('''SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'public' AND table_name = %s''', ('wcs_all_forms',)) existing_fields = set([x[0] for x in cur.fetchall()]) if 'formdef_id' not in existing_fields: drop_global_views(conn, cur) do_global_views(conn, cur) @guard_postgres def get_sql_level(conn, cur): do_meta_table(conn, cur, insert_current_sql_level=False) cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', ('sql_level', )) sql_level = int(cur.fetchone()[0]) return sql_level @guard_postgres def is_reindex_needed(index, conn, cur): do_meta_table(conn, cur, insert_current_sql_level=False) key_name = 'reindex_%s' % index cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name, )) row = cur.fetchone() if row is None: cur.execute('''INSERT INTO wcs_meta (id, key, value) VALUES (DEFAULT, %s, %s)''', (key_name, 'no')) return False return row[0] == 'needed' @guard_postgres def set_reindex(index, value, conn=None, cur=None): own_conn = False if not conn: own_conn = True conn, cur = get_connection_and_cursor() do_meta_table(conn, cur, insert_current_sql_level=False) key_name = 'reindex_%s' % index cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', (key_name, )) row = cur.fetchone() if row is None: cur.execute('''INSERT INTO wcs_meta (id, key, value) VALUES (DEFAULT, %s, %s)''', (key_name, value)) else: cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', ( value, key_name)) if own_conn: conn.commit() cur.close() def migrate_views(conn, cur): drop_views(None, conn, cur) from wcs.formdef import FormDef for formdef in FormDef.select(): # make sure all formdefs have up-to-date views do_formdef_tables(formdef, conn=conn, cur=cur, rebuild_views=True, rebuild_global_views=False) migrate_global_views(conn, cur) @guard_postgres def migrate(): conn, cur = get_connection_and_cursor() sql_level = get_sql_level(conn, cur) if sql_level < 0: # fake code to help in tetsting the error code path. raise RuntimeError() if sql_level < 1: # 1: introduction of tracking_code table do_tracking_code_table() if sql_level < 31: # 2: introduction of formdef_id in views # 5: add concerned_roles_array, is_at_endpoint and fts to views # 7: add backoffice_submission to tables # 8: add submission_context to tables # 9: add last_update_time to views # 10: add submission_channel to tables # 11: add formdef_name and user_name to views # 13: add backoffice_submission to views # 14: add criticality_level to tables & views # 15: add geolocation to formdata # 19: add geolocation to views # 20: remove user hash stuff # 22: rebuild views # 26: add digest to formdata # 27: add last_jump_datetime in evolutions tables # 31: add user_label to formdata migrate_views(conn, cur) if sql_level < 21: # 3: introduction of _structured for user fields # 4: removal of identification_token # 12: (first part) add fts to users # 16: add verified_fields to users # 21: (first part) add ascii_name to users do_user_table() if sql_level < 6: # 6: add actions_roles_array to tables and views from wcs.formdef import FormDef migrate_views(conn, cur) for formdef in FormDef.select(): formdef.data_class().rebuild_security() if sql_level < 23: # 12: (second part), store fts in existing rows # 21: (second part), store ascii_name of users # 23: (first part), use misc.simplify() over full text queries set_reindex('user', 'needed', conn=conn, cur=cur) if sql_level < 31: # 17: store last_update_time in tables # 18: add user name to full-text search index # 21: (third part), add user ascii_names to full-text index # 23: (second part) use misc.simplify() over full text queries # 28: add display id and formdef name to full-text index # 29: add evolution parts to full-text index # 31: add user_label to formdata set_reindex('formdata', 'needed', conn=conn, cur=cur) if sql_level < 24: from wcs.formdef import FormDef # 24: add index on evolution(formdata_id) for formdef in FormDef.select(): do_formdef_indexes(formdef, created=False, conn=conn, cur=cur) if sql_level < 32: # 25: create session_table # 32: add last_update_time column to session table do_session_table() if sql_level < 30: # 30: actually remove evo.who on anonymised formdatas from wcs.formdef import FormDef for formdef in FormDef.select(): for formdata in formdef.data_class().select_iterator(clause=[NotNull('anonymised')]): if formdata.evolution: for evo in formdata.evolution: evo.who = None formdata.store() cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', ( str(SQL_LEVEL), 'sql_level')) conn.commit() cur.close() @guard_postgres def reindex(): conn, cur = get_connection_and_cursor() if is_reindex_needed('user', conn=conn, cur=cur): for user in SqlUser.select(iterator=True): user.store() set_reindex('user', 'done', conn=conn, cur=cur) from wcs.formdef import FormDef if is_reindex_needed('formdata', conn=conn, cur=cur): # load and store all formdatas for formdef in FormDef.select(): for formdata in formdef.data_class().select(iterator=True): formdata.store() set_reindex('formdata', 'done', conn=conn, cur=cur) conn.commit() cur.close()