add dedicated method to get users with a given name identifier (#4951)

This commit is contained in:
Frédéric Péters 2014-06-12 11:25:29 +02:00
parent 448450c581
commit 6e21e2bc40
5 changed files with 85 additions and 2 deletions

View File

@ -378,6 +378,24 @@ def test_get_users_with_role():
assert len(sql.SqlUser.get_users_with_role(1)) == 1
@postgresql
def test_get_users_with_name_identifier():
sql.SqlUser.wipe()
user = sql.SqlUser()
user.name = 'Pierre'
user.name_identifiers = ['foo']
user.store()
user_id = user.id
user = sql.SqlUser()
user.name = 'Papier'
user.store()
assert len(sql.SqlUser.get_users_with_name_identifier('foo')) == 1
assert sql.SqlUser.get_users_with_name_identifier('foo')[0].name == 'Pierre'
@postgresql
def test_urlname_change():
global formef

43
tests/test_users.py Normal file
View File

@ -0,0 +1,43 @@
import datetime
import os
import random
import shutil
import sys
import tempfile
import time
from quixote import cleanup
from wcs import publisher
def setup_module(module):
cleanup()
global pub
publisher.WcsPublisher.APP_DIR = tempfile.mkdtemp()
pub = publisher.WcsPublisher.create_publisher()
def teardown_module(module):
shutil.rmtree(pub.APP_DIR)
def test_get_users_with_name_identifier():
pub.user_class.wipe()
user = pub.user_class()
user.name = 'Pierre'
user.name_identifiers = ['foo']
user.store()
user_id = user.id
user = pub.user_class()
user.name = 'Papier'
user.store()
assert len(pub.user_class.get_users_with_name_identifier('foo')) == 1
assert pub.user_class.get_users_with_name_identifier('foo')[0].name == 'Pierre'

View File

@ -428,7 +428,7 @@ class Saml2Directory(Directory):
session.name_identifier = ni
else:
ni = name_id
nis = list(get_publisher().user_class.select(lambda x: ni in x.name_identifiers))
nis = list(get_publisher().user_class.get_users_with_name_identifier(ni))
if nis:
user = nis[0]
else:
@ -706,7 +706,7 @@ class Saml2Directory(Directory):
session = get_session()
user = None
ni = manage.nameIdentifier.content
nis = list(get_publisher().user_class.select(lambda x: ni in x.name_identifiers))
nis = list(get_publisher().user_class.get_users_with_name_identifier(ni))
if nis:
user = nis[0]
nis = nis[1:]

View File

@ -992,6 +992,24 @@ class SqlUser(SqlMixin, wcs.users.User):
cur.close()
fix_sequences = classmethod(fix_sequences)
@guard_postgres
def get_users_with_name_identifier(cls, name_identifier):
conn, cur = get_connection_and_cursor()
sql_statement = '''SELECT %s
FROM %s
WHERE %%(value)s = ANY(name_identifiers)''' % (
', '.join([x[0] for x in cls._table_static_fields]
+ cls.get_data_fields()),
cls._table_name)
cur.execute(sql_statement, {'value': name_identifier})
objects = cls.get_objects(cur)
conn.commit()
cur.close()
return objects
get_users_with_name_identifier = classmethod(get_users_with_name_identifier)
@guard_postgres
def get_users_with_role(cls, role_id):
conn, cur = get_connection_and_cursor()

View File

@ -137,6 +137,10 @@ class User(StorableObject):
return users_with_role
get_users_with_role = classmethod(get_users_with_role)
def get_users_with_name_identifier(cls, name_identifier):
return cls.select(lambda x: name_identifier in x.name_identifiers)
get_users_with_name_identifier = classmethod(get_users_with_name_identifier)
def get_substitution_variables(self, prefix='session_'):
d = {
prefix+'user': self,