sql: add a way to select() over different formdefs (#8179)
This commit is contained in:
parent
30ef526b60
commit
cca1cf925b
|
@ -862,3 +862,58 @@ def test_views_fts():
|
|||
|
||||
cur.execute('''SELECT COUNT(*) FROM wcs_all_forms WHERE fts @@ plainto_tsquery(%s)''', ('bar',))
|
||||
assert bool(cur.fetchone()[0] == 1)
|
||||
|
||||
@postgresql
|
||||
def test_select_any_formdata():
|
||||
drop_formdef_tables()
|
||||
conn, cur = sql.get_connection_and_cursor()
|
||||
|
||||
now = datetime.datetime.now()
|
||||
|
||||
cnt = 0
|
||||
for i in range(5):
|
||||
formdef = FormDef()
|
||||
formdef.name = 'test any %d' % i
|
||||
formdef.fields = []
|
||||
formdef.store()
|
||||
|
||||
data_class = formdef.data_class(mode='sql')
|
||||
for j in range(20):
|
||||
formdata = data_class()
|
||||
formdata.just_created()
|
||||
formdata.user_id = '%s' % ((i+j)%11)
|
||||
# set receipt_time to make sure all entries are unique.
|
||||
formdata.receipt_time = (now + datetime.timedelta(seconds=cnt)).timetuple()
|
||||
formdata.status = ['wf-new', 'wf-accepted', 'wf-rejected', 'wf-finished'][(i+j)%4]
|
||||
formdata.store()
|
||||
cnt += 1
|
||||
|
||||
# test generic select
|
||||
objects = sql.AnyFormData.select()
|
||||
assert len(objects) == 100
|
||||
|
||||
# make sure valid formdefs are used
|
||||
assert len([x for x in objects if x.formdef.name == 'test any 0']) == 20
|
||||
assert len([x for x in objects if x.formdef.name == 'test any 1']) == 20
|
||||
|
||||
# test sorting
|
||||
objects = sql.AnyFormData.select(order_by='receipt_time')
|
||||
assert len(objects) == 100
|
||||
|
||||
objects2 = sql.AnyFormData.select(order_by='-receipt_time')
|
||||
assert [(x.formdef_id, x.id) for x in objects2] == list(reversed(
|
||||
[(x.formdef_id, x.id) for x in objects]))
|
||||
|
||||
# test clauses
|
||||
objects2 = sql.AnyFormData.select([st.Equal('user_id', '0')])
|
||||
assert len(objects2) == len([x for x in objects if x.user_id == '0'])
|
||||
|
||||
objects2 = sql.AnyFormData.select([st.Equal('is_at_endpoint', True)])
|
||||
assert len(objects2) == len([x for x in objects if x.status in ('wf-rejected', 'wf-finished')])
|
||||
|
||||
# test offset/limit
|
||||
objects2 = sql.AnyFormData.select(order_by='receipt_time', limit=10, offset=0)
|
||||
assert [(x.formdef_id, x.id) for x in objects2] == [(x.formdef_id, x.id) for x in objects][:10]
|
||||
|
||||
objects2 = sql.AnyFormData.select(order_by='receipt_time', limit=10, offset=20)
|
||||
assert [(x.formdef_id, x.id) for x in objects2] == [(x.formdef_id, x.id) for x in objects][20:30]
|
||||
|
|
44
wcs/sql.py
44
wcs/sql.py
|
@ -1469,6 +1469,50 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
|
|||
get_data_fields = classmethod(get_data_fields)
|
||||
|
||||
|
||||
class classproperty(object):
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
return self.f(owner)
|
||||
|
||||
|
||||
class AnyFormData(SqlMixin):
|
||||
_table_name = 'wcs_all_forms'
|
||||
__table_static_fields = []
|
||||
_formdef_cache = {}
|
||||
|
||||
@classproperty
|
||||
def _table_static_fields(cls):
|
||||
if cls.__table_static_fields:
|
||||
return cls.__table_static_fields
|
||||
from wcs.formdef import FormDef
|
||||
fake_formdef = FormDef()
|
||||
common_fields = get_view_fields(fake_formdef)
|
||||
cls.__table_static_fields = [(x[1], x[0]) for x in common_fields]
|
||||
return cls.__table_static_fields
|
||||
|
||||
@classmethod
|
||||
def get_data_fields(cls):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_objects(cls, *args, **kwargs):
|
||||
cls._formdef_cache = {}
|
||||
return super(AnyFormData, cls).get_objects(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _row2ob(cls, row):
|
||||
formdef_id = row[1]
|
||||
from wcs.formdef import FormDef
|
||||
formdef = cls._formdef_cache.setdefault(formdef_id, FormDef.get(formdef_id))
|
||||
o = formdef.data_class()()
|
||||
for static_field, value in zip(cls._table_static_fields,
|
||||
tuple(row[:len(cls._table_static_fields)])):
|
||||
setattr(o, static_field[0], value)
|
||||
return o
|
||||
|
||||
|
||||
def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None):
|
||||
clause = [NotNull('receipt_time')]
|
||||
table_name = 'wcs_all_forms'
|
||||
|
|
Loading…
Reference in New Issue