diff --git a/tests/test_sql.py b/tests/test_sql.py index 6fc1c6fce..e37ab6b42 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -128,12 +128,16 @@ def test_sql_get_missing(): data_class = formdef.data_class(mode='sql') with pytest.raises(KeyError): data_class.get(123456) + with pytest.raises(KeyError): + data_class.get('xxx') @postgresql def test_sql_get_missing_ignore_errors(): data_class = formdef.data_class(mode='sql') assert data_class.get(123456, ignore_errors=True) is None + assert data_class.get('xxx', ignore_errors=True) is None + assert data_class.get(None, ignore_errors=True) is None def check_sql_field(no, value): diff --git a/wcs/sql.py b/wcs/sql.py index 5286c30de..a722cc2a2 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -1021,6 +1021,7 @@ def do_global_views(conn, cur): class SqlMixin(object): _table_name = None + _numerical_id = True @classmethod @guard_postgres @@ -1095,11 +1096,14 @@ class SqlMixin(object): @classmethod @guard_postgres def get(cls, id, ignore_errors=False, ignore_migration=False): - if id is None: - if ignore_errors: - return None - else: - raise KeyError() + if cls._numerical_id or id is None: + try: + int(id) + except (TypeError, ValueError): + if ignore_errors and id is None: + return None + else: + raise KeyError() conn, cur = get_connection_and_cursor() sql_statement = '''SELECT %s @@ -1658,16 +1662,13 @@ class SqlDataMixin(SqlMixin): @classmethod @guard_postgres def get(cls, id, ignore_errors=False, ignore_migration=False): - if id is None: + try: + int(id) + except (TypeError, ValueError): if ignore_errors: return None else: raise KeyError() - else: - try: - int(id) - except ValueError: - raise KeyError() conn, cur = get_connection_and_cursor() fields = cls.get_data_fields() @@ -1993,6 +1994,7 @@ class Session(SqlMixin, wcs.sessions.BasicSession): ('id', 'varchar'), ('session_data', 'bytea'), ] + _numerical_id = False @classmethod @guard_postgres @@ -2084,6 +2086,7 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode): ('formdef_id', 'varchar'), ('formdata_id', 'varchar'), ] + _numerical_id = False id = None