adapters: factorize user linking (#33739)

This commit is contained in:
Benjamin Dauvergne 2019-06-06 10:59:28 +02:00 committed by Thomas NOEL
parent e0c1f5b43c
commit f2908b2ef3
1 changed files with 21 additions and 14 deletions

View File

@ -14,6 +14,8 @@ from django.utils.encoding import force_text
from . import utils, app_settings, models
User = auth.get_user_model()
class UserCreationError(Exception):
pass
@ -108,7 +110,6 @@ class DefaultAdapter(object):
user.save()
def lookup_user(self, idp, saml_attributes):
User = auth.get_user_model()
transient_federation_attribute = utils.get_setting(idp, 'TRANSIENT_FEDERATION_ATTRIBUTE')
if saml_attributes['name_id_format'] == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
if (transient_federation_attribute
@ -137,22 +138,29 @@ class DefaultAdapter(object):
return None
user = self.create_user(User)
nameid_user = self._link_user(idp, saml_attributes, issuer, name_id, user)
if user != nameid_user:
self.logger.info('looked up user %s with name_id %s from issuer %s',
nameid_user, name_id, issuer)
user.delete()
return nameid_user
try:
self.finish_create_user(idp, saml_attributes, nameid_user)
except UserCreationError:
nameid_user.delete()
return None
self.logger.info('created new user %s with name_id %s from issuer %s',
nameid_user, name_id, issuer)
return nameid_user
def _link_user(self, idp, saml_attributes, issuer, name_id, user):
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
name_id=name_id, issuer=issuer, defaults={'user': user})
if created:
try:
self.finish_create_user(idp, saml_attributes, user)
except UserCreationError:
user.delete()
return None
self.logger.info('created new user %s with name_id %s from issuer %s',
user, name_id, issuer)
return user
else:
user.delete()
user = saml_id.user
self.logger.info('looked up user %s with name_id %s from issuer %s',
user, name_id, issuer)
return user
return saml_id.user
def provision(self, user, idp, saml_attributes):
self.provision_attribute(user, idp, saml_attributes)
@ -215,7 +223,6 @@ class DefaultAdapter(object):
user.save()
def provision_groups(self, user, idp, saml_attributes):
User = user.__class__
group_attribute = utils.get_setting(idp, 'GROUP_ATTRIBUTE')
create_group = utils.get_setting(idp, 'CREATE_GROUP')
if group_attribute in saml_attributes: