sql: check id parameter passed to .get() method (#42827)

This commit is contained in:
Frédéric Péters 2020-05-13 13:22:01 +02:00
parent c34a405504
commit 6065ab4d39
2 changed files with 18 additions and 11 deletions

View File

@ -128,12 +128,16 @@ def test_sql_get_missing():
data_class = formdef.data_class(mode='sql')
with pytest.raises(KeyError):
data_class.get(123456)
with pytest.raises(KeyError):
data_class.get('xxx')
@postgresql
def test_sql_get_missing_ignore_errors():
data_class = formdef.data_class(mode='sql')
assert data_class.get(123456, ignore_errors=True) is None
assert data_class.get('xxx', ignore_errors=True) is None
assert data_class.get(None, ignore_errors=True) is None
def check_sql_field(no, value):

View File

@ -1021,6 +1021,7 @@ def do_global_views(conn, cur):
class SqlMixin(object):
_table_name = None
_numerical_id = True
@classmethod
@guard_postgres
@ -1095,11 +1096,14 @@ class SqlMixin(object):
@classmethod
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False):
if id is None:
if ignore_errors:
return None
else:
raise KeyError()
if cls._numerical_id or id is None:
try:
int(id)
except (TypeError, ValueError):
if ignore_errors and id is None:
return None
else:
raise KeyError()
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
@ -1658,16 +1662,13 @@ class SqlDataMixin(SqlMixin):
@classmethod
@guard_postgres
def get(cls, id, ignore_errors=False, ignore_migration=False):
if id is None:
try:
int(id)
except (TypeError, ValueError):
if ignore_errors:
return None
else:
raise KeyError()
else:
try:
int(id)
except ValueError:
raise KeyError()
conn, cur = get_connection_and_cursor()
fields = cls.get_data_fields()
@ -1993,6 +1994,7 @@ class Session(SqlMixin, wcs.sessions.BasicSession):
('id', 'varchar'),
('session_data', 'bytea'),
]
_numerical_id = False
@classmethod
@guard_postgres
@ -2084,6 +2086,7 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
('formdef_id', 'varchar'),
('formdata_id', 'varchar'),
]
_numerical_id = False
id = None