sql: add a way to select() over different formdefs (#8179)

This commit is contained in:
Frédéric Péters 2015-09-06 17:16:43 +02:00
parent 30ef526b60
commit cca1cf925b
2 changed files with 99 additions and 0 deletions

View File

@ -862,3 +862,58 @@ def test_views_fts():
cur.execute('''SELECT COUNT(*) FROM wcs_all_forms WHERE fts @@ plainto_tsquery(%s)''', ('bar',))
assert bool(cur.fetchone()[0] == 1)
@postgresql
def test_select_any_formdata():
drop_formdef_tables()
conn, cur = sql.get_connection_and_cursor()
now = datetime.datetime.now()
cnt = 0
for i in range(5):
formdef = FormDef()
formdef.name = 'test any %d' % i
formdef.fields = []
formdef.store()
data_class = formdef.data_class(mode='sql')
for j in range(20):
formdata = data_class()
formdata.just_created()
formdata.user_id = '%s' % ((i+j)%11)
# set receipt_time to make sure all entries are unique.
formdata.receipt_time = (now + datetime.timedelta(seconds=cnt)).timetuple()
formdata.status = ['wf-new', 'wf-accepted', 'wf-rejected', 'wf-finished'][(i+j)%4]
formdata.store()
cnt += 1
# test generic select
objects = sql.AnyFormData.select()
assert len(objects) == 100
# make sure valid formdefs are used
assert len([x for x in objects if x.formdef.name == 'test any 0']) == 20
assert len([x for x in objects if x.formdef.name == 'test any 1']) == 20
# test sorting
objects = sql.AnyFormData.select(order_by='receipt_time')
assert len(objects) == 100
objects2 = sql.AnyFormData.select(order_by='-receipt_time')
assert [(x.formdef_id, x.id) for x in objects2] == list(reversed(
[(x.formdef_id, x.id) for x in objects]))
# test clauses
objects2 = sql.AnyFormData.select([st.Equal('user_id', '0')])
assert len(objects2) == len([x for x in objects if x.user_id == '0'])
objects2 = sql.AnyFormData.select([st.Equal('is_at_endpoint', True)])
assert len(objects2) == len([x for x in objects if x.status in ('wf-rejected', 'wf-finished')])
# test offset/limit
objects2 = sql.AnyFormData.select(order_by='receipt_time', limit=10, offset=0)
assert [(x.formdef_id, x.id) for x in objects2] == [(x.formdef_id, x.id) for x in objects][:10]
objects2 = sql.AnyFormData.select(order_by='receipt_time', limit=10, offset=20)
assert [(x.formdef_id, x.id) for x in objects2] == [(x.formdef_id, x.id) for x in objects][20:30]

View File

@ -1469,6 +1469,50 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
get_data_fields = classmethod(get_data_fields)
class classproperty(object):
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
__table_static_fields = []
_formdef_cache = {}
@classproperty
def _table_static_fields(cls):
if cls.__table_static_fields:
return cls.__table_static_fields
from wcs.formdef import FormDef
fake_formdef = FormDef()
common_fields = get_view_fields(fake_formdef)
cls.__table_static_fields = [(x[1], x[0]) for x in common_fields]
return cls.__table_static_fields
@classmethod
def get_data_fields(cls):
return []
@classmethod
def get_objects(cls, *args, **kwargs):
cls._formdef_cache = {}
return super(AnyFormData, cls).get_objects(*args, **kwargs)
@classmethod
def _row2ob(cls, row):
formdef_id = row[1]
from wcs.formdef import FormDef
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
o = formdef.data_class()()
for static_field, value in zip(cls._table_static_fields,
tuple(row[:len(cls._table_static_fields)])):
setattr(o, static_field[0], value)
return o
def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None):
clause = [NotNull('receipt_time')]
table_name = 'wcs_all_forms'