sql: check id parameter passed to .get() method (#42827)
This commit is contained in:
parent
c34a405504
commit
6065ab4d39
|
@ -128,12 +128,16 @@ def test_sql_get_missing():
|
||||||
data_class = formdef.data_class(mode='sql')
|
data_class = formdef.data_class(mode='sql')
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
data_class.get(123456)
|
data_class.get(123456)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
data_class.get('xxx')
|
||||||
|
|
||||||
|
|
||||||
@postgresql
|
@postgresql
|
||||||
def test_sql_get_missing_ignore_errors():
|
def test_sql_get_missing_ignore_errors():
|
||||||
data_class = formdef.data_class(mode='sql')
|
data_class = formdef.data_class(mode='sql')
|
||||||
assert data_class.get(123456, ignore_errors=True) is None
|
assert data_class.get(123456, ignore_errors=True) is None
|
||||||
|
assert data_class.get('xxx', ignore_errors=True) is None
|
||||||
|
assert data_class.get(None, ignore_errors=True) is None
|
||||||
|
|
||||||
|
|
||||||
def check_sql_field(no, value):
|
def check_sql_field(no, value):
|
||||||
|
|
25
wcs/sql.py
25
wcs/sql.py
|
@ -1021,6 +1021,7 @@ def do_global_views(conn, cur):
|
||||||
|
|
||||||
class SqlMixin(object):
|
class SqlMixin(object):
|
||||||
_table_name = None
|
_table_name = None
|
||||||
|
_numerical_id = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@guard_postgres
|
@guard_postgres
|
||||||
|
@ -1095,11 +1096,14 @@ class SqlMixin(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
@guard_postgres
|
@guard_postgres
|
||||||
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
||||||
if id is None:
|
if cls._numerical_id or id is None:
|
||||||
if ignore_errors:
|
try:
|
||||||
return None
|
int(id)
|
||||||
else:
|
except (TypeError, ValueError):
|
||||||
raise KeyError()
|
if ignore_errors and id is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise KeyError()
|
||||||
conn, cur = get_connection_and_cursor()
|
conn, cur = get_connection_and_cursor()
|
||||||
|
|
||||||
sql_statement = '''SELECT %s
|
sql_statement = '''SELECT %s
|
||||||
|
@ -1658,16 +1662,13 @@ class SqlDataMixin(SqlMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
@guard_postgres
|
@guard_postgres
|
||||||
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
||||||
if id is None:
|
try:
|
||||||
|
int(id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
if ignore_errors:
|
if ignore_errors:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
else:
|
|
||||||
try:
|
|
||||||
int(id)
|
|
||||||
except ValueError:
|
|
||||||
raise KeyError()
|
|
||||||
conn, cur = get_connection_and_cursor()
|
conn, cur = get_connection_and_cursor()
|
||||||
|
|
||||||
fields = cls.get_data_fields()
|
fields = cls.get_data_fields()
|
||||||
|
@ -1993,6 +1994,7 @@ class Session(SqlMixin, wcs.sessions.BasicSession):
|
||||||
('id', 'varchar'),
|
('id', 'varchar'),
|
||||||
('session_data', 'bytea'),
|
('session_data', 'bytea'),
|
||||||
]
|
]
|
||||||
|
_numerical_id = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@guard_postgres
|
@guard_postgres
|
||||||
|
@ -2084,6 +2086,7 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
|
||||||
('formdef_id', 'varchar'),
|
('formdef_id', 'varchar'),
|
||||||
('formdata_id', 'varchar'),
|
('formdata_id', 'varchar'),
|
||||||
]
|
]
|
||||||
|
_numerical_id = False
|
||||||
|
|
||||||
id = None
|
id = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue