storage: add declarative clauses to select() (#5931)

This commit is contained in:
Frédéric Péters 2014-11-13 21:05:05 +01:00
parent 8fcee2f876
commit 251be76550
4 changed files with 157 additions and 6 deletions

View File

@ -12,6 +12,7 @@ from wcs import formdef, publisher, fields
from wcs.formdef import FormDef
from wcs.formdata import Evolution
from wcs import sql
import wcs.qommon.storage as st
sys.modules['formdef'] = formdef
@ -477,3 +478,24 @@ def test_sql_table_add_and_remove_fields():
data_class = test_formdef.data_class(mode='sql')
data_class.select()
@postgresql
def test_sql_table_select():
test_formdef = FormDef()
test_formdef.name = 'table select'
test_formdef.fields = []
test_formdef.store()
data_class = test_formdef.data_class(mode='sql')
assert data_class.count() == 0
for i in range(50):
t = data_class()
t.store()
assert data_class.count() == 50
assert len(data_class.select()) == 50
assert len(data_class.select(lambda x: x.id < 26)) == 25
assert len(data_class.select([st.Less('id', 26)])) == 25
assert len(data_class.select([st.Less('id', 25), st.GreaterOrEqual('id', 10)])) == 15
assert len(data_class.select([st.Less('id', 25), st.GreaterOrEqual('id', 10), lambda x: x.id >= 15])) == 10

View File

@ -8,6 +8,7 @@ import pytest
from quixote import cleanup
from wcs import publisher
from wcs.qommon.storage import StorableObject
import wcs.qommon.storage as st
def setup_module(module):
cleanup()
@ -187,3 +188,18 @@ def test_get_with_indexed_value_dict_changes():
tests = Foobar.get_with_indexed_value('dict_value', '2')
assert len(tests) == 2
def test_select():
Foobar.wipe()
for x in range(1, 51):
test = Foobar()
test.unique_value = x
test.store()
assert len(Foobar.select()) == 50
assert len(Foobar.select(lambda x: x.unique_value < 26)) == 25
assert len(Foobar.select([st.Less('unique_value', 26)])) == 25
assert len(Foobar.select([st.Less('unique_value', 25), st.GreaterOrEqual('unique_value', 10)])) == 15

View File

@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>.
import operator
import os
import time
import pickle
@ -73,6 +74,51 @@ def atomic_write(path, content, async=False):
else:
doit()
class Criteria(object):
def __init__(self, attribute, value):
self.attribute = attribute
self.value = value
def build_lambda(self):
return lambda x: self.op(getattr(x, self.attribute), self.value)
class Less(Criteria):
op = operator.lt
class Greater(Criteria):
op = operator.gt
class Equal(Criteria):
op = operator.eq
class LessOrEqual(Criteria):
op = operator.le
class GreaterOrEqual(Criteria):
op = operator.ge
def parse_clause(clause):
# creates a callable out of a clause
# (attribute, operator, value)
if callable(clause): # already a callable
return clause
def combine_callables(x1, x2):
return lambda x: x1(x) and x2(x)
func = lambda x: True
for element in clause:
if callable(element):
func = combine_callables(func, element)
else:
func = combine_callables(func, element.build_lambda())
return func
class StorageIndexException(Exception):
pass
@ -113,14 +159,14 @@ class StorableObject(object):
return len(cls.keys())
count = classmethod(count)
def select(cls, clause = None, order_by = None, ignore_errors = False, limit = None):
keys = cls.keys()
objects = (cls.get(k, ignore_errors = ignore_errors) for k in keys)
if ignore_errors:
objects = (x for x in objects if x is not None)
if clause:
objects = (x for x in objects if clause(x))
clause_function = parse_clause(clause)
objects = (x for x in objects if clause_function(x))
if order_by:
order_by = str(order_by)
if order_by[0] == '-':

View File

@ -21,7 +21,7 @@ import cPickle
from quixote import get_publisher
import qommon
from qommon.storage import _take
from qommon.storage import _take, parse_clause as parse_storage_clause
from qommon import get_cfg
import wcs.formdata
@ -45,6 +45,34 @@ SQL_TYPE_MAPPING = {
'password': 'text[][]',
}
class Criteria(qommon.storage.Criteria):
def __init__(self, attribute, value, **kwargs):
self.attribute = attribute
self.value = value
def as_sql(self):
return '%s %s %%(c%s)s' % (self.attribute, self.sql_op, id(self.value))
def as_sql_param(self):
return {'c%s' % id(self.value): self.value}
class Less(Criteria):
sql_op = '<'
class Greater(Criteria):
sql_op = '>'
class Equal(Criteria):
sql_op = '='
class LessOrEqual(Criteria):
sql_op = '<='
class GreaterOrEqual(Criteria):
sql_op = '>='
def get_name_as_sql_identifier(name):
name = qommon.misc.simplify(name)
for char in '<>|{}!?^*+/\'': # forbidden chars
@ -53,6 +81,41 @@ def get_name_as_sql_identifier(name):
return name
def parse_clause(clause):
# returns a three-elements tuple with:
# - a list of SQL 'WHERE' clauses
# - a dict for query parameters
# - a callable, or None if all clauses have been successfully translated
if clause is None:
return ([], None, None)
if callable(clause): # already a callable
return ([], None, clause)
# create 'WHERE' clauses
func_clauses = []
where_clauses = []
parameters = {}
for element in clause:
if callable(element):
func_clauses.append(element)
else:
sql_class = globals().get(element.__class__.__name__)
if sql_class:
sql_element = sql_class(**element.__dict__)
where_clauses.append(sql_element.as_sql())
parameters.update(sql_element.as_sql_param())
else:
func_clauses.append(element.build_lambda())
if func_clauses:
return (where_clauses, parameters, parse_storage_clause(func_clauses))
else:
return (where_clauses, parameters, None)
def get_connection(new=False):
if new and hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None:
get_publisher().pgconn.close()
@ -496,15 +559,19 @@ class SqlMixin:
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement)
where_clauses, parameters, func_clause = parse_clause(clause)
if where_clauses:
sql_statement += ' WHERE ' + ' AND '.join(where_clauses)
cur.execute(sql_statement, parameters)
objects = cls.get_objects(cur)
conn.commit()
cur.close()
if ignore_errors:
objects = (x for x in objects if x is not None)
if clause:
objects = (x for x in objects if clause(x))
if func_clause:
objects = (x for x in objects if func_clause(x))
if order_by:
order_by = str(order_by)
if order_by[0] == '-':