sql: add support for basic decimal queries (#77911) #602
|
@ -1283,6 +1283,7 @@ def test_api_list_formdata_string_filter(pub, local_user):
|
|||
formdef.fields = [
|
||||
fields.StringField(id=formdef.get_new_field_id(), label='String', varname='string'),
|
||||
fields.StringField(id='1', label='String2', varname='string2'),
|
||||
fields.StringField(id='2', label='String3', varname='string3'),
|
||||
]
|
||||
formdef.store()
|
||||
|
||||
|
@ -1294,12 +1295,14 @@ def test_api_list_formdata_string_filter(pub, local_user):
|
|||
formdata.data = {
|
||||
formdef.fields[0].id: 'FOO %s' % i,
|
||||
'1': '%s' % (9 + i),
|
||||
'2': '%.2f' % (3.2 + 0.8 * i),
|
||||
}
|
||||
if i == 3:
|
||||
# Empty values
|
||||
formdata.data = {
|
||||
formdef.fields[0].id: '',
|
||||
'1': '',
|
||||
'2': '',
|
||||
}
|
||||
if i == 4:
|
||||
# None values
|
||||
|
@ -1375,6 +1378,40 @@ def test_api_list_formdata_string_filter(pub, local_user):
|
|||
)
|
||||
assert len(resp.json) == result
|
||||
|
||||
# decimal numbers
|
||||
params = [
|
||||
('eq', '4', 1),
|
||||
('ne', '4', 3),
|
||||
('lt', '4', 1),
|
||||
('lte', '4', 2),
|
||||
('lt', '4.1', 2),
|
||||
('lte', '4.1', 2),
|
||||
('gt', '4', 1),
|
||||
('gt', '3.9', 2),
|
||||
('gte', '4', 2),
|
||||
('in', '4', 1),
|
||||
('in', '3.2|4', 2),
|
||||
('in', '4|42', 1),
|
||||
('in', '4|a', 0), # nothing, as all items are not numeric
|
||||
('in', '4.00|a', 1), # 1, as matching is thus done on strings
|
||||
('not_in', '4', 2),
|
||||
('not_in', '3.2|4', 1),
|
||||
('not_in', '3.2|42', 2),
|
||||
('absent', 'on', 2),
|
||||
('existing', 'on', 3),
|
||||
('between', '3.1|4.5', 2),
|
||||
('between', '3.3|4.5', 1),
|
||||
('between', '4.5|3.1', 2),
|
||||
]
|
||||
for operator, value, result in params:
|
||||
resp = get_app(pub).get(
|
||||
sign_uri(
|
||||
'/api/forms/test/list?filter-string3=%s&filter-string3-operator=%s' % (value, operator),
|
||||
user=local_user,
|
||||
)
|
||||
)
|
||||
assert len(resp.json) == result
|
||||
|
||||
|
||||
def test_api_list_formdata_text_filter(pub, local_user):
|
||||
pub.role_class.wipe()
|
||||
|
|
44
wcs/sql.py
44
wcs/sql.py
|
@ -173,7 +173,7 @@ def get_name_as_sql_identifier(name):
|
|||
return name
|
||||
|
||||
|
||||
def parse_clause(clause):
|
||||
def parse_clause(clause, cur=None):
|
||||
# returns a three-elements tuple with:
|
||||
# - a list of SQL 'WHERE' clauses
|
||||
# - a dict for query parameters
|
||||
|
@ -189,6 +189,7 @@ def parse_clause(clause):
|
|||
func_clauses = []
|
||||
where_clauses = []
|
||||
parameters = {}
|
||||
requires_set_lc_numeric = False
|
||||
for i, element in enumerate(clause):
|
||||
if callable(element):
|
||||
func_clauses.append(element)
|
||||
|
@ -206,9 +207,13 @@ def parse_clause(clause):
|
|||
clause[i] = sql_element
|
||||
where_clauses.append(sql_element.as_sql())
|
||||
parameters.update(sql_element.as_sql_param())
|
||||
requires_set_lc_numeric |= sql_element.requires_set_lc_numeric
|
||||
else:
|
||||
func_clauses.append(element.build_lambda())
|
||||
|
||||
if requires_set_lc_numeric and cur:
|
||||
cur.execute("SET lc_numeric = 'C'")
|
||||
|
||||
|
||||
if func_clauses:
|
||||
return (where_clauses, parameters, parse_storage_clause(func_clauses))
|
||||
else:
|
||||
|
@ -1640,7 +1645,7 @@ class SqlMixin:
|
|||
@guard_postgres
|
||||
def keys(cls, clause=None):
|
||||
conn, cur = get_connection_and_cursor()
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
assert not func_clause
|
||||
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
||||
if where_clauses:
|
||||
|
@ -1654,14 +1659,16 @@ class SqlMixin:
|
|||
@classmethod
|
||||
@guard_postgres
|
||||
def count(cls, clause=None):
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
conn, cur = get_connection_and_cursor()
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
if func_clause:
|
||||
# fallback to counting the result of a select()
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return len(cls.select(clause))
|
||||
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
conn, cur = get_connection_and_cursor()
|
||||
cur.execute(sql_statement, parameters)
|
||||
count = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
|
@ -1671,15 +1678,17 @@ class SqlMixin:
|
|||
@classmethod
|
||||
@guard_postgres
|
||||
def exists(cls, clause=None):
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
conn, cur = get_connection_and_cursor()
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
if func_clause:
|
||||
# fallback to counting the result of a select()
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return len(cls.select(clause))
|
||||
sql_statement = 'SELECT 1 FROM %s' % cls._table_name
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
sql_statement += ' LIMIT 1'
|
||||
conn, cur = get_connection_and_cursor()
|
||||
try:
|
||||
cur.execute(sql_statement, parameters)
|
||||
except UndefinedTable:
|
||||
|
@ -1952,7 +1961,9 @@ class SqlMixin:
|
|||
', '.join(table_static_fields + cls.get_data_fields()),
|
||||
cls._table_name,
|
||||
)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
|
||||
conn, cur = get_connection_and_cursor()
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
||||
|
@ -1966,7 +1977,6 @@ class SqlMixin:
|
|||
sql_statement += ' OFFSET %(offset)s'
|
||||
parameters['offset'] = offset
|
||||
|
||||
conn, cur = get_connection_and_cursor()
|
||||
with cur:
|
||||
cur.execute(sql_statement, parameters)
|
||||
conn.commit()
|
||||
|
@ -2026,7 +2036,7 @@ class SqlMixin:
|
|||
', '.join([column0] + columns[1:]),
|
||||
cls._table_name,
|
||||
)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
assert not func_clause
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
@ -2178,7 +2188,7 @@ class SqlMixin:
|
|||
sql_statement = '''DELETE FROM %s''' % cls._table_name
|
||||
parameters = {}
|
||||
if clause:
|
||||
where_clauses, parameters, dummy = parse_clause(clause)
|
||||
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
cur.execute(sql_statement, parameters)
|
||||
|
@ -2190,7 +2200,7 @@ class SqlMixin:
|
|||
def get_sorted_ids(cls, order_by, clause=None):
|
||||
conn, cur = get_connection_and_cursor()
|
||||
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
assert not func_clause
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
@ -2439,7 +2449,7 @@ class SqlDataMixin(SqlMixin):
|
|||
if not where:
|
||||
where = []
|
||||
where.append(Equal('id', self.id))
|
||||
where_clauses, parameters, dummy = parse_clause(where)
|
||||
where_clauses, parameters, dummy = parse_clause(where, cur=cur)
|
||||
column_names = list(sql_dict.keys())
|
||||
sql_dict.update(parameters)
|
||||
sql_statement = '''UPDATE %s SET %s WHERE %s RETURNING id''' % (
|
||||
|
@ -2681,7 +2691,7 @@ class SqlDataMixin(SqlMixin):
|
|||
def get_ids_with_indexed_value(cls, index, value, auto_fallback=True, clause=None):
|
||||
cur = get_connection_and_cursor()[1]
|
||||
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause, cur=cur)
|
||||
assert not func_clause
|
||||
|
||||
if isinstance(value, int):
|
||||
|
@ -3669,7 +3679,7 @@ class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
|
|||
clause = [Contains('object_type', object_types)]
|
||||
if user is not None:
|
||||
clause.append(Equal('user_id', str(user.id)))
|
||||
where_clauses, parameters, dummy = parse_clause(clause)
|
||||
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
|
||||
|
||||
sql_statement = 'SELECT object_type, object_id, MAX(timestamp) AS m FROM snapshots'
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
@ -3693,7 +3703,7 @@ class Snapshot(SqlMixin, wcs.snapshots.Snapshot):
|
|||
conn, cur = get_connection_and_cursor()
|
||||
|
||||
clause = [Contains('object_type', object_types)]
|
||||
where_clauses, parameters, dummy = parse_clause(clause)
|
||||
where_clauses, parameters, dummy = parse_clause(clause, cur=cur)
|
||||
sql_statement = 'SELECT COUNT(*) FROM (SELECT object_type, object_id FROM snapshots'
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
sql_statement += ' GROUP BY object_type, object_id) AS s'
|
||||
|
@ -4790,7 +4800,7 @@ def get_actionable_counts(user_roles):
|
|||
Intersects('actions_roles_array', user_roles),
|
||||
Null('anonymised'),
|
||||
]
|
||||
where_clauses, parameters, dummy = parse_clause(criterias)
|
||||
where_clauses, parameters, dummy = parse_clause(criterias, cur=cur)
|
||||
statement = '''SELECT formdef_id, COUNT(*)
|
||||
FROM wcs_all_forms
|
||||
WHERE %s
|
||||
|
@ -4811,7 +4821,7 @@ def get_total_counts(user_roles):
|
|||
Intersects('concerned_roles_array', user_roles),
|
||||
Null('anonymised'),
|
||||
]
|
||||
where_clauses, parameters, dummy = parse_clause(criterias)
|
||||
where_clauses, parameters, dummy = parse_clause(criterias, cur=cur)
|
||||
statement = '''SELECT formdef_id, COUNT(*)
|
||||
FROM wcs_all_forms
|
||||
WHERE %s
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# along with this program; if not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import re
|
||||
import time
|
||||
|
||||
|
@ -36,6 +37,8 @@ def get_field_id(field):
|
|||
|
||||
|
||||
class Criteria(wcs.qommon.storage.Criteria):
|
||||
requires_set_lc_numeric = False
|
||||
|
||||
def __init__(self, attribute, value, **kwargs):
|
||||
self.attribute = attribute
|
||||
if '->' not in attribute:
|
||||
|
@ -52,12 +55,23 @@ class Criteria(wcs.qommon.storage.Criteria):
|
|||
return value
|
||||
|
||||
def as_sql(self):
|
||||
def is_numeric_type(x):
|
||||
return isinstance(x, (int, float, decimal.Decimal)) and not isinstance(x, bool)
|
||||
|
||||
value_is_list_of_int = (
|
||||
isinstance(self.value, list)
|
||||
and self.value
|
||||
and isinstance(self.value[0], int) # all elements are of the same type
|
||||
)
|
||||
value_is_list_of_numeric = (
|
||||
isinstance(self.value, list)
|
||||
and self.value
|
||||
and is_numeric_type(self.value[0]) # all elements are of the same type
|
||||
)
|
||||
value_is_int = isinstance(self.value, int) or value_is_list_of_int
|
||||
value_is_numeric = is_numeric_type(self.value) or value_is_list_of_numeric
|
||||
if value_is_numeric:
|
||||
self.requires_set_lc_numeric = True
|
||||
|
||||
if self.field and getattr(self.field, 'block_field', None):
|
||||
# eq: EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
|
||||
|
@ -72,12 +86,21 @@ class Criteria(wcs.qommon.storage.Criteria):
|
|||
# ne: NOT EXISTS (SELECT 1 FROM jsonb_array_elements(BLOCK->'data') AS datas(aa) WHERE aa->>'FOOBAR' = 'value')
|
||||
# note: aa->>'FOOBAR' can be written with an integer or bool cast
|
||||
attribute = "aa->>'%s'" % self.field.id
|
||||
if self.field.key in ['item', 'string'] and value_is_int:
|
||||
if self.field.key == 'string' and value_is_numeric:
|
||||
# decimal cast of db values
|
||||
attribute = (
|
||||
"(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int "
|
||||
" WHEN %s~E'^\\\\d{0,9}[\\\\.,]\\\\d{1,9}$' "
|
||||
" THEN to_number(replace(%s, ',', '.'), '999999999D999999999')::numeric "
|
||||
' ELSE NULL END)' % (attribute, attribute, attribute, attribute)
|
||||
)
|
||||
elif self.field.key in ['item', 'string'] and value_is_int:
|
||||
# integer cast of db values
|
||||
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int ELSE NULL END)" % (
|
||||
attribute,
|
||||
attribute,
|
||||
)
|
||||
|
||||
elif self.field.key == 'bool':
|
||||
# bool cast of db values
|
||||
attribute = '(%s)::bool' % attribute
|
||||
|
@ -140,6 +163,14 @@ class Criteria(wcs.qommon.storage.Criteria):
|
|||
if self.field:
|
||||
if self.field.key == 'computed':
|
||||
attribute = "%s->>'data'" % self.attribute
|
||||
elif self.field.key == 'string' and value_is_numeric:
|
||||
# decimal cast of db values
|
||||
attribute = (
|
||||
"(CASE WHEN %s~E'^\\\\d{1,9}$' THEN (%s)::int "
|
||||
" WHEN %s~E'^\\\\d{0,9}[\\\\.,]\\\\d{1,9}$' "
|
||||
" THEN to_number(replace(%s, ',', '.'), '999999999D999999999')::numeric "
|
||||
' ELSE NULL END)' % (attribute, attribute, attribute, attribute)
|
||||
)
|
||||
elif self.field.key in ['item', 'string'] and value_is_int:
|
||||
# integer cast of db values
|
||||
attribute = "(CASE WHEN %s~E'^\\\\d{1,9}$' THEN %s::int ELSE NULL END)" % (
|
||||
|
@ -293,7 +324,9 @@ class Or(Criteria):
|
|||
def as_sql(self):
|
||||
if not self.criterias:
|
||||
return '( FALSE )'
|
||||
return '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
|
||||
expression = '( %s )' % ' OR '.join([x.as_sql() for x in self.criterias])
|
||||
self.requires_set_lc_numeric = any(x.requires_set_lc_numeric for x in self.criterias)
|
||||
return expression
|
||||
|
||||
def as_sql_param(self):
|
||||
d = {}
|
||||
|
@ -307,7 +340,9 @@ class Or(Criteria):
|
|||
|
||||
class And(Or):
|
||||
def as_sql(self):
|
||||
return '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
|
||||
expression = '( %s )' % ' AND '.join([x.as_sql() for x in self.criterias])
|
||||
self.requires_set_lc_numeric = any(x.requires_set_lc_numeric for x in self.criterias)
|
||||
return expression
|
||||
|
||||
|
||||
class Not(Criteria):
|
||||
|
@ -317,7 +352,9 @@ class Not(Criteria):
|
|||
self.criteria = sql_element
|
||||
|
||||
def as_sql(self):
|
||||
return 'NOT ( %s )' % self.criteria.as_sql()
|
||||
expression = 'NOT ( %s )' % self.criteria.as_sql()
|
||||
self.requires_set_lc_numeric = self.criteria.requires_set_lc_numeric
|
||||
return expression
|
||||
|
||||
def as_sql_param(self):
|
||||
return self.criteria.as_sql_param()
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import decimal
|
||||
import warnings
|
||||
|
||||
from django.utils import formats
|
||||
|
@ -325,6 +326,29 @@ class LazyFormDefObjectsManager:
|
|||
return str(val)
|
||||
return val
|
||||
|
||||
def check_numeric(val):
|
||||
orig_val = val
|
||||
if isinstance(val, str):
|
||||
if '_' in val:
|
||||
# do not consider _ a valid character in numbers
|
||||
# (unlike python where it can be used to group digits)
|
||||
return val
|
||||
val = val.replace(',', '.') # replace , by . for French users comfort
|
||||
if val.startswith('0') and not val.startswith('0.'):
|
||||
# do not consider numbers starting with 0 as numerics (unless the next
|
||||
# character is the decimal separator), to avoid strings like phone numbers
|
||||
# to be considered as numbers
|
||||
return orig_val
|
||||
try:
|
||||
# cast to decimal so it can be used with numerical operators
|
||||
numeric_value = decimal.Decimal(val)
|
||||
if -(2**31) <= numeric_value < 2**31:
|
||||
# (limit to 32bits to match postgresql integer range)
|
||||
return numeric_value
|
||||
except (ValueError, TypeError, decimal.InvalidOperation):
|
||||
pass
|
||||
return orig_val
|
||||
|
||||
def convert_value(value, field):
|
||||
if field.convert_value_from_anything and value is not Ellipsis:
|
||||
try:
|
||||
|
@ -378,12 +402,18 @@ class LazyFormDefObjectsManager:
|
|||
value = value[0]
|
||||
if field.key in ['string', 'item', 'items']:
|
||||
if isinstance(value, list):
|
||||
value = [check_int(v) for v in value]
|
||||
# make sure all elements are of the same type
|
||||
if not all(isinstance(v, int) for v in value):
|
||||
value = [str(v) for v in value]
|
||||
else:
|
||||
numeric_values = [check_numeric(v) for v in value]
|
||||
int_values = [check_int(v) for v in value]
|
||||
if field.key == 'string' and all(isinstance(v, decimal.Decimal) for v in numeric_values):
|
||||
value = numeric_values
|
||||
elif all(isinstance(v, int) for v in int_values):
|
||||
value = int_values
|
||||
elif field.key in ['item', 'items']:
|
||||
value = check_int(value)
|
||||
else: # string
|
||||
value = check_numeric(value)
|
||||
if not isinstance(value, decimal.Decimal): # fallback to maybe integer
|
||||
value = check_int(value)
|
||||
|
||||
return value
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Il y a un peu de complexité pour poser ça uniquement quand on en a besoin; mais peut-être que faire le SET de manière systématique aurait un coût négligeable, @pducroquet ?