sql: use batch iteration on ids instead of named cursors (#58013)

Named cursors imposed the use of isolated connections and were misused
resulting in reading using one SQL query by row (because of the use of
.fetchone() with cursors). This commit revert to the behaviour of one
connection per request and reading full SQL statement results at a time
without using cursors.
This commit is contained in:
Benjamin Dauvergne 2021-10-20 10:49:44 +02:00 committed by Frédéric Péters
parent 44a1e2f0e6
commit 5389956f22
3 changed files with 70 additions and 36 deletions

View File

@ -431,6 +431,7 @@ class StorableObject:
limit=None,
offset=None,
iterator=False,
itersize=None,
**kwargs,
):
# iterator: only for compatibility with sql select()

View File

@ -20,7 +20,6 @@ import io
import json
import re
import time
import uuid
import psycopg2
import psycopg2.extensions
@ -365,11 +364,11 @@ def site_unicode(value):
return force_text(value, get_publisher().site_charset)
def get_connection(new=False, isolate=False):
if new and not isolate:
def get_connection(new=False):
if new:
cleanup_connection()
if isolate or not getattr(get_publisher(), 'pgconn', None):
if not getattr(get_publisher(), 'pgconn', None):
postgresql_cfg = {}
for param in ('database', 'user', 'password', 'host', 'port'):
value = get_cfg('postgresql', {}).get(param)
@ -380,11 +379,9 @@ def get_connection(new=False, isolate=False):
try:
pgconn = psycopg2.connect(**postgresql_cfg)
except psycopg2.Error:
if new or isolate:
if new:
raise
pgconn = None
if isolate:
return pgconn
get_publisher().pgconn = pgconn
@ -1433,7 +1430,7 @@ class SqlMixin:
_table_name = None
_numerical_id = True
_table_select_skipped_fields = []
_iterate_on_server = True
_has_id = True
@classmethod
@guard_postgres
@ -1653,16 +1650,38 @@ class SqlMixin:
@classmethod
@guard_postgres
def select_iterator(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None):
def select_iterator(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
itersize=None,
):
table_static_fields = [
x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0]
for x in cls._table_static_fields
]
sql_statement = '''SELECT %s
FROM %s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
def retrieve():
for object in cls.get_objects(cur, iterator=True):
if object is None:
continue
if func_clause and not func_clause(object):
continue
yield object
if itersize and cls._has_id:
# this case concerns almost all data tables: formdata, card, users, roles
sql_statement = '''SELECT id FROM %s''' % cls._table_name
else:
# this case concerns aggregated views like wcs_all_forms (class
# AnyFormData) which does not have a surrogate key id column
sql_statement = '''SELECT %s FROM %s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
@ -1677,31 +1696,45 @@ class SqlMixin:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
if cls._iterate_on_server:
conn = get_connection(isolate=True)
cur = conn.cursor(name='select_iterator_%s' % uuid.uuid4())
else:
conn, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
try:
for object in cls.get_objects(cur, iterator=True):
if object is None:
continue
if func_clause and not func_clause(object):
continue
yield object
finally:
cur.close()
conn, cur = get_connection_and_cursor()
with cur:
cur.execute(sql_statement, parameters)
conn.commit()
if cls._iterate_on_server:
# close isolated connection
conn.close()
if itersize and cls._has_id:
sql_id_statement = '''SELECT %s FROM %s WHERE id IN %%s''' % (
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
sql_id_statement += cls.get_order_by_clause(order_by)
ids = [row[0] for row in cur]
while ids:
cur.execute(sql_id_statement, [tuple(ids[:itersize])])
conn.commit()
yield from retrieve()
ids = ids[itersize:]
else:
yield from retrieve()
@classmethod
@guard_postgres
def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=False):
def select(
cls,
clause=None,
order_by=None,
ignore_errors=False,
limit=None,
offset=None,
iterator=False,
itersize=None,
):
if iterator and not itersize:
itersize = 200
objects = cls.select_iterator(
clause=clause, order_by=order_by, ignore_errors=ignore_errors, limit=limit, offset=offset
clause=clause,
order_by=order_by,
ignore_errors=ignore_errors,
limit=limit,
offset=offset,
)
func_clause = parse_clause(clause)[2]
if func_clause and (limit or offset):
@ -3149,7 +3182,7 @@ class classproperty:
class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
_formdef_cache = {}
_iterate_on_server = False
_has_id = False
@classproperty
def _table_static_fields(self):

View File

@ -350,7 +350,7 @@ def _apply_timeouts(publisher, **kwargs):
(datetime.datetime.now() - datetime.timedelta(seconds=delay)).timetuple(),
),
]
formdatas = formdata_class.select_iterator(criterias, ignore_errors=True)
formdatas = formdata_class.select_iterator(criterias, ignore_errors=True, itersize=200)
else:
formdatas = formdata_class.get_with_indexed_value('status', status_id, ignore_errors=True)