From f2908b2ef30d249a496b441235f5e7fad3887cd7 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 6 Jun 2019 10:59:28 +0200 Subject: [PATCH] adapters: factorize user linking (#33739) --- mellon/adapters.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/mellon/adapters.py b/mellon/adapters.py index 44f2fd0..a79932b 100644 --- a/mellon/adapters.py +++ b/mellon/adapters.py @@ -14,6 +14,8 @@ from django.utils.encoding import force_text from . import utils, app_settings, models +User = auth.get_user_model() + class UserCreationError(Exception): pass @@ -108,7 +110,6 @@ class DefaultAdapter(object): user.save() def lookup_user(self, idp, saml_attributes): - User = auth.get_user_model() transient_federation_attribute = utils.get_setting(idp, 'TRANSIENT_FEDERATION_ATTRIBUTE') if saml_attributes['name_id_format'] == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT: if (transient_federation_attribute @@ -137,22 +138,29 @@ class DefaultAdapter(object): return None user = self.create_user(User) + nameid_user = self._link_user(idp, saml_attributes, issuer, name_id, user) + if user != nameid_user: + self.logger.info('looked up user %s with name_id %s from issuer %s', + nameid_user, name_id, issuer) + user.delete() + return nameid_user + + try: + self.finish_create_user(idp, saml_attributes, nameid_user) + except UserCreationError: + nameid_user.delete() + return None + self.logger.info('created new user %s with name_id %s from issuer %s', + nameid_user, name_id, issuer) + return nameid_user + + def _link_user(self, idp, saml_attributes, issuer, name_id, user): saml_id, created = models.UserSAMLIdentifier.objects.get_or_create( name_id=name_id, issuer=issuer, defaults={'user': user}) if created: - try: - self.finish_create_user(idp, saml_attributes, user) - except UserCreationError: - user.delete() - return None - self.logger.info('created new user %s with name_id %s from issuer %s', - user, name_id, issuer) + return user else: - user.delete() - user = saml_id.user - self.logger.info('looked up user %s with name_id %s from issuer %s', - user, name_id, issuer) - return user + return saml_id.user def provision(self, user, idp, saml_attributes): self.provision_attribute(user, idp, saml_attributes) @@ -215,7 +223,6 @@ class DefaultAdapter(object): user.save() def provision_groups(self, user, idp, saml_attributes): - User = user.__class__ group_attribute = utils.get_setting(idp, 'GROUP_ATTRIBUTE') create_group = utils.get_setting(idp, 'CREATE_GROUP') if group_attribute in saml_attributes: