diff --git a/mellon/adapters.py b/mellon/adapters.py index f23fe68..8f7018d 100644 --- a/mellon/adapters.py +++ b/mellon/adapters.py @@ -6,9 +6,11 @@ from django.contrib.auth.models import Group from . import utils, app_settings, models -log = logging.getLogger(__name__) class DefaultAdapter(object): + def __init__(self, *args, **kwargs): + self.logger = logging.getLogger(__name__) + def get_idp(self, entity_id): '''Find the first IdP definition matching entity_id''' for idp in app_settings.IDENTITY_PROVIDERS: @@ -36,12 +38,12 @@ class DefaultAdapter(object): username = unicode(username_template).format( realm=realm, attributes=saml_attributes, idp=idp)[:30] except ValueError: - log.error('invalid username template %r'. username_template) + self.logger.error(u'invalid username template %r', username_template) except (AttributeError, KeyError, IndexError), e: - log.error('invalid reference in username template %r: %s', + self.logger.error(u'invalid reference in username template %r: %s', username_template, e) except Exception, e: - log.exception('unknown error when formatting username') + self.logger.exception(u'unknown error when formatting username') else: return username @@ -82,9 +84,9 @@ class DefaultAdapter(object): try: value = unicode(tpl).format(realm=realm, attributes=saml_attributes, idp=idp) except ValueError: - log.warning('invalid attribute mapping template %r', tpl) + self.logger.warning(u'invalid attribute mapping template %r', tpl) except (AttributeError, KeyError, IndexError, ValueError), e: - log.warning('invalid reference in attribute mapping template %r: %s', tpl, e) + self.logger.warning(u'invalid reference in attribute mapping template %r: %s', tpl, e) else: attribute_set = True model_field = user._meta.get_field(field) @@ -118,6 +120,7 @@ class DefaultAdapter(object): user.save() def provision_groups(self, user, idp, saml_attributes): + User = user.__class__ group_attribute = utils.get_setting(idp, 'GROUP_ATTRIBUTE') create_group = utils.get_setting(idp, 'CREATE_GROUP') if group_attribute in saml_attributes: @@ -134,4 +137,11 @@ class DefaultAdapter(object): except Group.DoesNotExist: continue groups.append(group) - user.groups = groups + for group in Group.objects.filter(pk__in=[g.pk for g in groups]).exclude(user=user): + self.logger.info(u'adding group %s (%s) to user %s (%s)', group, group.pk, user, user.pk) + User.groups.through.objects.get_or_create(group=group, user=user) + qs = User.groups.through.objects.exclude(group__pk__in=[g.pk for g in groups]).filter(user=user) + for rel in qs: + self.logger.info(u'removing group %s (%s) from user %s (%s)', rel.group, + rel.group.pk, rel.user, rel.user.pk) + qs.delete() diff --git a/tests/test_default_adapter.py b/tests/test_default_adapter.py index 5fc9b31..d84f8f0 100644 --- a/tests/test_default_adapter.py +++ b/tests/test_default_adapter.py @@ -16,6 +16,7 @@ saml_attributes = { 'first_name': ['Foo'], 'last_name': ['Bar'], 'is_superuser': ['true'], + 'group': ['GroupA', 'GroupB', 'GroupC'], } def test_format_username(settings): @@ -45,6 +46,7 @@ def test_lookup_user(settings): assert User.objects.count() == 0 def test_provision(settings): + settings.MELLON_GROUP_ATTRIBUTE = 'group' User = auth.get_user_model() adapter = DefaultAdapter() settings.MELLON_ATTRIBUTE_MAPPING = { @@ -59,6 +61,13 @@ def test_provision(settings): assert user.last_name == 'Bar' assert user.email == 'test@example.net' assert user.is_superuser == False + assert user.groups.count() == 3 + assert set(user.groups.values_list('name', flat=True)) == set(saml_attributes['group']) + saml_attributes2 = saml_attributes.copy() + saml_attributes2['group'] = ['GroupB', 'GroupC'] + adapter.provision(user, idp, saml_attributes2) + assert user.groups.count() == 2 + assert set(user.groups.values_list('name', flat=True)) == set(saml_attributes2['group']) User.objects.all().delete() settings.MELLON_SUPERUSER_MAPPING = {