storage: add parameter to sort get_ids results (#47878)
This commit is contained in:
parent
93452519fd
commit
ef03231782
|
@ -329,6 +329,35 @@ class StorableObject(object):
|
|||
return len(cls.select(clause))
|
||||
return len(cls.keys())
|
||||
|
||||
@classmethod
|
||||
def sort_results(cls, objects, order_by):
|
||||
if not order_by:
|
||||
return objects
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
reverse = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse = False
|
||||
# only list can be sorted
|
||||
objects = list(objects)
|
||||
if order_by == 'id':
|
||||
key_function = lambda x: lax_int(x.id)
|
||||
elif order_by == 'name':
|
||||
# proper collation should be done but it's messy to get working
|
||||
# on all systems so we go the cheap and almost ok way.
|
||||
from .misc import simplify
|
||||
key_function = lambda x: simplify(x.name)
|
||||
elif order_by.endswith('_time'):
|
||||
typed_none = time.gmtime(-10**10) # 1653
|
||||
key_function = lambda x: getattr(x, order_by) or typed_none
|
||||
else:
|
||||
key_function = lambda x: getattr(x, order_by)
|
||||
objects.sort(key=key_function)
|
||||
if reverse:
|
||||
objects.reverse()
|
||||
return objects
|
||||
|
||||
@classmethod
|
||||
def select(cls, clause=None, order_by=None, ignore_errors=False,
|
||||
ignore_migration=False, limit=None, offset=None, iterator=False, **kwargs):
|
||||
|
@ -341,30 +370,7 @@ class StorableObject(object):
|
|||
if clause:
|
||||
clause_function = parse_clause(clause)
|
||||
objects = (x for x in objects if clause_function(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
reverse = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse = False
|
||||
# only list can be sorted
|
||||
objects = list(objects)
|
||||
if order_by == 'id':
|
||||
key_function = lambda x: lax_int(x.id)
|
||||
elif order_by == 'name':
|
||||
# proper collation should be done but it's messy to get working
|
||||
# on all systems so we go the cheap and almost ok way.
|
||||
from .misc import simplify
|
||||
key_function = lambda x: simplify(x.name)
|
||||
elif order_by.endswith('_time'):
|
||||
typed_none = time.gmtime(-10**10) # 1653
|
||||
key_function = lambda x: getattr(x, order_by) or typed_none
|
||||
else:
|
||||
key_function = lambda x: getattr(x, order_by)
|
||||
objects.sort(key=key_function)
|
||||
if reverse:
|
||||
objects.reverse()
|
||||
objects = cls.sort_results(objects, order_by)
|
||||
if limit or offset:
|
||||
objects = _take(objects, limit, offset)
|
||||
return list(objects)
|
||||
|
@ -418,13 +424,13 @@ class StorableObject(object):
|
|||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_ids(cls, ids, ignore_errors=False, **kwargs):
|
||||
def get_ids(cls, ids, ignore_errors=False, order_by=None, **kwargs):
|
||||
objects = []
|
||||
for x in ids:
|
||||
obj = cls.get(x, ignore_errors=ignore_errors)
|
||||
if obj is not None:
|
||||
objects.append(obj)
|
||||
return objects
|
||||
return cls.sort_results(objects, order_by)
|
||||
|
||||
@classmethod
|
||||
def get_on_index(cls, id, index, ignore_errors=False, ignore_migration=False):
|
||||
|
|
33
wcs/sql.py
33
wcs/sql.py
|
@ -1189,7 +1189,7 @@ class SqlMixin(object):
|
|||
|
||||
@classmethod
|
||||
@guard_postgres
|
||||
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None):
|
||||
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None, order_by=None):
|
||||
if not ids:
|
||||
return []
|
||||
tables = [cls._table_name]
|
||||
|
@ -1232,6 +1232,7 @@ class SqlMixin(object):
|
|||
' '.join(tables),
|
||||
cls._table_name,
|
||||
','.join([str(x) for x in ids]))
|
||||
sql_statement += cls.get_order_by_clause(order_by)
|
||||
cur.execute(sql_statement)
|
||||
objects = cls.get_objects(cur, extra_fields=extra_fields)
|
||||
conn.commit()
|
||||
|
@ -1263,6 +1264,19 @@ class SqlMixin(object):
|
|||
return generator
|
||||
return list(generator)
|
||||
|
||||
@classmethod
|
||||
def get_order_by_clause(cls, order_by):
|
||||
if not order_by:
|
||||
return ''
|
||||
# [SEC_ORDER] security note: it is not possible to use
|
||||
# prepared statements for ORDER BY clauses, therefore input
|
||||
# is controlled beforehand (see misc.get_order_by_or_400).
|
||||
if order_by.startswith('-'):
|
||||
order_by = order_by[1:]
|
||||
return ' ORDER BY %s DESC' % order_by.replace('-', '_')
|
||||
else:
|
||||
return ' ORDER BY %s' % order_by.replace('-', '_')
|
||||
|
||||
@classmethod
|
||||
@guard_postgres
|
||||
def select_iterator(cls, clause=None, order_by=None, ignore_errors=False,
|
||||
|
@ -1278,15 +1292,7 @@ class SqlMixin(object):
|
|||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
||||
if order_by:
|
||||
# [SEC_ORDER] security note: it is not possible to use
|
||||
# prepared statements for ORDER BY clauses, therefore input
|
||||
# is controlled beforehand (see misc.get_order_by_or_400).
|
||||
if order_by.startswith('-'):
|
||||
order_by = order_by[1:]
|
||||
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
|
||||
else:
|
||||
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
|
||||
sql_statement += cls.get_order_by_clause(order_by)
|
||||
|
||||
if not func_clause:
|
||||
if limit:
|
||||
|
@ -1467,12 +1473,7 @@ class SqlMixin(object):
|
|||
assert not func_clause
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
# security note, refer to [SEC_ORDER]
|
||||
if order_by.startswith('-'):
|
||||
order_by = order_by[1:]
|
||||
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
|
||||
else:
|
||||
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
|
||||
sql_statement += cls.get_order_by_clause(order_by)
|
||||
cur.execute(sql_statement, parameters)
|
||||
ids = [x[0] for x in cur.fetchall()]
|
||||
conn.commit()
|
||||
|
|
Loading…
Reference in New Issue