This repository has been archived on 2023-02-21. You can view files and clone it, but cannot push or open issues or pull requests.
mandaye/mandaye/backends/sql.py

163 lines
4.8 KiB
Python

from datetime import datetime
from mandaye.db import sql_session
from mandaye.models import IDPUser, SPUser, ServiceProvider
class ManagerIDPUserSQL:
@staticmethod
def get(unique_id, idp_id='default'):
idp_user = sql_session().query(IDPUser).\
filter_by(unique_id=unique_id,
idp_id='default')
if len(idp_user) > 1:
logger.critical('ManagerIDPUserSQL.get %s not unique' % unique_id)
raise MandayeException(
'ManagerIDPUserSQL.get : %s is not unique' % unique_id)
if idp_user:
return idp_user.first()
else:
return None
@staticmethod
def create(unique_id, idp_id='default'):
idp_user = IDPUser(
unique_id=unique_id,
idp_id=idp_id)
sql_session().add(idp_user)
return idp_user
@staticmethod
def get_or_create(unique_id, idp_id='default'):
if ManagerIDPUserSQL.get(**kwargs):
return user
else:
return ManagerIDPUserSQL.create(**kwargs)
@staticmethod
def delete(idp_user):
sql_session().delete(idp_user)
sql_session().commit()
@staticmethod
def save():
sql_session().commit()
class ManagerSPUserSQL:
@staticmethod
def get(login, idp_user, service_provider):
sp_user = sql_session().query(SPPUser).\
join(IDPUser).\
join(ServiceProvider).\
filter_by(login=login,
idp_user=idp_user,
service_provider=service_provider)
if sp_user:
return sp_user.first()
else:
return None
@staticmethod
def get_by_id(id):
return sql_session().query(SPUser).\
filter(id==id).first()
@staticmethod
def get_last_connected(idp_user, service_provider):
return sql_session().query(SPPUser).\
join(IDPUser).\
join(ServiceProvider).\
filter(idp_user=idp_user).\
filer(service_provider=service_provider).\
order_by(SPUser.last_connection.desc()).\
first()
@staticmethod
def get_sp_users(idp_unique_id, service_provider_name):
return sql_session().query(SPUser).\
join(IDPUser).\
join(ServiceProvider).\
filter(IDPUser.unique_id==idp_unique_id).\
filter(ServiceProvider.name==service_provider_name).\
order_by(SPUser.last_connection.desc()).\
all()
@staticmethod
def create(login, post_values, idp_user, service_provider):
sp_user = SPUser(
login=login,
post_values=post_values,
idp_id=idp_id,
service_provider = service_provider
)
logger.info('New association: %s with %s on site %s' % \
(login, idp_user.unique_id, service_provider.name))
sql_session().add(sp_user)
sql_session().commit()
return idp_user
@staticmethod
def get_or_create(login, post_values, idp_user, service_provider):
sp_user = ManagerSPUserSQL.get(login, idp_user, service_provider)
if sp_user:
return sp_user
else:
return ManagerSPUserSQL.create(login, post_values,
idp_user, service_provider)
@staticmethod
def all():
return sql_session().query(SPUser).all()
@staticmethod
def delete(sp_user):
logger.debug('Disassociate account %s' % sp_user.login)
sql_session().delete(sp_user)
sql_session().commit()
@staticmethod
def save():
sql_session().commit()
class ManagerServiceProviderSQL:
@staticmethod
def get(name):
sp = sql_session().query(ServiceProvider).\
filter_by(name=name)
if sp:
return sp.first()
else:
return None
@staticmethod
def create(name):
logger.info('Add %s service provider into the database' % name)
sp = ServiceProvider(name=name)
sql_session().add(sp)
sql_session().commit()
return sp
@staticmethod
def get_or_create(name):
sp = ManagerServiceProviderSQL.get(name)
if sp:
return sp
else:
return ManagerServiceProviderSQL.create(name)
@staticmethod
def delete(service_provider):
logger.debug('Delete service provider %s' % service_provider.name)
sql_session().delete(service_provider)
sql_session().commit()
@staticmethod
def save():
sql_session().commit()
ManagerServiceProvider = ManagerServiceProviderSQL
ManagerSPUser = ManagerSPUserSQL
ManagerServiceProvider = ManagerServiceProviderSQL