sql: add support for basic decimal queries (#77911) #602

Closed
fpeters wants to merge 1 commits from wip/77911-decimal-queries into main
4 changed files with 140 additions and 26 deletions

View File

@ -1283,6 +1283,7 @@ def test_api_list_formdata_string_filter(pub, local_user):
formdef.fields = [
fields.StringField(id=formdef.get_new_field_id(), label='String', varname='string'),
fields.StringField(id='1', label='String2', varname='string2'),
fields.StringField(id='2', label='String3', varname='string3'),
]
formdef.store()
@ -1294,12 +1295,14 @@ def test_api_list_formdata_string_filter(pub, local_user):
formdata.data = {
formdef.fields[0].id: 'FOO %s' % i,
'1': '%s' % (9 + i),
'2': '%.2f' % (3.2 + 0.8 * i),
}
if i == 3:
# Empty values
formdata.data = {
formdef.fields[0].id: '',
'1': '',
'2': '',
}
if i == 4:
# None values
@ -1375,6 +1378,40 @@ def test_api_list_formdata_string_filter(pub, local_user):
)
assert len(resp.json) == result
# decimal numbers
params = [
('eq', '4', 1),
('ne', '4', 3),
('lt', '4', 1),
('lte', '4', 2),
('lt', '4.1', 2),
('lte', '4.1', 2),
('gt', '4', 1),
('gt', '3.9', 2),
('gte', '4', 2),
('in', '4', 1),
('in', '3.2|4', 2),
('in', '4|42', 1),
('in', '4|a', 0), # nothing, as all items are not numeric
('in', '4.00|a', 1), # 1, as matching is thus done on strings
('not_in', '4', 2),
('not_in', '3.2|4', 1),
('not_in', '3.2|42', 2),
('absent', 'on', 2),
('existing', 'on', 3),
('between', '3.1|4.5', 2),
('between', '3.3|4.5', 1),
('between', '4.5|3.1', 2),
]
for operator, value, result in params:
resp = get_app(pub).get(
sign_uri(
'/api/forms/test/list?filter-string3=%s&filter-string3-operator=%s' % (value, operator),
user=local_user,
)
)
assert len(resp.json) == result
def test_api_list_formdata_text_filter(pub, local_user):
pub.role_class.wipe()

View File

@ -173,7 +173,7 @@ def get_name_as_sql_identifier(name):
return name
def parse_clause(clause):
def parse_clause(clause, cur=None):
# returns a three-elements tuple with:
# - a list of SQL 'WHERE' clauses
# - a dict for query parameters
@ -189,6 +189,7 @@ def parse_clause(clause):
func_clauses = []
where_clauses = []
parameters = {}
requires_set_lc_numeric = False
for i, element in enumerate(clause):
if callable(element):
func_clauses.append(element)
@ -206,9 +207,13 @@ def parse_clause(clause):
clause[i] = sql_element
where_clauses.append(sql_element.as_sql())
parameters.update(sql_element.as_sql_param())
requires_set_lc_numeric |= sql_element.requires_set_lc_numeric
else:
func_clauses.append(element.build_lambda())
if requires_set_lc_numeric and cur:
cur.execute("SET lc_numeric = 'C'")
Review

Il y a un peu de complexité pour poser ça uniquement quand on en a besoin; mais peut-être que faire le SET de manière systématique aurait un coût négligeable, @pducroquet ?

