storage: add declarative clauses to select() (#5931)
This commit is contained in:
parent
8fcee2f876
commit
251be76550
|
@ -12,6 +12,7 @@ from wcs import formdef, publisher, fields
|
|||
from wcs.formdef import FormDef
|
||||
from wcs.formdata import Evolution
|
||||
from wcs import sql
|
||||
import wcs.qommon.storage as st
|
||||
|
||||
sys.modules['formdef'] = formdef
|
||||
|
||||
|
@ -477,3 +478,24 @@ def test_sql_table_add_and_remove_fields():
|
|||
data_class = test_formdef.data_class(mode='sql')
|
||||
data_class.select()
|
||||
|
||||
|
||||
@postgresql
|
||||
def test_sql_table_select():
|
||||
test_formdef = FormDef()
|
||||
test_formdef.name = 'table select'
|
||||
test_formdef.fields = []
|
||||
test_formdef.store()
|
||||
data_class = test_formdef.data_class(mode='sql')
|
||||
assert data_class.count() == 0
|
||||
|
||||
for i in range(50):
|
||||
t = data_class()
|
||||
t.store()
|
||||
|
||||
assert data_class.count() == 50
|
||||
assert len(data_class.select()) == 50
|
||||
|
||||
assert len(data_class.select(lambda x: x.id < 26)) == 25
|
||||
assert len(data_class.select([st.Less('id', 26)])) == 25
|
||||
assert len(data_class.select([st.Less('id', 25), st.GreaterOrEqual('id', 10)])) == 15
|
||||
assert len(data_class.select([st.Less('id', 25), st.GreaterOrEqual('id', 10), lambda x: x.id >= 15])) == 10
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from quixote import cleanup
|
||||
from wcs import publisher
|
||||
from wcs.qommon.storage import StorableObject
|
||||
import wcs.qommon.storage as st
|
||||
|
||||
def setup_module(module):
|
||||
cleanup()
|
||||
|
@ -187,3 +188,18 @@ def test_get_with_indexed_value_dict_changes():
|
|||
|
||||
tests = Foobar.get_with_indexed_value('dict_value', '2')
|
||||
assert len(tests) == 2
|
||||
|
||||
|
||||
def test_select():
|
||||
Foobar.wipe()
|
||||
|
||||
for x in range(1, 51):
|
||||
test = Foobar()
|
||||
test.unique_value = x
|
||||
test.store()
|
||||
|
||||
assert len(Foobar.select()) == 50
|
||||
|
||||
assert len(Foobar.select(lambda x: x.unique_value < 26)) == 25
|
||||
assert len(Foobar.select([st.Less('unique_value', 26)])) == 25
|
||||
assert len(Foobar.select([st.Less('unique_value', 25), st.GreaterOrEqual('unique_value', 10)])) == 15
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import operator
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
|
@ -73,6 +74,51 @@ def atomic_write(path, content, async=False):
|
|||
else:
|
||||
doit()
|
||||
|
||||
|
||||
class Criteria(object):
|
||||
def __init__(self, attribute, value):
|
||||
self.attribute = attribute
|
||||
self.value = value
|
||||
|
||||
def build_lambda(self):
|
||||
return lambda x: self.op(getattr(x, self.attribute), self.value)
|
||||
|
||||
|
||||
class Less(Criteria):
|
||||
op = operator.lt
|
||||
|
||||
class Greater(Criteria):
|
||||
op = operator.gt
|
||||
|
||||
class Equal(Criteria):
|
||||
op = operator.eq
|
||||
|
||||
class LessOrEqual(Criteria):
|
||||
op = operator.le
|
||||
|
||||
class GreaterOrEqual(Criteria):
|
||||
op = operator.ge
|
||||
|
||||
|
||||
def parse_clause(clause):
|
||||
# creates a callable out of a clause
|
||||
# (attribute, operator, value)
|
||||
|
||||
if callable(clause): # already a callable
|
||||
return clause
|
||||
|
||||
def combine_callables(x1, x2):
|
||||
return lambda x: x1(x) and x2(x)
|
||||
|
||||
func = lambda x: True
|
||||
for element in clause:
|
||||
if callable(element):
|
||||
func = combine_callables(func, element)
|
||||
else:
|
||||
func = combine_callables(func, element.build_lambda())
|
||||
return func
|
||||
|
||||
|
||||
class StorageIndexException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -113,14 +159,14 @@ class StorableObject(object):
|
|||
return len(cls.keys())
|
||||
count = classmethod(count)
|
||||
|
||||
|
||||
def select(cls, clause = None, order_by = None, ignore_errors = False, limit = None):
|
||||
keys = cls.keys()
|
||||
objects = (cls.get(k, ignore_errors = ignore_errors) for k in keys)
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
if clause:
|
||||
objects = (x for x in objects if clause(x))
|
||||
clause_function = parse_clause(clause)
|
||||
objects = (x for x in objects if clause_function(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
|
|
75
wcs/sql.py
75
wcs/sql.py
|
@ -21,7 +21,7 @@ import cPickle
|
|||
|
||||
from quixote import get_publisher
|
||||
import qommon
|
||||
from qommon.storage import _take
|
||||
from qommon.storage import _take, parse_clause as parse_storage_clause
|
||||
from qommon import get_cfg
|
||||
|
||||
import wcs.formdata
|
||||
|
@ -45,6 +45,34 @@ SQL_TYPE_MAPPING = {
|
|||
'password': 'text[][]',
|
||||
}
|
||||
|
||||
|
||||
class Criteria(qommon.storage.Criteria):
|
||||
def __init__(self, attribute, value, **kwargs):
|
||||
self.attribute = attribute
|
||||
self.value = value
|
||||
|
||||
def as_sql(self):
|
||||
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
|
||||
|
||||
def as_sql_param(self):
|
||||
return {'c%s' % id(self.value): self.value}
|
||||
|
||||
class Less(Criteria):
|
||||
sql_op = '<'
|
||||
|
||||
class Greater(Criteria):
|
||||
sql_op = '>'
|
||||
|
||||
class Equal(Criteria):
|
||||
sql_op = '='
|
||||
|
||||
class LessOrEqual(Criteria):
|
||||
sql_op = '<='
|
||||
|
||||
class GreaterOrEqual(Criteria):
|
||||
sql_op = '>='
|
||||
|
||||
|
||||
def get_name_as_sql_identifier(name):
|
||||
name = qommon.misc.simplify(name)
|
||||
for char in '<>|{}!?^*+/\'': # forbidden chars
|
||||
|
@ -53,6 +81,41 @@ def get_name_as_sql_identifier(name):
|
|||
return name
|
||||
|
||||
|
||||
def parse_clause(clause):
|
||||
# returns a three-elements tuple with:
|
||||
# - a list of SQL 'WHERE' clauses
|
||||
# - a dict for query parameters
|
||||
# - a callable, or None if all clauses have been successfully translated
|
||||
|
||||
if clause is None:
|
||||
return ([], None, None)
|
||||
|
||||
if callable(clause): # already a callable
|
||||
return ([], None, clause)
|
||||
|
||||
# create 'WHERE' clauses
|
||||
func_clauses = []
|
||||
where_clauses = []
|
||||
parameters = {}
|
||||
for element in clause:
|
||||
if callable(element):
|
||||
func_clauses.append(element)
|
||||
else:
|
||||
sql_class = globals().get(element.__class__.__name__)
|
||||
if sql_class:
|
||||
sql_element = sql_class(**element.__dict__)
|
||||
where_clauses.append(sql_element.as_sql())
|
||||
parameters.update(sql_element.as_sql_param())
|
||||
else:
|
||||
func_clauses.append(element.build_lambda())
|
||||
|
||||
if func_clauses:
|
||||
return (where_clauses, parameters, parse_storage_clause(func_clauses))
|
||||
else:
|
||||
return (where_clauses, parameters, None)
|
||||
|
||||
|
||||
|
||||
def get_connection(new=False):
|
||||
if new and hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
|
||||
get_publisher().pgconn.close()
|
||||
|
@ -496,15 +559,19 @@ class SqlMixin:
|
|||
', '.join([x[0] for x in cls._table_static_fields]
|
||||
+ cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement)
|
||||
where_clauses, parameters, func_clause = parse_clause(clause)
|
||||
if where_clauses:
|
||||
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
|
||||
|
||||
cur.execute(sql_statement, parameters)
|
||||
objects = cls.get_objects(cur)
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
if clause:
|
||||
objects = (x for x in objects if clause(x))
|
||||
if func_clause:
|
||||
objects = (x for x in objects if func_clause(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
|
|
Loading…
Reference in New Issue