sql: use a server cursor to iterate over rows (#54242)

This commit is contained in:
Frédéric Péters 2021-05-24 22:04:32 +02:00
parent 374a282488
commit 471c4a32dc
1 changed files with 21 additions and 8 deletions

View File

@ -21,6 +21,7 @@ import json
import re
import time
import unicodedata
import uuid
import psycopg2
import psycopg2.extensions
@ -332,21 +333,27 @@ def site_unicode(value):
return force_text(value, get_publisher().site_charset)
def get_connection(new=False):
if new:
def get_connection(new=False, isolate=False):
if new and not isolate:
cleanup_connection()
if not hasattr(get_publisher(), 'pgconn') or get_publisher().pgconn is None:
if isolate or not getattr(get_publisher(), 'pgconn', None):
postgresql_cfg = {}
for param in ('database', 'user', 'password', 'host', 'port'):
value = get_cfg('postgresql', {}).get(param)
if value:
postgresql_cfg[param] = value
try:
get_publisher().pgconn = psycopg2.connect(**postgresql_cfg)
pgconn = psycopg2.connect(**postgresql_cfg)
except psycopg2.Error:
if new:
raise
get_publisher().pgconn = None
pgconn = None
if isolate:
return pgconn
get_publisher().pgconn = pgconn
return get_publisher().pgconn
@ -1386,6 +1393,7 @@ class SqlMixin:
_table_name = None
_numerical_id = True
_table_select_skipped_fields = []
_iterate_on_server = True
@classmethod
@guard_postgres
@ -1629,7 +1637,11 @@ class SqlMixin:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
conn, cur = get_connection_and_cursor()
if cls._iterate_on_server:
conn = get_connection(isolate=True)
cur = conn.cursor(name='select_iterator_%s' % uuid.uuid4())
else:
conn, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
try:
for object in cls.get_objects(cur, iterator=True):
@ -1639,8 +1651,8 @@ class SqlMixin:
continue
yield object
finally:
conn.commit()
cur.close()
conn.commit()
@classmethod
@guard_postgres
@ -2268,7 +2280,7 @@ class SqlDataMixin(SqlMixin):
@classmethod
def rebuild_security(cls):
formdatas = cls.select(order_by='id')
formdatas = cls.select(order_by='id', iterator=True)
conn, cur = get_connection_and_cursor()
for formdata in formdatas:
sql_statement = (
@ -3074,6 +3086,7 @@ class AnyFormData(SqlMixin):
_table_name = 'wcs_all_forms'
__table_static_fields = []
_formdef_cache = {}
_iterate_on_server = False
@classproperty
def _table_static_fields(self):