wcs/wcs/sql.py

1733 lines
63 KiB
Python

# w.c.s. - web application for online forms
# Copyright (C) 2005-2012 Entr'ouvert
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import psycopg2
import datetime
import time
import cPickle
from quixote import get_publisher
import qommon
from qommon.storage import _take, parse_clause as parse_storage_clause
from qommon import get_cfg
import wcs.categories
import wcs.formdata
import wcs.tracking_code
import wcs.users
SQL_TYPE_MAPPING = {
'title': None,
'subtitle': None,
'comment': None,
'page': None,
'text': 'text',
'bool': 'boolean',
'file': 'bytea',
'date': 'date',
'items': 'text[]',
'table': 'text[][]',
'table-select': 'text[][]',
'tablerows': 'text[][]',
# mapping of dicts
'ranked-items': 'text[][]',
'password': 'text[][]',
}
class Criteria(qommon.storage.Criteria):
def __init__(self, attribute, value, **kwargs):
self.attribute = attribute
self.value = value
def as_sql(self):
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
def as_sql_param(self):
if isinstance(self.value, time.struct_time):
value = datetime.datetime.fromtimestamp(time.mktime(self.value))
else:
value = self.value
return {'c%s' % id(self.value): value}
class Less(Criteria):
sql_op = '<'
class Greater(Criteria):
sql_op = '>'
class Equal(Criteria):
sql_op = '='
def as_sql(self):
if self.value in ([], ()):
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
return super(Equal, self).as_sql()
class LessOrEqual(Criteria):
sql_op = '<='
class GreaterOrEqual(Criteria):
sql_op = '>='
class NotEqual(Criteria):
sql_op = '!='
class Contains(Criteria):
sql_op = 'IN'
def as_sql(self):
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
def as_sql_param(self):
return {'c%s' % id(self.value): tuple(self.value)}
class NotContains(Contains):
sql_op = 'NOT IN'
class NotNull(Criteria):
sql_op = 'IS NOT NULL'
def __init__(self, attribute):
self.attribute = attribute
def as_sql(self):
return '%s %s' % (self.attribute, self.sql_op)
def as_sql_param(self):
return {}
class Or(Criteria):
def __init__(self, criterias):
self.criterias = []
for element in criterias:
sql_class = globals().get(element.__class__.__name__)
sql_element = sql_class(**element.__dict__)
self.criterias.append(sql_element)
def as_sql(self):
return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
def as_sql_param(self):
d = {}
for criteria in self.criterias:
d.update(criteria.as_sql_param())
return d
class And(Criteria):
def __init__(self, criterias):
self.criterias = []
for element in criterias:
sql_class = globals().get(element.__class__.__name__)
sql_element = sql_class(**element.__dict__)
self.criterias.append(sql_element)
def as_sql(self):
return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
def as_sql_param(self):
d = {}
for criteria in self.criterias:
d.update(criteria.as_sql_param())
return d
class Intersects(Criteria):
def as_sql(self):
if not self.value:
return 'ARRAY_LENGTH(%s, 1) IS NULL' % self.attribute
else:
return '%s && %%(c%s)s' % (self.attribute, id(self.value))
def as_sql_param(self):
return {'c%s' % id(self.value): list(self.value)}
class ILike(Criteria):
def __init__(self, attribute, value, **kwargs):
super(ILike, self).__init__(attribute, value, **kwargs)
self.value = '%' + self.value + '%'
def as_sql(self):
return '%s ILIKE %%(c%s)s' % (self.attribute, id(self.value))
def get_name_as_sql_identifier(name):
name = qommon.misc.simplify(name)
for char in '<>|{}!?^*+/=\'': # forbidden chars
name = name.replace(char, '')
name = name.replace('-', '_')
return name
def parse_clause(clause):
# returns a three-elements tuple with:
# - a list of SQL 'WHERE' clauses
# - a dict for query parameters
# - a callable, or None if all clauses have been successfully translated
if clause is None:
return ([], None, None)
if callable(clause): # already a callable
return ([], None, clause)
# create 'WHERE' clauses
func_clauses = []
where_clauses = []
parameters = {}
for element in clause:
if callable(element):
func_clauses.append(element)
else:
sql_class = globals().get(element.__class__.__name__)
if sql_class:
sql_element = sql_class(**element.__dict__)
where_clauses.append(sql_element.as_sql())
parameters.update(sql_element.as_sql_param())
else:
func_clauses.append(element.build_lambda())
if func_clauses:
return (where_clauses, parameters, parse_storage_clause(func_clauses))
else:
return (where_clauses, parameters, None)
def get_connection(new=False):
if new:
cleanup_connection()
if not hasattr(get_publisher(), 'pgconn') or get_publisher().pgconn is None:
postgresql_cfg = {}
for k, v in get_cfg('postgresql', {}).items():
if v:
postgresql_cfg[k] = v
try:
get_publisher().pgconn = psycopg2.connect(**postgresql_cfg)
except psycopg2.Error:
if new:
raise
get_publisher().pgconn = None
return get_publisher().pgconn
def cleanup_connection():
if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
get_publisher().pgconn.close()
get_publisher().pgconn = None
def get_connection_and_cursor(new=False):
conn = get_connection(new=new)
try:
cur = conn.cursor()
except psycopg2.InterfaceError:
# may be postgresql was restarted in between
conn = get_connection(new=True)
cur = conn.cursor()
return (conn, cur)
def get_formdef_table_name(formdef):
# PostgreSQL limits identifier length to 63 bytes
#
# The system uses no more than NAMEDATALEN-1 bytes of an identifier;
# longer names can be written in commands, but they will be truncated.
# By default, NAMEDATALEN is 64 so the maximum identifier length is
# 63 bytes. If this limit is problematic, it can be raised by changing
# the NAMEDATALEN constant in src/include/pg_config_manual.h.
#
# as we have to know our table names, we crop the names here, and to an
# extent that allows suffixes (like _evolution) to be added.
assert formdef.id is not None
if hasattr(formdef, 'table_name') and formdef.table_name:
return formdef.table_name
formdef.table_name = 'formdata_%s_%s' % (formdef.id,
get_name_as_sql_identifier(formdef.url_name)[:30])
formdef.store()
return formdef.table_name
def guard_postgres(func):
def f(*args, **kwargs):
try:
return func(*args, **kwargs)
except psycopg2.Error:
get_connection().rollback()
raise
return f
@guard_postgres
def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=False):
if formdef.id is None:
return []
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
table_name = get_formdef_table_name(formdef)
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = %s''', (table_name,))
if cur.fetchone()[0] == 0:
cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY,
user_id varchar,
user_hash varchar,
receipt_time timestamp,
anonymised timestamptz,
status varchar,
page_no varchar,
workflow_data bytea,
id_display varchar
)''' % table_name)
cur.execute('''CREATE TABLE %s_evolutions (id serial PRIMARY KEY,
who varchar,
status varchar,
time timestamp,
comment text,
parts bytea,
formdata_id integer REFERENCES %s (id) ON DELETE CASCADE)''' % (
table_name, table_name))
cur.execute('''SELECT column_name FROM information_schema.columns
WHERE table_name = %s''', (table_name,))
existing_fields = set([x[0] for x in cur.fetchall()])
needed_fields = set(['id', 'user_id', 'user_hash', 'receipt_time',
'status', 'workflow_data', 'id_display', 'fts', 'page_no',
'anonymised', 'workflow_roles', 'workflow_roles_array',
'concerned_roles_array', 'tracking_code',
'actions_roles_array', 'backoffice_submission',
'submission_context'])
# migrations
if not 'fts' in existing_fields:
# full text search
cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name)
cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (
table_name, table_name))
if not 'workflow_roles' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles bytea''' % table_name)
cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles_array text[]''' % table_name)
if not 'concerned_roles_array' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN concerned_roles_array text[]''' % table_name)
if not 'actions_roles_array' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN actions_roles_array text[]''' % table_name)
if not 'page_no' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN page_no varchar''' % table_name)
if not 'anonymised' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN anonymised timestamptz''' % table_name)
if not 'tracking_code' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN tracking_code varchar''' % table_name)
if not 'backoffice_submission' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN backoffice_submission boolean''' % table_name)
if not 'submission_context' in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN submission_context bytea''' % table_name)
# add new fields
for field in formdef.fields:
assert field.id is not None
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add('f%s' % field.id)
if 'f%s' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (
table_name,
'f%s' % field.id,
sql_type))
if field.store_display_value:
needed_fields.add('f%s_display' % field.id)
if 'f%s_display' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % (
table_name, 'f%s_display' % field.id))
if field.store_structured_value:
needed_fields.add('f%s_structured' % field.id)
if 'f%s_structured' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % (
table_name, 'f%s_structured' % field.id))
# delete obsolete fields
for field in (existing_fields - needed_fields):
cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field))
if rebuild_views or len(existing_fields - needed_fields):
# views may have been dropped when dropping columns, so we recreate
# them even if not asked to.
redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views)
if own_conn:
conn.commit()
cur.close()
actions = []
if not 'concerned_roles_array' in existing_fields:
actions.append('rebuild_security')
elif not 'actions_roles_array' in existing_fields:
actions.append('rebuild_security')
if not 'tracking_code' in existing_fields:
# if tracking code has just been added to the table we need to make
# sure the tracking code table does exist.
actions.append('do_tracking_code_table')
return actions
@guard_postgres
def do_user_table():
conn, cur = get_connection_and_cursor()
table_name = 'users'
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = %s''', (table_name,))
if cur.fetchone()[0] == 0:
cur.execute('''CREATE TABLE %s (id serial PRIMARY KEY,
name varchar,
email varchar,
roles text[],
is_admin bool,
anonymous bool,
name_identifiers text[],
lasso_dump text,
last_seen timestamp)''' % table_name)
cur.execute('''SELECT column_name FROM information_schema.columns
WHERE table_name = %s''', (table_name,))
existing_fields = set([x[0] for x in cur.fetchall()])
needed_fields = set(['id', 'name', 'email', 'roles', 'is_admin',
'anonymous', 'name_identifiers',
'lasso_dump', 'last_seen'])
from admin.settings import UserFieldsFormDef
formdef = UserFieldsFormDef()
for field in formdef.fields:
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
needed_fields.add('f%s' % field.id)
if 'f%s' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (
table_name,
'f%s' % field.id,
sql_type))
if field.store_display_value:
needed_fields.add('f%s_display' % field.id)
if 'f%s_display' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s varchar''' % (
table_name, 'f%s_display' % field.id))
if field.store_structured_value:
needed_fields.add('f%s_structured' % field.id)
if 'f%s_structured' % field.id not in existing_fields:
cur.execute('''ALTER TABLE %s ADD COLUMN %s bytea''' % (
table_name, 'f%s_structured' % field.id))
# delete obsolete fields
for field in (existing_fields - needed_fields):
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
try:
cur.execute('CREATE INDEX users_name_idx ON users (name)')
conn.commit()
except psycopg2.ProgrammingError:
conn.rollback()
cur.close()
def do_tracking_code_table():
conn, cur = get_connection_and_cursor()
table_name = 'tracking_codes'
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = %s''', (table_name,))
if cur.fetchone()[0] == 0:
cur.execute('''CREATE TABLE %s (id varchar PRIMARY KEY,
formdef_id varchar,
formdata_id varchar)''' % table_name)
cur.execute('''SELECT column_name FROM information_schema.columns
WHERE table_name = %s''', (table_name,))
existing_fields = set([x[0] for x in cur.fetchall()])
needed_fields = set(['id', 'formdef_id', 'formdata_id'])
# delete obsolete fields
for field in (existing_fields - needed_fields):
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
conn.commit()
cur.close()
@guard_postgres
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True):
own_conn = False
if not conn:
own_conn = True
conn, cur = get_connection_and_cursor()
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = %s''', ('wcs_meta',))
if cur.fetchone()[0] == 0:
cur.execute('''CREATE TABLE wcs_meta (id serial PRIMARY KEY,
key varchar,
value varchar)''')
if insert_current_sql_level:
sql_level = SQL_LEVEL
else:
sql_level = 0
cur.execute('''INSERT INTO wcs_meta (id, key, value)
VALUES (DEFAULT, %s, %s)''', ('sql_level', str(sql_level)))
if own_conn:
conn.commit()
cur.close()
@guard_postgres
def redo_views(conn, cur, formdef, rebuild_global_views=False):
if get_publisher().get_site_option('postgresql_views') == 'false':
return
if formdef.id is None:
return
drop_views(formdef, conn, cur)
do_views(formdef, conn, cur, rebuild_global_views=rebuild_global_views)
@guard_postgres
def drop_views(formdef, conn, cur):
# remove the global views
drop_global_views(conn, cur)
if formdef:
# remove the form view itself
cur.execute('''SELECT table_name FROM information_schema.views
WHERE table_name LIKE %s''', ('wcs\_view\_%s\_%%' % formdef.id ,))
else:
# if there's no formdef specified, remove all form views
cur.execute('''SELECT table_name FROM information_schema.views
WHERE table_name LIKE %s''', ('wcs\_view\_%',))
view_names = []
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
def get_view_fields(formdef):
view_fields = []
view_fields.append(("int '%s'" % (formdef.category_id or 0), 'category_id'))
view_fields.append(("int '%s'" % (formdef.id or 0), 'formdef_id'))
for field in ('id', 'user_id', 'user_hash', 'receipt_time', 'status', 'id_display'):
view_fields.append((field, field))
return view_fields
@guard_postgres
def do_views(formdef, conn, cur, rebuild_global_views=True):
# create new view
table_name = get_formdef_table_name(formdef)
view_name = 'wcs_view_%s_%s' % (formdef.id,
get_name_as_sql_identifier(formdef.url_name)[:40])
view_fields = get_view_fields(formdef)
column_names = {}
for field in formdef.fields:
field_key = 'f%s' % field.id
if field.type in ('page', 'title', 'subtitle', 'comment'):
continue
if field.varname:
# the variable should be fine as is but we pass it through
# get_name_as_sql_identifier nevertheless, to be extra sure it
# doesn't contain invalid characters.
field_name = get_name_as_sql_identifier(field.varname)[:50]
else:
field_name = 'f%s_%s' % (field.id, get_name_as_sql_identifier(field.label))
field_name = field_name[:50]
if field_name in column_names:
# it may happen that the same varname is used on multiple fields
# (for example in the case of conditional pages), in that situation
# we suffix the field name with an index count
while field_name in column_names:
column_names[field_name] += 1
field_name = '%s_%s' % (field_name, column_names[field_name])
column_names[field_name] = 1
view_fields.append((field_key, field_name))
if field.store_display_value:
field_key = 'f%s_display' % field.id
view_fields.append((field_key, field_name + '_display'))
view_fields.append(('''ARRAY(SELECT status FROM %s_evolutions '''\
''' WHERE %s.id = %s_evolutions.formdata_id'''\
''' ORDER BY %s_evolutions.time)''' % ((table_name,) * 4),
'status_history'))
# add a is_at_endpoint column, dynamically created againt the endpoint status.
endpoint_status = formdef.workflow.get_endpoint_status()
view_fields.append(('''(SELECT status = ANY(ARRAY[[%s]]::text[]))''' % \
', '.join(["'wf-%s'" % x.id for x in endpoint_status]),
'''is_at_endpoint'''))
view_fields.append(('concerned_roles_array', 'concerned_roles_array'))
view_fields.append(('actions_roles_array', 'actions_roles_array'))
view_fields.append(('fts', 'fts'))
fields_list = ', '.join(['%s AS %s' % x for x in view_fields])
cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % (
view_name, fields_list, table_name))
if rebuild_global_views:
do_global_views(conn, cur) # recreate global views
def drop_global_views(conn, cur):
cur.execute('''SELECT table_name FROM information_schema.views
WHERE table_name LIKE %s''', ('wcs\_category\_%',))
view_names = []
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
for view_name in view_names:
cur.execute('''DROP VIEW IF EXISTS %s''' % view_name)
cur.execute('''DROP VIEW IF EXISTS wcs_all_forms''')
def do_global_views(conn, cur):
# recreate global views
view_names = []
cur.execute('''SELECT table_name FROM information_schema.views
WHERE table_name LIKE %s''', ('wcs\_view\_%',))
while True:
row = cur.fetchone()
if row is None:
break
view_names.append(row[0])
if not view_names:
return
from wcs.formdef import FormDef
fake_formdef = FormDef()
common_fields = get_view_fields(fake_formdef)
common_fields.append(('concerned_roles_array', 'concerned_roles_array'))
common_fields.append(('actions_roles_array', 'actions_roles_array'))
common_fields.append(('fts', 'fts'))
common_fields.append(('is_at_endpoint', 'is_at_endpoint'))
union = ' UNION '.join(['''SELECT %s FROM %s''' % (
', '.join([y[1] for y in common_fields]), x) for x in view_names])
cur.execute('''CREATE VIEW wcs_all_forms AS %s''' % union)
for category in wcs.categories.Category.select():
name = get_name_as_sql_identifier(category.name)[:40]
cur.execute('''CREATE VIEW wcs_category_%s AS SELECT * from wcs_all_forms
WHERE category_id = %s''' % (name, category.id))
class SqlMixin(object):
_table_name = None
@guard_postgres
def keys(cls):
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM %s' % cls._table_name
cur.execute(sql_statement)
ids = [x[0] for x in cur.fetchall()]
conn.commit()
cur.close()
return ids
keys = classmethod(keys)
@guard_postgres
def count(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
if func_clause:
# fallback to counting the result of a select()
return len(cls.select(clause))
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
conn, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
conn.commit()
cur.close()
return count
count = classmethod(count)
@guard_postgres
def get_with_indexed_value(cls, index, value, ignore_errors = False):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %s = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name,
index)
cur.execute(sql_statement, {'value': str(value)})
objects = []
while True:
row = cur.fetchone()
if row is None:
break
objects.append(cls._row2ob(row))
conn.commit()
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
return list(objects)
get_with_indexed_value = classmethod(get_with_indexed_value)
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False):
if id is None:
if ignore_errors:
return None
else:
raise KeyError()
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE id = %%(id)s''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'id': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
get = classmethod(get)
@guard_postgres
def get_ids(cls, ids, ignore_errors=False, keep_order=False):
if not ids:
return []
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE id IN (%s)''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name,
','.join([str(x) for x in ids]))
cur.execute(sql_statement)
objects = cls.get_objects(cur)
conn.commit()
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
if keep_order:
objects_dict = {}
for object in objects:
objects_dict[object.id] = object
objects = [objects_dict[x] for x in ids if objects_dict.get(x)]
return list(objects)
get_ids = classmethod(get_ids)
def get_objects(cls, cur, ignore_errors=False):
objects = []
while True:
row = cur.fetchone()
if row is None:
break
objects.append(cls._row2ob(row))
return objects
get_objects = classmethod(get_objects)
@guard_postgres
def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
where_clauses, parameters, func_clause = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
if order_by:
if order_by.startswith('-'):
order_by = order_by[1:]
sql_statement += ' ORDER BY %s DESC' % order_by
else:
sql_statement += ' ORDER BY %s' % order_by
if not func_clause:
if limit:
sql_statement += ' LIMIT %s' % limit
if offset:
sql_statement += ' OFFSET %s' % offset
cur.execute(sql_statement, parameters)
objects = cls.get_objects(cur)
conn.commit()
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
if func_clause:
objects = (x for x in objects if func_clause(x))
if limit or offset:
objects = _take(objects, limit, offset)
return list(objects)
select = classmethod(select)
def get_sql_dict_from_data(self, data, formdef):
sql_dict = {}
for field in formdef.fields:
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
value = data.get(field.id)
if value is not None:
if field.key in ('ranked-items', 'password'):
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
value = [[x, unicode(y).encode('utf-8')] for x, y in value.items()]
elif sql_type == 'varchar':
assert isinstance(value, basestring)
elif sql_type == 'date':
assert type(value) is time.struct_time
value = datetime.datetime(value.tm_year, value.tm_mon, value.tm_mday)
elif sql_type == 'bytea':
value = bytearray(cPickle.dumps(value))
elif sql_type == 'boolean':
pass
sql_dict['f%s' % field.id] = value
if field.store_display_value:
sql_dict['f%s_display' % field.id] = data.get('%s_display' % field.id)
if field.store_structured_value:
sql_dict['f%s_structured' % field.id] = bytearray(
cPickle.dumps(data.get('%s_structured' % field.id)))
return sql_dict
def _row2obdata(cls, row, formdef):
obdata = {}
i = len(cls._table_static_fields)
for field in formdef.fields:
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
value = row[i]
if value:
if field.key == 'ranked-items':
d = {}
for data, rank in value:
d[data] = int(rank)
value = d
elif field.key == 'password':
d = {}
for fmt, val in value:
d[fmt] = unicode(val, 'utf-8')
value = d
if sql_type == 'date':
value = value.timetuple()
elif sql_type == 'bytea':
value = cPickle.loads(str(value))
obdata[field.id] = value
i += 1
if field.store_display_value:
value = row[i]
obdata['%s_display' % field.id] = value
i += 1
if field.store_structured_value:
value = row[i]
if value is not None:
obdata['%s_structured' % field.id] = cPickle.loads(str(value))
if obdata['%s_structured' % field.id] is None:
del obdata['%s_structured' % field.id]
i += 1
return obdata
_row2obdata = classmethod(_row2obdata)
@guard_postgres
def remove_object(cls, id):
conn, cur = get_connection_and_cursor()
sql_statement = '''DELETE FROM %s
WHERE id = %%(id)s''' % cls._table_name
cur.execute(sql_statement, {'id': str(id)})
conn.commit()
cur.close()
remove_object = classmethod(remove_object)
@guard_postgres
def wipe(cls):
conn, cur = get_connection_and_cursor()
sql_statement = '''DELETE FROM %s''' % cls._table_name
cur.execute(sql_statement, {'id': str(id)})
conn.commit()
cur.close()
wipe = classmethod(wipe)
@guard_postgres
def get_sorted_ids(cls, order_by):
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM %s' % cls._table_name
if order_by.startswith('-'):
order_by = order_by[1:]
sql_statement += ' ORDER BY %s DESC' % order_by
else:
sql_statement += ' ORDER BY %s' % order_by
cur.execute(sql_statement)
ids = [x[0] for x in cur.fetchall()]
conn.commit()
cur.close()
return ids
get_sorted_ids = classmethod(get_sorted_ids)
class SqlFormData(SqlMixin, wcs.formdata.FormData):
_names = None # make sure StorableObject methods fail
_formdef = None
_table_static_fields = [
('id', 'serial'),
('user_id', 'varchar'),
('user_hash', 'varchar'),
('receipt_time', 'timestamp'),
('status', 'varchar'),
('page_no', 'varchar'),
('anonymised', 'timestamptz'),
('workflow_data', 'bytea'),
('id_display', 'varchar'),
('workflow_roles', 'bytea'),
('workflow_roles_array', 'text[]'),
('concerned_roles_array', 'text[]'),
('actions_roles_array', 'text[]'),
('tracking_code', 'varchar'),
('backoffice_submission', 'boolean'),
('submission_context', 'bytea'),
]
def __init__(self, id=None):
self.id = id
self.data = {}
_evolution = None
@guard_postgres
def get_evolution(self):
if self._evolution:
return self._evolution
if not self.id:
self._evolution = []
return self._evolution
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT id, who, status, time, comment, parts FROM %s_evolutions
WHERE formdata_id = %%(id)s
ORDER BY id''' % self._table_name
cur.execute(sql_statement, {'id': self.id})
self._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
self._evolution.append(self._row2evo(row))
conn.commit()
cur.close()
return self._evolution
def _row2evo(cls, row):
o = wcs.formdata.Evolution()
o._sql_id, o.who, o.status, o.time, o.comment = tuple(row[:5])
if o.time:
o.time = o.time.timetuple()
if row[5]:
o.parts = cPickle.loads(str(row[5]))
return o
_row2evo = classmethod(_row2evo)
def set_evolution(self, value):
self._evolution = value
evolution = property(get_evolution, set_evolution)
@guard_postgres
def load_all_evolutions(cls, values):
# Typically formdata.evolution is loaded on-demand (see above
# property()) and this is fine to minimize queries, especially when
# dealing with a single formdata. However in some places (to compute
# statistics for example) it is sometimes useful to access .evolution
# on a serie of formdata and in that case, it's more efficient to
# optimize the process loading all evolutions in a single batch query.
object_dict = dict([(x.id, x) for x in values if x.id and x._evolution is None])
if not object_dict:
return
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT id, who, status, time, comment, parts, formdata_id
FROM %s_evolutions''' % cls._table_name
sql_statement += ''' WHERE formdata_id IN %(object_ids)s ORDER BY id'''
cur.execute(sql_statement, {'object_ids': tuple(object_dict.keys())})
for value in values:
value._evolution = []
while True:
row = cur.fetchone()
if row is None:
break
_sql_id, who, status, time, comment, parts, formdata_id = tuple(row[:7])
formdata = object_dict.get(formdata_id)
if not formdata:
continue
formdata._evolution.append(formdata._row2evo(row))
conn.commit()
cur.close()
load_all_evolutions = classmethod(load_all_evolutions)
@guard_postgres
def store(self):
sql_dict = {
'user_id': self.user_id,
'user_hash': self.user_hash,
'status': self.status,
'page_no': self.page_no,
'workflow_data': bytearray(cPickle.dumps(self.workflow_data)),
'id_display': self.id_display,
'anonymised': self.anonymised,
'tracking_code': self.tracking_code,
'backoffice_submission': self.backoffice_submission,
'submission_context': self.submission_context,
}
if self.receipt_time:
sql_dict['receipt_time'] = datetime.datetime.fromtimestamp(time.mktime(self.receipt_time)),
else:
sql_dict['receipt_time'] = None
if self.workflow_roles:
sql_dict['workflow_roles'] = bytearray(cPickle.dumps(self.workflow_roles))
sql_dict['workflow_roles_array'] = [str(x) for x in self.workflow_roles.values() if x is not None]
else:
sql_dict['workflow_roles'] = None
sql_dict['workflow_roles_array'] = None
if self.submission_context:
sql_dict['submission_context'] = bytearray(cPickle.dumps(self.submission_context))
else:
sql_dict['submission_context'] = None
sql_dict['concerned_roles_array'] = [str(x) for x in self.concerned_roles if x]
sql_dict['actions_roles_array'] = [str(x) for x in self.actions_roles if x]
sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef))
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]))
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s) RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]))
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
if self._evolution:
for evo in self._evolution:
sql_dict = {}
if hasattr(evo, '_sql_id'):
sql_dict.update({'id': evo._sql_id})
sql_statement = '''UPDATE %s_evolutions SET
comment = %%(comment)s,
parts = %%(parts)s
WHERE id = %%(id)s
RETURNING id''' % self._table_name
else:
sql_statement = '''INSERT INTO %s_evolutions (
id, who, status,
time, comment, parts,
formdata_id)
VALUES (DEFAULT, %%(who)s, %%(status)s,
%%(time)s, %%(comment)s,
%%(parts)s, %%(formdata_id)s)
RETURNING id''' % self._table_name
sql_dict.update({
'who': evo.who,
'status': evo.status,
'time': datetime.datetime.fromtimestamp(time.mktime(evo.time)),
'comment': evo.comment,
'formdata_id': self.id,
})
if evo.parts:
sql_dict['parts'] = bytearray(cPickle.dumps(evo.parts))
else:
sql_dict['parts'] = None
cur.execute(sql_statement, sql_dict)
evo._sql_id = cur.fetchone()[0]
fts_strings = [str(self.id)]
if self.tracking_code:
fts_strings.append(self.tracking_code)
for field in self._formdef.fields:
if not self.data.get(field.id):
continue
value = None
if field.key in ('string', 'text', 'email'):
value = self.data.get(field.id)
elif field.key in ('item', 'items'):
value = self.data.get('%s_display' % field.id)
if value:
if isinstance(value, basestring):
fts_strings.append(value)
elif type(value) in (tuple, list):
fts_strings.extend(value)
if self._evolution:
for evo in self._evolution:
if evo.comment:
fts_strings.append(evo.comment)
sql_statement = '''UPDATE %s SET fts = to_tsvector( %%(fts)s)
WHERE id = %%(id)s''' % self._table_name
cur.execute(sql_statement, {'id': self.id, 'fts': ' '.join(fts_strings)})
conn.commit()
cur.close()
def _row2ob(cls, row):
o = cls()
for static_field, value in zip(cls._table_static_fields,
tuple(row[:len(cls._table_static_fields)])):
setattr(o, static_field[0], value)
if o.receipt_time:
o.receipt_time = o.receipt_time.timetuple()
if o.workflow_data:
o.workflow_data = cPickle.loads(str(o.workflow_data))
if o.workflow_roles:
o.workflow_roles = cPickle.loads(str(o.workflow_roles))
if o.submission_context:
o.submission_context = cPickle.loads(str(o.submission_context))
o.data = cls._row2obdata(row, cls._formdef)
return o
_row2ob = classmethod(_row2ob)
def get_data_fields(cls):
data_fields = []
for field in cls._formdef.fields:
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append('f%s' % field.id)
if field.store_display_value:
data_fields.append('f%s_display' % field.id)
if field.store_structured_value:
data_fields.append('f%s_structured' % field.id)
return data_fields
get_data_fields = classmethod(get_data_fields)
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False):
if id is None:
if ignore_errors:
return None
else:
raise KeyError()
else:
try:
int(id)
except ValueError:
raise KeyError()
conn, cur = get_connection_and_cursor()
potential_comma = ', '
if not cls.get_data_fields():
potential_comma = ''
sql_statement = '''SELECT %s
%s
%s
FROM %s
WHERE id = %%(id)s''' % (
', '.join([x[0] for x in cls._table_static_fields]),
potential_comma,
', '.join(cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'id': str(id)})
row = cur.fetchone()
if row is None:
cur.close()
if ignore_errors:
return None
raise KeyError()
cur.close()
return cls._row2ob(row)
get = classmethod(get)
@guard_postgres
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True):
conn, cur = get_connection_and_cursor()
if type(value) is int:
value = str(value)
if '%s_array' % index in [x[0] for x in cls._table_static_fields]:
sql_statement = '''SELECT id FROM %s WHERE %%(value)s = ANY (%s_array)''' % (
cls._table_name,
index)
else:
sql_statement = '''SELECT id FROM %s WHERE %s = %%(value)s''' % (
cls._table_name,
index)
cur.execute(sql_statement, {'value': value})
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
get_ids_with_indexed_value = classmethod(get_ids_with_indexed_value)
@guard_postgres
def get_ids_from_query(cls, query):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT id FROM %s
WHERE fts @@ plainto_tsquery(%%(value)s)''' % cls._table_name
cur.execute(sql_statement, {'value': query})
all_ids = [x[0] for x in cur.fetchall()]
cur.close()
return all_ids
get_ids_from_query = classmethod(get_ids_from_query)
@guard_postgres
def fix_sequences(cls):
conn, cur = get_connection_and_cursor()
for table_name in (cls._table_name, '%s_evolutions' % cls._table_name):
sql_statement = '''select max(id) from %s''' % table_name
cur.execute(sql_statement)
max_id = cur.fetchone()[0]
if max_id:
sql_statement = '''ALTER SEQUENCE %s RESTART %s''' % (
'%s_id_seq' % table_name, max_id+1)
cur.execute(sql_statement)
conn.commit()
cur.close()
fix_sequences = classmethod(fix_sequences)
def rebuild_security(cls):
formdatas = cls.select()
conn, cur = get_connection_and_cursor()
for formdata in formdatas:
sql_statement = '''UPDATE %s
SET concerned_roles_array = %%(roles)s,
actions_roles_array = %%(actions_roles)s
WHERE id = %%(id)s''' % cls._table_name
cur.execute(sql_statement, {
'id': formdata.id,
'roles': [str(x) for x in formdata.concerned_roles if x],
'actions_roles': [str(x) for x in formdata.actions_roles if x]})
conn.commit()
cur.close()
rebuild_security = classmethod(rebuild_security)
def do_tracking_code_table(cls):
do_tracking_code_table()
do_tracking_code_table = classmethod(do_tracking_code_table)
class SqlUser(SqlMixin, wcs.users.User):
_table_name = 'users'
_table_static_fields = [
('id', 'serial'),
('name', 'varchar'),
('email', 'varchar'),
('roles', 'varchar[]'),
('is_admin', 'bool'),
('anonymous', 'bool'),
('name_identifiers', 'varchar[]'),
('lasso_dump', 'text'),
('last_seen', 'timestamp')
]
id = None
def __init__(self, name=None):
self.name = name
self.name_identifiers = []
self.roles = []
@guard_postgres
def store(self):
sql_dict = {
'name': self.name,
'email': self.email,
'roles': self.roles,
'is_admin': self.is_admin,
'anonymous': self.anonymous,
'name_identifiers': self.name_identifiers,
'lasso_dump': self.lasso_dump,
'last_seen': None,
}
if self.last_seen:
sql_dict['last_seen'] = datetime.datetime.fromtimestamp(self.last_seen),
sql_dict.update(self.get_sql_dict_from_data(self.form_data, self.get_formdef()))
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (id, %s)
VALUES (DEFAULT, %s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]))
cur.execute(sql_statement, sql_dict)
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
column_names = sql_dict.keys()
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]))
cur.execute(sql_statement, sql_dict)
conn.commit()
cur.close()
def _row2ob(cls, row):
o = cls()
(o.id, o.name, o.email, o.roles, o.is_admin, o.anonymous,
o.name_identifiers, o.lasso_dump,
o.last_seen) = tuple(row[:9])
if o.last_seen:
o.last_seen = time.mktime(o.last_seen.timetuple())
if o.roles:
o.roles = [str(x) for x in o.roles]
o.form_data = cls._row2obdata(row, cls.get_formdef())
return o
_row2ob = classmethod(_row2ob)
def get_data_fields(cls):
data_fields = []
for field in cls.get_formdef().fields:
sql_type = SQL_TYPE_MAPPING.get(field.key, 'varchar')
if sql_type is None:
continue
data_fields.append('f%s' % field.id)
if field.store_display_value:
data_fields.append('f%s_display' % field.id)
if field.store_structured_value:
data_fields.append('f%s_structured' % field.id)
return data_fields
get_data_fields = classmethod(get_data_fields)
@guard_postgres
def fix_sequences(cls):
conn, cur = get_connection_and_cursor()
sql_statement = '''select max(id) from %s''' % cls._table_name
cur.execute(sql_statement)
max_id = cur.fetchone()[0]
if max_id is not None:
sql_statement = '''ALTER SEQUENCE %s_id_seq RESTART %s''' % (
cls._table_name, max_id+1)
cur.execute(sql_statement)
conn.commit()
cur.close()
fix_sequences = classmethod(fix_sequences)
@guard_postgres
def get_users_with_name_identifier(cls, name_identifier):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %%(value)s = ANY(name_identifiers)''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'value': name_identifier})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
get_users_with_name_identifier = classmethod(get_users_with_name_identifier)
@guard_postgres
def get_users_with_email(cls, email):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE email = %%(value)s''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'value': email})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
get_users_with_email = classmethod(get_users_with_email)
@guard_postgres
def get_users_with_role(cls, role_id):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %%(value)s = ANY(roles)''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'value': str(role_id)})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
get_users_with_role = classmethod(get_users_with_role)
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
_table_name = 'tracking_codes'
_table_static_fields = [
('id', 'varchar'),
('formdef_id', 'varchar'),
('formdata_id', 'varchar'),
]
id = None
@guard_postgres
def store(self):
sql_dict = {
'id': self.id,
'formdef_id': self.formdef_id,
'formdata_id': self.formdata_id
}
conn, cur = get_connection_and_cursor()
if not self.id:
column_names = sql_dict.keys()
sql_dict['id'] = self.get_new_id()
sql_statement = '''INSERT INTO %s (%s)
VALUES (%s)
RETURNING id''' % (
self._table_name,
', '.join(column_names),
', '.join(['%%(%s)s' % x for x in column_names]))
while True:
try:
cur.execute(sql_statement, sql_dict)
except psycopg2.IntegrityError:
conn.rollback()
sql_dict['id'] = self.get_new_id()
else:
break
self.id = cur.fetchone()[0]
else:
column_names = sql_dict.keys()
sql_dict['id'] = self.id
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
self._table_name,
', '.join(['%s = %%(%s)s' % (x,x) for x in column_names]))
cur.execute(sql_statement, sql_dict)
if cur.fetchone() is None:
raise AssertionError()
conn.commit()
cur.close()
def _row2ob(cls, row):
o = cls()
(o.id, o.formdef_id, o.formdata_id) = tuple(row[:3])
return o
_row2ob = classmethod(_row2ob)
def get_data_fields(cls):
return []
get_data_fields = classmethod(get_data_fields)
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
__table_static_fields = []
_formdef_cache = {}
@classproperty
def _table_static_fields(cls):
if cls.__table_static_fields:
return cls.__table_static_fields
from wcs.formdef import FormDef
fake_formdef = FormDef()
common_fields = get_view_fields(fake_formdef)
cls.__table_static_fields = [(x[1], x[0]) for x in common_fields]
return cls.__table_static_fields
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_objects(cls, *args, **kwargs):
cls._formdef_cache = {}
return super(AnyFormData, cls).get_objects(*args, **kwargs)
@classmethod
def _row2ob(cls, row):
formdef_id = row[1]
from wcs.formdef import FormDef
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
o = formdef.data_class()()
for static_field, value in zip(cls._table_static_fields,
tuple(row[:len(cls._table_static_fields)])):
setattr(o, static_field[0], value)
return o
def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None):
clause = [NotNull('receipt_time')]
table_name = 'wcs_all_forms'
if criterias:
for criteria in criterias:
if criteria.__class__.__name__ == 'Equal' and \
criteria.attribute == 'formdef_id':
# if there's a formdef_id specified, switch to using the
# specific table so we have access to all fields
from wcs.formdef import FormDef
table_name = get_formdef_table_name(FormDef.get(criteria.value))
continue
clause.append(criteria)
if period_start:
clause.append(GreaterOrEqual('receipt_time', period_start))
if period_end:
clause.append(LessOrEqual('receipt_time', period_end))
where_clauses, params, func_clause = parse_clause(clause)
parameters.update(params)
statement = ' FROM %s ' % table_name
statement += ' WHERE ' + ' AND '.join(where_clauses)
return statement
@guard_postgres
def get_weekday_totals(period_start=None, period_end=None, criterias=None):
conn, cur = get_connection_and_cursor()
statement = '''SELECT DATE_PART('dow', receipt_time) AS weekday, COUNT(*)'''
parameters = {}
statement += get_period_query(period_start, period_end, criterias, parameters)
statement += ' GROUP BY weekday ORDER BY weekday'''
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x), y) for x, y in result]
coverage = [x[0] for x in result]
for weekday in range(7):
if not weekday in coverage:
result.append((weekday, 0))
result.sort()
conn.commit()
cur.close()
return result
@guard_postgres
def get_hour_totals(period_start=None, period_end=None, criterias=None):
conn, cur = get_connection_and_cursor()
statement = '''SELECT DATE_PART('hour', receipt_time) AS hour, COUNT(*)'''
parameters = {}
statement += get_period_query(period_start, period_end, criterias, parameters)
statement += ' GROUP BY hour ORDER BY hour'
cur.execute(statement, parameters)
result = cur.fetchall()
result = [(int(x), y) for x, y in result]
coverage = [x[0] for x in result]
for hour in range(24):
if not hour in coverage:
result.append((hour, 0))
result.sort()
conn.commit()
cur.close()
return result
@guard_postgres
def get_monthly_totals(period_start=None, period_end=None, criterias=None):
conn, cur = get_connection_and_cursor()
statement = '''SELECT DATE_TRUNC('month', receipt_time) AS month, COUNT(*) '''
parameters = {}
statement += get_period_query(period_start, period_end, criterias, parameters)
statement += ' GROUP BY month ORDER BY month'''
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [('%d-%02d' % x.timetuple()[:2], y) for x, y in raw_result]
if result:
coverage = [x[0] for x in result]
current_month = raw_result[0][0]
last_month = raw_result[-1][0]
while current_month < last_month:
label = '%d-%02d' % current_month.timetuple()[:2]
if not label in coverage:
result.append((label, 0))
current_month = current_month + datetime.timedelta(days=31)
current_month = current_month - datetime.timedelta(days=current_month.day-1)
result.sort()
conn.commit()
cur.close()
return result
@guard_postgres
def get_yearly_totals(period_start=None, period_end=None, criterias=None):
conn, cur = get_connection_and_cursor()
statement = '''SELECT DATE_TRUNC('year', receipt_time) AS year, COUNT(*)'''
parameters = {}
statement += get_period_query(period_start, period_end, criterias, parameters)
statement += ' GROUP BY year ORDER BY year'
cur.execute(statement, parameters)
raw_result = cur.fetchall()
result = [(str(x.year), y) for x, y in raw_result]
if result:
coverage = [x[0] for x in result]
current_year = raw_result[0][0]
last_year = raw_result[-1][0]
while current_year < last_year:
label = str(current_year.year)
if not label in coverage:
result.append((label, 0))
current_year = current_year + datetime.timedelta(days=366)
result.sort()
conn.commit()
cur.close()
return result
SQL_LEVEL = 8
def migrate_global_views(conn, cur):
cur.execute('''SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = %s''', ('wcs_all_forms',))
existing_fields = set([x[0] for x in cur.fetchall()])
if 'formdef_id' not in existing_fields:
drop_global_views(conn, cur)
do_global_views(conn, cur)
@guard_postgres
def get_sql_level(conn, cur):
do_meta_table(conn, cur, insert_current_sql_level=False)
cur.execute('''SELECT value FROM wcs_meta WHERE key = %s''', ('sql_level', ))
sql_level = int(cur.fetchone()[0])
return sql_level
def migrate_views(conn, cur):
drop_views(None, conn, cur)
from wcs.formdef import FormDef
for formdef in FormDef.select():
# make sure all formdefs have up-to-date views
do_formdef_tables(formdef, conn=conn, cur=cur, rebuild_views=True)
migrate_global_views(conn, cur)
@guard_postgres
def migrate():
conn, cur = get_connection_and_cursor()
sql_level = get_sql_level(conn, cur)
if sql_level < 0:
# fake code to help in tetsting the error code path.
raise RuntimeError()
if sql_level < 1: # 1: introduction of tracking_code table
do_tracking_code_table()
if sql_level < 2: # 2: introduction of formdef_id in views
migrate_views(conn, cur)
if sql_level < 4:
# 3: introduction of _structured for user fields
# 4: removal of identification_token
do_user_table()
if sql_level < 5:
# 5: add concerned_roles_array, is_at_endpoint and fts to views
migrate_views(conn, cur)
if sql_level < 6:
# 6: add actions_roles_array to tables and views
from wcs.formdef import FormDef
migrate_views(conn, cur)
for formdef in FormDef.select():
formdef.data_class().rebuild_security()
if sql_level < 8:
# 7: add backoffice_submission to tables and views
# 8: add submission_context to tables
migrate_views(conn, cur)
cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', (
str(SQL_LEVEL), 'sql_level'))
conn.commit()
cur.close()