goin mixin
This commit is contained in:
parent
ee2d853b5c
commit
cfb321fac7
470
wcs/sql.py
470
wcs/sql.py
|
@ -126,10 +126,179 @@ def do_user_table():
|
|||
cur.close()
|
||||
|
||||
|
||||
class SqlFormData(wcs.formdata.FormData):
|
||||
class SqlMixin:
|
||||
def keys(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
ids = [x[0] for x in cur.fetchall()]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return ids
|
||||
keys = classmethod(keys)
|
||||
|
||||
def count(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
count = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return count
|
||||
count = classmethod(count)
|
||||
|
||||
def get_with_indexed_value(cls, index, value, ignore_errors = False):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT %s, %s
|
||||
FROM %s
|
||||
WHERE %s = %%(value)s''' % (
|
||||
', '.join([x[0] for x in cls._table_static_fields]
|
||||
+ cls.get_data_fields()),
|
||||
cls._table_name,
|
||||
index)
|
||||
cur.execute(sql_statement, {'value': str(value)})
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
|
||||
return list(objects)
|
||||
get_with_indexed_value = classmethod(get_with_indexed_value)
|
||||
|
||||
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
||||
if id is None:
|
||||
if ignore_errors:
|
||||
return None
|
||||
else:
|
||||
raise KeyError()
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
sql_statement = '''SELECT %s
|
||||
FROM %s
|
||||
WHERE id = %%(id)s''' % (
|
||||
', '.join([x[0] for x in cls._table_static_fields]
|
||||
+ cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement, {'id': str(id)})
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
cur.close()
|
||||
raise KeyError()
|
||||
cur.close()
|
||||
return cls._row2ob(row)
|
||||
get = classmethod(get)
|
||||
|
||||
def select(cls, clause = None, order_by = None, ignore_errors = False, limit = None):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT %s
|
||||
FROM %s''' % (
|
||||
', '.join([x[0] for x in cls._table_static_fields]
|
||||
+ cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement)
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
if clause:
|
||||
objects = (x for x in objects if clause(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
reverse = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse = False
|
||||
# only list can be sorted
|
||||
objects = list(objects)
|
||||
objects.sort(lambda x,y: cmp(getattr(x, order_by), getattr(y, order_by)))
|
||||
if reverse:
|
||||
objects.reverse()
|
||||
if limit:
|
||||
objects = _take(objects, limit)
|
||||
return list(objects)
|
||||
select = classmethod(select)
|
||||
|
||||
def get_sql_dict_from_data(self, data, formdef):
|
||||
sql_dict = {}
|
||||
columns = data.keys()
|
||||
for field in formdef.fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = self.data.get(field.id)
|
||||
if value is not None:
|
||||
if field.type == 'ranked-items':
|
||||
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
|
||||
value = [[x, str(y)] for x, y in value.items()]
|
||||
elif sql_type == 'varchar':
|
||||
pass
|
||||
elif sql_type == 'date':
|
||||
value = datetime.datetime.fromtimestamp(time.mktime(value))
|
||||
elif sql_type == 'bytea':
|
||||
value = bytearray(cPickle.dumps(value))
|
||||
elif sql_type == 'boolean':
|
||||
pass
|
||||
sql_dict['f%s' % field.id] = value
|
||||
return sql_dict
|
||||
|
||||
def _row2obdata(cls, row, formdef):
|
||||
obdata = {}
|
||||
i = len(cls._table_static_fields)
|
||||
for field in formdef.fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = row[i]
|
||||
if value:
|
||||
if field.type == 'ranked-items':
|
||||
d = {}
|
||||
for data, rank in value:
|
||||
d[data] = int(rank)
|
||||
value = d
|
||||
if sql_type == 'date':
|
||||
value = value.timetuple()
|
||||
elif sql_type == 'bytea':
|
||||
value = cPickle.loads(str(value))
|
||||
obdata[field.id] = value
|
||||
i += 1
|
||||
return obdata
|
||||
_row2obdata = classmethod(_row2obdata)
|
||||
|
||||
|
||||
|
||||
class SqlFormData(SqlMixin, wcs.formdata.FormData):
|
||||
_names = None # make sure StorableObject methods fail
|
||||
_formdef = None
|
||||
|
||||
_table_static_fields = [
|
||||
('id', 'serial'),
|
||||
('user_id', 'varchar'),
|
||||
('user_hash', 'varchar'),
|
||||
('receipt_time', 'timestamp'),
|
||||
('status', 'varchar')
|
||||
]
|
||||
|
||||
def __init__(self, id=None):
|
||||
self.id = id
|
||||
|
||||
|
@ -169,32 +338,6 @@ class SqlFormData(wcs.formdata.FormData):
|
|||
|
||||
evolution = property(get_evolution, set_evolution)
|
||||
|
||||
def get_with_indexed_value(cls, index, value, ignore_errors = False):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT id, user_id, user_hash, receipt_time, status, %s
|
||||
FROM %s
|
||||
WHERE %s = %%(value)s''' % (
|
||||
', '.join(cls.get_data_fields()),
|
||||
cls._table_name,
|
||||
index)
|
||||
cur.execute(sql_statement, {'value': str(value)})
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
|
||||
return list(objects)
|
||||
get_with_indexed_value = classmethod(get_with_indexed_value)
|
||||
|
||||
def store(self):
|
||||
sql_dict = {
|
||||
'user_id': self.user_id,
|
||||
|
@ -202,25 +345,7 @@ class SqlFormData(wcs.formdata.FormData):
|
|||
'receipt_time': datetime.datetime.fromtimestamp(time.mktime(self.receipt_time)),
|
||||
'status': self.status
|
||||
}
|
||||
columns = self.data.keys()
|
||||
for field in self._formdef.fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = self.data.get(field.id)
|
||||
if value is not None:
|
||||
if field.type == 'ranked-items':
|
||||
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
|
||||
value = [[x, str(y)] for x, y in value.items()]
|
||||
elif sql_type == 'varchar':
|
||||
pass
|
||||
elif sql_type == 'date':
|
||||
value = datetime.datetime.fromtimestamp(time.mktime(value))
|
||||
elif sql_type == 'bytea':
|
||||
value = bytearray(cPickle.dumps(value))
|
||||
elif sql_type == 'boolean':
|
||||
pass
|
||||
sql_dict['f%s' % field.id] = value
|
||||
sql_dict.update(self.get_sql_dict_from_data(self.data, self._formdef))
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
if not self.id:
|
||||
|
@ -271,52 +396,12 @@ class SqlFormData(wcs.formdata.FormData):
|
|||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
def keys(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
ids = [x[0] for x in cur.fetchall()]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return ids
|
||||
keys = classmethod(keys)
|
||||
|
||||
def count(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
count = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return count
|
||||
count = classmethod(count)
|
||||
|
||||
def _row2ob(cls, row):
|
||||
o = cls()
|
||||
o.id, o.user_id, o.user_hash, o.receipt_time, o.status = tuple(row[:5])
|
||||
if o.receipt_time:
|
||||
o.receipt_time = o.receipt_time.timetuple()
|
||||
o.data = {}
|
||||
i = 5
|
||||
for field in cls._formdef.fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = row[i]
|
||||
if value:
|
||||
if field.type == 'ranked-items':
|
||||
d = {}
|
||||
for data, rank in value:
|
||||
d[data] = int(rank)
|
||||
value = d
|
||||
if sql_type == 'date':
|
||||
value = value.timetuple()
|
||||
elif sql_type == 'bytea':
|
||||
value = cPickle.loads(str(value))
|
||||
o.data[field.id] = value
|
||||
i += 1
|
||||
o.data = cls._row2obdata(row, cls._formdef)
|
||||
return o
|
||||
_row2ob = classmethod(_row2ob)
|
||||
|
||||
|
@ -354,47 +439,21 @@ class SqlFormData(wcs.formdata.FormData):
|
|||
get = classmethod(get)
|
||||
|
||||
|
||||
def select(cls, clause = None, order_by = None, ignore_errors = False, limit = None):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT id, user_id, user_hash, receipt_time, status, %s
|
||||
FROM %s''' % (
|
||||
', '.join(cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement)
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
if clause:
|
||||
objects = (x for x in objects if clause(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
reverse = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse = False
|
||||
# only list can be sorted
|
||||
objects = list(objects)
|
||||
objects.sort(lambda x,y: cmp(getattr(x, order_by), getattr(y, order_by)))
|
||||
if reverse:
|
||||
objects.reverse()
|
||||
if limit:
|
||||
objects = _take(objects, limit)
|
||||
return list(objects)
|
||||
select = classmethod(select)
|
||||
|
||||
|
||||
class SqlUser(wcs.users.User):
|
||||
class SqlUser(SqlMixin, wcs.users.User):
|
||||
_table_name = 'users'
|
||||
_table_static_fields = [
|
||||
('id', 'serial'),
|
||||
('name', 'varchar'),
|
||||
('email', 'varchar'),
|
||||
('roles', 'varchar[]'),
|
||||
('is_admin', 'bool'),
|
||||
('anonymous', 'bool'),
|
||||
('name_identifiers', 'varchar[]'),
|
||||
('identification_token', 'varchar'),
|
||||
('lasso_dump', 'text'),
|
||||
('last_seen', 'timestamp')
|
||||
]
|
||||
|
||||
id = None
|
||||
|
||||
def __init__(self, name=None):
|
||||
|
@ -402,34 +461,6 @@ class SqlUser(wcs.users.User):
|
|||
self.name_identifiers = []
|
||||
self.roles = []
|
||||
|
||||
def get_with_indexed_value(cls, index, value, ignore_errors = False):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT id, name, email, roles, is_admin, anonymous,
|
||||
name_identifiers, identification_token,
|
||||
lasso_dump, last_seen, %s
|
||||
FROM %s
|
||||
WHERE %s = %%(value)s''' % (
|
||||
', '.join(cls.get_data_fields()),
|
||||
cls._table_name,
|
||||
index)
|
||||
cur.execute(sql_statement, {'value': str(value)})
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
|
||||
return list(objects)
|
||||
get_with_indexed_value = classmethod(get_with_indexed_value)
|
||||
|
||||
def store(self):
|
||||
sql_dict = {
|
||||
'name': self.name,
|
||||
|
@ -445,25 +476,8 @@ class SqlUser(wcs.users.User):
|
|||
if self.last_seen:
|
||||
sql_dict['last_seen'] = datetime.datetime.fromtimestamp(self.last_seen),
|
||||
|
||||
columns = self.form_data.keys()
|
||||
for field in self.get_formdef().fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = self.form_data.get(field.id)
|
||||
if value is not None:
|
||||
if field.type == 'ranked-items':
|
||||
# turn {'poire': 2, 'abricot': 1, 'pomme': 3} into an array
|
||||
value = [[x, str(y)] for x, y in value.items()]
|
||||
elif sql_type == 'varchar':
|
||||
pass
|
||||
elif sql_type == 'date':
|
||||
value = datetime.datetime.fromtimestamp(time.mktime(value))
|
||||
elif sql_type == 'bytea':
|
||||
value = bytearray(cPickle.dumps(value))
|
||||
elif sql_type == 'boolean':
|
||||
pass
|
||||
sql_dict['f%s' % field.id] = value
|
||||
sql_dict.update(self.get_sql_dict_from_data(self.form_data, self.get_formdef()))
|
||||
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
if not self.id:
|
||||
|
@ -505,28 +519,6 @@ class SqlUser(wcs.users.User):
|
|||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
def keys(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT id FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
ids = [x[0] for x in cur.fetchall()]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return ids
|
||||
keys = classmethod(keys)
|
||||
|
||||
def count(cls):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = 'SELECT count(*) FROM %s' % cls._table_name
|
||||
cur.execute(sql_statement)
|
||||
count = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return count
|
||||
count = classmethod(count)
|
||||
|
||||
def _row2ob(cls, row):
|
||||
o = cls()
|
||||
(o.id, o.name, o.email, o.roles, o.is_admin, o.anonymous,
|
||||
|
@ -534,25 +526,7 @@ class SqlUser(wcs.users.User):
|
|||
o.last_seen) = tuple(row[:10])
|
||||
if o.last_seen:
|
||||
o.last_seen = time.mktime(o.last_seen.timetuple())
|
||||
o.form_data = {}
|
||||
i = 10
|
||||
for field in cls.get_formdef().fields:
|
||||
sql_type = SQL_TYPE_MAPPING.get(field.type, 'varchar')
|
||||
if sql_type is None:
|
||||
continue
|
||||
value = row[i]
|
||||
if value:
|
||||
if field.type == 'ranked-items':
|
||||
d = {}
|
||||
for data, rank in value:
|
||||
d[data] = int(rank)
|
||||
value = d
|
||||
if sql_type == 'date':
|
||||
value = value.timetuple()
|
||||
elif sql_type == 'bytea':
|
||||
value = cPickle.loads(str(value))
|
||||
o.form_data[field.id] = value
|
||||
i += 1
|
||||
o.form_data = cls._row2obdata(row, cls.get_formdef())
|
||||
return o
|
||||
_row2ob = classmethod(_row2ob)
|
||||
|
||||
|
@ -566,69 +540,3 @@ class SqlUser(wcs.users.User):
|
|||
return data_fields
|
||||
get_data_fields = classmethod(get_data_fields)
|
||||
|
||||
def get(cls, id, ignore_errors=False, ignore_migration=False):
|
||||
if id is None:
|
||||
if ignore_errors:
|
||||
return None
|
||||
else:
|
||||
raise KeyError()
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
sql_statement = '''SELECT id, name, email, roles, is_admin, anonymous,
|
||||
name_identifiers, identification_token,
|
||||
lasso_dump, last_seen, %s
|
||||
FROM %s
|
||||
WHERE id = %%(id)s''' % (
|
||||
', '.join(cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement, {'id': str(id)})
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
cur.close()
|
||||
raise KeyError()
|
||||
cur.close()
|
||||
return cls._row2ob(row)
|
||||
get = classmethod(get)
|
||||
|
||||
|
||||
def select(cls, clause = None, order_by = None, ignore_errors = False, limit = None):
|
||||
conn = get_connection()
|
||||
cur = conn.cursor()
|
||||
sql_statement = '''SELECT id, name, email, roles, is_admin, anonymous,
|
||||
name_identifiers, identification_token,
|
||||
lasso_dump, last_seen, %s
|
||||
FROM %s''' % (
|
||||
', '.join(cls.get_data_fields()),
|
||||
cls._table_name)
|
||||
cur.execute(sql_statement)
|
||||
objects = []
|
||||
while True:
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
break
|
||||
objects.append(cls._row2ob(row))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
if ignore_errors:
|
||||
objects = (x for x in objects if x is not None)
|
||||
if clause:
|
||||
objects = (x for x in objects if clause(x))
|
||||
if order_by:
|
||||
order_by = str(order_by)
|
||||
if order_by[0] == '-':
|
||||
reverse = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse = False
|
||||
# only list can be sorted
|
||||
objects = list(objects)
|
||||
objects.sort(lambda x,y: cmp(getattr(x, order_by), getattr(y, order_by)))
|
||||
if reverse:
|
||||
objects.reverse()
|
||||
if limit:
|
||||
objects = _take(objects, limit)
|
||||
return list(objects)
|
||||
select = classmethod(select)
|
||||
|
||||
|
|
Loading…
Reference in New Issue