fix concurrency error when creating new users (fixes #9965)
UserSAMLIdentifier is retrieved using get_or_create() first, and if is new we proceed with the creation of the new user, otherwise we delete the temporaru user we created use the one attached to the existing UserSAMLIdentifier.
This commit is contained in:
parent
359a2f4be0
commit
e641c6ec96
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.contrib import auth
|
||||
|
@ -53,16 +54,24 @@ class DefaultAdapter(object):
|
|||
issuer = saml_attributes['issuer']
|
||||
try:
|
||||
return User.objects.get(saml_identifiers__name_id=name_id,
|
||||
saml_identifiers__issuer=issuer)
|
||||
saml_identifiers__issuer=issuer)
|
||||
except User.DoesNotExist:
|
||||
if not utils.get_setting(idp, 'PROVISION'):
|
||||
self.logger.warning('provisionning disabled, login refused')
|
||||
return None
|
||||
username = self.format_username(idp, saml_attributes)
|
||||
if not username:
|
||||
self.logger.warning('could not build a username, login refused')
|
||||
return None
|
||||
user = User(username=username)
|
||||
user.save()
|
||||
self.provision_name_id(user, idp, saml_attributes)
|
||||
user = User.objects.create(username=uuid.uuid4().hex[:30])
|
||||
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
|
||||
name_id=name_id, issuer=issuer, defaults={'user': user})
|
||||
if created:
|
||||
user.username = username
|
||||
user.save()
|
||||
else:
|
||||
user.delete()
|
||||
user = saml_id.user
|
||||
return user
|
||||
|
||||
def provision(self, user, idp, saml_attributes):
|
||||
|
@ -70,12 +79,6 @@ class DefaultAdapter(object):
|
|||
self.provision_superuser(user, idp, saml_attributes)
|
||||
self.provision_groups(user, idp, saml_attributes)
|
||||
|
||||
def provision_name_id(self, user, idp, saml_attributes):
|
||||
models.UserSAMLIdentifier.objects.get_or_create(
|
||||
user=user,
|
||||
issuer=saml_attributes['issuer'],
|
||||
name_id=saml_attributes['name_id_content'])
|
||||
|
||||
def provision_attribute(self, user, idp, saml_attributes):
|
||||
realm = utils.get_setting(idp, 'REALM')
|
||||
attribute_mapping = utils.get_setting(idp, 'ATTRIBUTE_MAPPING')
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def concurrency(settings):
|
||||
'''Select a level of concurrency based on the db, sqlite3 is less robust
|
||||
thant postgres due to its transaction lock timeout of 5 seconds.
|
||||
'''
|
||||
if 'sqlite' in settings.DATABASES['default']['ENGINE']:
|
||||
return 20
|
||||
else:
|
||||
return 100
|
|
@ -1,7 +1,9 @@
|
|||
import threading
|
||||
import pytest
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
from django.db import connection
|
||||
|
||||
from mellon.adapters import DefaultAdapter
|
||||
|
||||
|
@ -45,6 +47,27 @@ def test_lookup_user(settings):
|
|||
assert user is None
|
||||
assert User.objects.count() == 0
|
||||
|
||||
|
||||
def test_lookup_user_transaction(transactional_db, concurrency):
|
||||
adapter = DefaultAdapter()
|
||||
N = 30
|
||||
def map_threads(f, l):
|
||||
threads = []
|
||||
for i in l:
|
||||
threads.append(threading.Thread(target=f, args=(i,)))
|
||||
threads[-1].start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
users = []
|
||||
|
||||
def f(i):
|
||||
users.append(adapter.lookup_user(idp, saml_attributes))
|
||||
connection.close()
|
||||
map_threads(f, range(concurrency))
|
||||
assert len(users) == concurrency
|
||||
assert len(set(user.pk for user in users)) == 1
|
||||
|
||||
|
||||
def test_provision(settings):
|
||||
settings.MELLON_GROUP_ATTRIBUTE = 'group'
|
||||
User = auth.get_user_model()
|
||||
|
|
|
@ -4,6 +4,9 @@ DATABASES = {
|
|||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': 'mellon.sqlite3',
|
||||
'TEST': {
|
||||
'NAME': 'mellon-test.sqlite',
|
||||
},
|
||||
}
|
||||
}
|
||||
DEBUG = True
|
||||
|
|
Loading…
Reference in New Issue