sql: prefetch evolutions in user forms API (#38903)

This commit is contained in:
Frédéric Péters 2020-01-12 13:47:48 +01:00
parent bbaa2d18be
commit cbebd7558f
3 changed files with 49 additions and 0 deletions

View File

@ -1350,6 +1350,38 @@ def test_select_any_formdata():
objects2 = sql.AnyFormData.select(order_by='receipt_time', limit=10, offset=20)
assert [(x.formdef_id, x.id) for x in objects2] == [(x.formdef_id, x.id) for x in objects][20:30]
@postgresql
def test_load_all_evolutions_on_any_formdata():
drop_formdef_tables()
conn, cur = sql.get_connection_and_cursor()
now = datetime.datetime.now()
cnt = 0
for i in range(5):
formdef = FormDef()
formdef.name = 'test any %d' % i
formdef.fields = []
formdef.store()
data_class = formdef.data_class(mode='sql')
for j in range(20):
formdata = data_class()
formdata.just_created()
formdata.user_id = '%s' % ((i+j)%11)
# set receipt_time to make sure all entries are unique.
formdata.receipt_time = (now + datetime.timedelta(seconds=cnt)).timetuple()
formdata.status = ['wf-new', 'wf-accepted', 'wf-rejected', 'wf-finished'][(i+j)%4]
formdata.store()
cnt += 1
objects = sql.AnyFormData.select()
assert len(objects) == 100
assert len([x for x in objects if x._evolution is None]) == 100
sql.AnyFormData.load_all_evolutions(objects)
assert len([x for x in objects if x._evolution is not None]) == 100
@postgresql
def test_geoloc_in_global_view():
drop_formdef_tables()

View File

@ -689,6 +689,12 @@ class ApiUserDirectory(Directory):
# ignore confidential forms
forms = [x for x in forms if x.readable or not x.formdef.skip_from_360_view]
if get_publisher().is_using_postgresql() and not get_request().form.get('full') == 'on':
# prefetch evolutions to avoid individual loads when computing
# formdata.get_visible_status().
from wcs import sql
sql.AnyFormData.load_all_evolutions(forms)
include_drafts = include_drafts or get_query_flag('include-drafts')
result = []
for form in forms:

View File

@ -2131,6 +2131,17 @@ class AnyFormData(SqlMixin):
o.geolocations = {'base': {'lon': o.geoloc_base_x, 'lat': o.geoloc_base_y}}
return o
@classmethod
@guard_postgres
def load_all_evolutions(cls, formdatas):
classes = {}
for formdata in formdatas:
if not formdata._table_name in classes:
classes[formdata._table_name] = []
classes[formdata._table_name].append(formdata)
for formdatas in classes.values():
formdatas[0].load_all_evolutions(formdatas)
def get_period_query(period_start=None, period_end=None, criterias=None, parameters=None):
clause = [NotNull('receipt_time')]