storage: add parameter to sort get_ids results (#47878)

This commit is contained in:
Frédéric Péters 2020-11-16 19:13:51 +01:00
parent 93452519fd
commit ef03231782
2 changed files with 49 additions and 42 deletions

View File

@ -329,6 +329,35 @@ class StorableObject(object):
return len(cls.select(clause))
return len(cls.keys())
@classmethod
def sort_results(cls, objects, order_by):
if not order_by:
return objects
order_by = str(order_by)
if order_by[0] == '-':
reverse = True
order_by = order_by[1:]
else:
reverse = False
# only list can be sorted
objects = list(objects)
if order_by == 'id':
key_function = lambda x: lax_int(x.id)
elif order_by == 'name':
# proper collation should be done but it's messy to get working
# on all systems so we go the cheap and almost ok way.
from .misc import simplify
key_function = lambda x: simplify(x.name)
elif order_by.endswith('_time'):
typed_none = time.gmtime(-10**10) # 1653
key_function = lambda x: getattr(x, order_by) or typed_none
else:
key_function = lambda x: getattr(x, order_by)
objects.sort(key=key_function)
if reverse:
objects.reverse()
return objects
@classmethod
def select(cls, clause=None, order_by=None, ignore_errors=False,
ignore_migration=False, limit=None, offset=None, iterator=False, **kwargs):
@ -341,30 +370,7 @@ class StorableObject(object):
if clause:
clause_function = parse_clause(clause)
objects = (x for x in objects if clause_function(x))
if order_by:
order_by = str(order_by)
if order_by[0] == '-':
reverse = True
order_by = order_by[1:]
else:
reverse = False
# only list can be sorted
objects = list(objects)
if order_by == 'id':
key_function = lambda x: lax_int(x.id)
elif order_by == 'name':
# proper collation should be done but it's messy to get working
# on all systems so we go the cheap and almost ok way.
from .misc import simplify
key_function = lambda x: simplify(x.name)
elif order_by.endswith('_time'):
typed_none = time.gmtime(-10**10) # 1653
key_function = lambda x: getattr(x, order_by) or typed_none
else:
key_function = lambda x: getattr(x, order_by)
objects.sort(key=key_function)
if reverse:
objects.reverse()
objects = cls.sort_results(objects, order_by)
if limit or offset:
objects = _take(objects, limit, offset)
return list(objects)
@ -418,13 +424,13 @@ class StorableObject(object):
**kwargs)
@classmethod
def get_ids(cls, ids, ignore_errors=False, **kwargs):
def get_ids(cls, ids, ignore_errors=False, order_by=None, **kwargs):
objects = []
for x in ids:
obj = cls.get(x, ignore_errors=ignore_errors)
if obj is not None:
objects.append(obj)
return objects
return cls.sort_results(objects, order_by)
@classmethod
def get_on_index(cls, id, index, ignore_errors=False, ignore_migration=False):

View File

@ -1189,7 +1189,7 @@ class SqlMixin(object):
@classmethod
@guard_postgres
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None):
def get_ids(cls, ids, ignore_errors=False, keep_order=False, fields=None, order_by=None):
if not ids:
return []
tables = [cls._table_name]
@ -1232,6 +1232,7 @@ class SqlMixin(object):
' '.join(tables),
cls._table_name,
','.join([str(x) for x in ids]))
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement)
objects = cls.get_objects(cur, extra_fields=extra_fields)
conn.commit()
@ -1263,6 +1264,19 @@ class SqlMixin(object):
return generator
return list(generator)
@classmethod
def get_order_by_clause(cls, order_by):
if not order_by:
return ''
# [SEC_ORDER] security note: it is not possible to use
# prepared statements for ORDER BY clauses, therefore input
# is controlled beforehand (see misc.get_order_by_or_400).
if order_by.startswith('-'):
order_by = order_by[1:]
return ' ORDER BY %s DESC' % order_by.replace('-', '_')
else:
return ' ORDER BY %s' % order_by.replace('-', '_')
@classmethod
@guard_postgres
def select_iterator(cls, clause=None, order_by=None, ignore_errors=False,
@ -1278,15 +1292,7 @@ class SqlMixin(object):
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
if order_by:
# [SEC_ORDER] security note: it is not possible to use
# prepared statements for ORDER BY clauses, therefore input
# is controlled beforehand (see misc.get_order_by_or_400).
if order_by.startswith('-'):
order_by = order_by[1:]
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
else:
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
sql_statement += cls.get_order_by_clause(order_by)
if not func_clause:
if limit:
@ -1467,12 +1473,7 @@ class SqlMixin(object):
assert not func_clause
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
# security note, refer to [SEC_ORDER]
if order_by.startswith('-'):
order_by = order_by[1:]
sql_statement += ' ORDER BY %s DESC' % order_by.replace('-', '_')
else:
sql_statement += ' ORDER BY %s' % order_by.replace('-', '_')
sql_statement += cls.get_order_by_clause(order_by)
cur.execute(sql_statement, parameters)
ids = [x[0] for x in cur.fetchall()]
conn.commit()