Il y a un peu de complexité pour poser ça uniquement quand on en a besoin; mais peut-être que faire le SET de manière systématique aurait un coût négligeable, @pducroquet ?
if func_clauses:
return (where_clauses, parameters, parse_storage_clause(func_clauses))
else:
@ -1640,7 +1645,7 @@ class SqlMixin:
@guard_postgres
def keys(cls, clause=None):
conn, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause)
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
assert not func_clause
sql_statement = 'SELECT id FROM %s' % cls._table_name
if where_clauses:
@ -1654,14 +1659,16 @@ class SqlMixin:
@classmethod
@guard_postgres
def count(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
conn, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
if func_clause:
# fallback to counting the result of a select()
conn.commit()
cur.close()
return len(cls.select(clause))
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
conn, cur = get_connection_and_cursor()
cur.execute(sql_statement, parameters)
count = cur.fetchone()[0]
conn.commit()
@ -1671,15 +1678,17 @@ class SqlMixin:
@classmethod
@guard_postgres
def exists(cls, clause=None):
where_clauses, parameters, func_clause = parse_clause(clause)
conn, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
if func_clause:
# fallback to counting the result of a select()
conn.commit()
cur.close()
return len(cls.select(clause))
sql_statement = 'SELECT 1 FROM %s' % cls._table_name
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' LIMIT 1'
conn, cur = get_connection_and_cursor()
try:
cur.execute(sql_statement, parameters)
except UndefinedTable:
@ -1952,7 +1961,9 @@ class SqlMixin:
', '.join(table_static_fields + cls.get_data_fields()),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
conn, cur = get_connection_and_cursor()
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
@ -1966,7 +1977,6 @@ class SqlMixin:
sql_statement += ' OFFSET %(offset)s'
parameters['offset'] = offset
conn, cur = get_connection_and_cursor()
with cur:
cur.execute(sql_statement, parameters)
conn.commit()
@ -2026,7 +2036,7 @@ class SqlMixin:
', '.join([column0] + columns[1:]),
cls._table_name,
)
where_clauses, parameters, func_clause = parse_clause(clause)
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
@ -2178,7 +2188,7 @@ class SqlMixin:
sql_statement = '''DELETE FROM %s''' % cls._table_name
parameters = {}
if clause:
where_clauses, parameters, dummy = parse_clause(clause)
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
@ -2190,7 +2200,7 @@ class SqlMixin:
def get_sorted_ids(cls, order_by, clause=None):
conn, cur = get_connection_and_cursor()
sql_statement = 'SELECT id FROM %s' % cls._table_name
where_clauses, parameters, func_clause = parse_clause(clause)
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
@ -2439,7 +2449,7 @@ class SqlDataMixin(SqlMixin):
if not where:
where = []
where.append(Equal('id', self.id))
where_clauses, parameters, dummy = parse_clause(where)
where_clauses, parameters, dummy = parse_clause(where, cur=cur)
column_names = list(sql_dict.keys())
sql_dict.update(parameters)
sql_statement = '''UPDATE %s SET %s WHERE %s RETURNING id''' % (
@ -2681,7 +2691,7 @@ class SqlDataMixin(SqlMixin):
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True, clause=None):
cur = get_connection_and_cursor()[1]
where_clauses, parameters, func_clause = parse_clause(clause)
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
assert not func_clause
if isinstance(value, int):
@ -3669,7 +3679,7 @@ class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
clause = [Contains('object_type', object_types)]
if user is not None:
clause.append(Equal('user_id', str(user.id)))
where_clauses, parameters, dummy = parse_clause(clause)
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
sql_statement = 'SELECT object_type, object_id, MAX(timestamp) AS m FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
@ -3693,7 +3703,7 @@ class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
conn, cur = get_connection_and_cursor()
clause = [Contains('object_type', object_types)]
where_clauses, parameters, dummy = parse_clause(clause)
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
sql_statement = 'SELECT COUNT(*) FROM (SELECT object_type, object_id FROM snapshots'
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
sql_statement += ' GROUP BY object_type, object_id) AS s'
@ -4790,7 +4800,7 @@ def get_actionable_counts(user_roles):
Intersects('actions_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
where_clauses, parameters, dummy = parse_clause(criterias, cur=cur)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s
@ -4811,7 +4821,7 @@ def get_total_counts(user_roles):
Intersects('concerned_roles_array', user_roles),
Null('anonymised'),
]
where_clauses, parameters, dummy = parse_clause(criterias)
where_clauses, parameters, dummy = parse_clause(criterias, cur=cur)
statement = '''SELECT formdef_id, COUNT(*)
FROM wcs_all_forms
WHERE %s

View File

@ -15,6 +15,7 @@
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import datetime
import decimal
import re
import time
@ -36,6 +37,8 @@ def get_field_id(field):
class Criteria(wcs.qommon.storage.Criteria):
requires_set_lc_numeric = False
def __init__(self, attribute, value, **kwargs):
self.attribute = attribute
if '->' not in attribute:
@ -52,12 +55,23 @@ class Criteria(wcs.qommon.storage.Criteria):
return value
def as_sql(self):
def is_numeric_type(x):
return isinstance(x, (int, float, decimal.Decimal)) and not isinstance(x, bool)
value_is_list_of_int = (
isinstance(self.value, list)
and self.value
and isinstance(self.value[0], int) # all elements are of the same type
)
value_is_list_of_numeric = (
isinstance(self.value, list)
and self.value
and is_numeric_type(self.value[0]) # all elements are of the same type
)
value_is_int = isinstance(self.value, int) or value_is_list_of_int
value_is_numeric = is_numeric_type(self.value) or value_is_list_of_numeric
if value_is_numeric:
self.requires_set_lc_numeric = True
if self.field and getattr(self.field, 'block_field', None):
# eq: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
@ -72,12 +86,21 @@ class Criteria(wcs.qommon.storage.Criteria):
# ne: NOT EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
# note: aa->>'FOOBAR' can be written with an integer or bool cast
attribute = "aa->>'%s'" % self.field.id
if self.field.key in ['item', 'string'] and value_is_int:
if self.field.key == 'string' and value_is_numeric:
# decimal cast of db values
attribute = (
"(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int "
" WHEN %s~E'^\\\\d{0,9}[\\\\.,]\\\\d{1,9}$' "
" THEN to_number(replace(%s, ',', '.'), '999999999D999999999')::numeric "
' ELSE NULL END)' % (attribute, attribute, attribute, attribute)
)
elif self.field.key in ['item', 'string'] and value_is_int:
# integer cast of db values
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int ELSE NULL END)" % (
attribute,
attribute,
)
elif self.field.key == 'bool':
# bool cast of db values
attribute = '(%s)::bool' % attribute
@ -140,6 +163,14 @@ class Criteria(wcs.qommon.storage.Criteria):
if self.field:
if self.field.key == 'computed':
attribute = "%s->>'data'" % self.attribute
elif self.field.key == 'string' and value_is_numeric:
# decimal cast of db values
attribute = (
"(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int "
" WHEN %s~E'^\\\\d{0,9}[\\\\.,]\\\\d{1,9}$' "
" THEN to_number(replace(%s, ',', '.'), '999999999D999999999')::numeric "
' ELSE NULL END)' % (attribute, attribute, attribute, attribute)
)
elif self.field.key in ['item', 'string'] and value_is_int:
# integer cast of db values
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN %s::int ELSE NULL END)" % (
@ -293,7 +324,9 @@ class Or(Criteria):
def as_sql(self):
if not self.criterias:
return '( FALSE )'
return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
expression = '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
self.requires_set_lc_numeric = any(x.requires_set_lc_numeric for x in self.criterias)
return expression
def as_sql_param(self):
d = {}
@ -307,7 +340,9 @@ class Or(Criteria):
class And(Or):
def as_sql(self):
return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
expression = '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
self.requires_set_lc_numeric = any(x.requires_set_lc_numeric for x in self.criterias)
return expression
class Not(Criteria):
@ -317,7 +352,9 @@ class Not(Criteria):
self.criteria = sql_element
def as_sql(self):
return 'NOT ( %s )' % self.criteria.as_sql()
expression = 'NOT ( %s )' % self.criteria.as_sql()
self.requires_set_lc_numeric = self.criteria.requires_set_lc_numeric
return expression
def as_sql_param(self):
return self.criteria.as_sql_param()

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import decimal
import warnings
from django.utils import formats
@ -325,6 +326,29 @@ class LazyFormDefObjectsManager:
return str(val)
return val
def check_numeric(val):
orig_val = val
if isinstance(val, str):
if '_' in val:
# do not consider _ a valid character in numbers
# (unlike python where it can be used to group digits)
return val
val = val.replace(',', '.') # replace , by . for French users comfort
if val.startswith('0') and not val.startswith('0.'):
# do not consider numbers starting with 0 as numerics (unless the next
# character is the decimal separator), to avoid strings like phone numbers
# to be considered as numbers
return orig_val
try:
# cast to decimal so it can be used with numerical operators
numeric_value = decimal.Decimal(val)
if -(2**31) <= numeric_value < 2**31:
# (limit to 32bits to match postgresql integer range)
return numeric_value
except (ValueError, TypeError, decimal.InvalidOperation):
pass
return orig_val
def convert_value(value, field):
if field.convert_value_from_anything and value is not Ellipsis:
try:
@ -378,12 +402,18 @@ class LazyFormDefObjectsManager:
value = value[0]
if field.key in ['string', 'item', 'items']:
if isinstance(value, list):
value = [check_int(v) for v in value]
# make sure all elements are of the same type
if not all(isinstance(v, int) for v in value):
value = [str(v) for v in value]
else:
numeric_values = [check_numeric(v) for v in value]
int_values = [check_int(v) for v in value]
if field.key == 'string' and all(isinstance(v, decimal.Decimal) for v in numeric_values):
value = numeric_values
elif all(isinstance(v, int) for v in int_values):
value = int_values
elif field.key in ['item', 'items']:
value = check_int(value)
else: # string
value = check_numeric(value)
if not isinstance(value, decimal.Decimal): # fallback to maybe integer
value = check_int(value)
return value