adapters: prevent collision in provision_groups() (fixes #9327)

Assiging related m2m fields provokes a bulk insert which is not safe with
respect to concurrent writes, we replace this by use of get_or_create() and
delete() on the through model of the User.groups field.
This commit is contained in:
Benjamin Dauvergne 2015-12-14 16:39:05 +01:00
parent 78762accf7
commit e18dd7c7e5
2 changed files with 26 additions and 7 deletions

View File

@ -6,9 +6,11 @@ from django.contrib.auth.models import Group
from . import utils, app_settings, models
log = logging.getLogger(__name__)
class DefaultAdapter(object):
def __init__(self, *args, **kwargs):
self.logger = logging.getLogger(__name__)
def get_idp(self, entity_id):
'''Find the first IdP definition matching entity_id'''
for idp in app_settings.IDENTITY_PROVIDERS:
@ -36,12 +38,12 @@ class DefaultAdapter(object):
username = unicode(username_template).format(
realm=realm, attributes=saml_attributes, idp=idp)[:30]
except ValueError:
log.error('invalid username template %r'. username_template)
self.logger.error(u'invalid username template %r', username_template)
except (AttributeError, KeyError, IndexError), e:
log.error('invalid reference in username template %r: %s',
self.logger.error(u'invalid reference in username template %r: %s',
username_template, e)
except Exception, e:
log.exception('unknown error when formatting username')
self.logger.exception(u'unknown error when formatting username')
else:
return username
@ -82,9 +84,9 @@ class DefaultAdapter(object):
try:
value = unicode(tpl).format(realm=realm, attributes=saml_attributes, idp=idp)
except ValueError:
log.warning('invalid attribute mapping template %r', tpl)
self.logger.warning(u'invalid attribute mapping template %r', tpl)
except (AttributeError, KeyError, IndexError, ValueError), e:
log.warning('invalid reference in attribute mapping template %r: %s', tpl, e)
self.logger.warning(u'invalid reference in attribute mapping template %r: %s', tpl, e)
else:
attribute_set = True
model_field = user._meta.get_field(field)
@ -118,6 +120,7 @@ class DefaultAdapter(object):
user.save()
def provision_groups(self, user, idp, saml_attributes):
User = user.__class__
group_attribute = utils.get_setting(idp, 'GROUP_ATTRIBUTE')
create_group = utils.get_setting(idp, 'CREATE_GROUP')
if group_attribute in saml_attributes:
@ -134,4 +137,11 @@ class DefaultAdapter(object):
except Group.DoesNotExist:
continue
groups.append(group)
user.groups = groups
for group in Group.objects.filter(pk__in=[g.pk for g in groups]).exclude(user=user):
self.logger.info(u'adding group %s (%s) to user %s (%s)', group, group.pk, user, user.pk)
User.groups.through.objects.get_or_create(group=group, user=user)
qs = User.groups.through.objects.exclude(group__pk__in=[g.pk for g in groups]).filter(user=user)
for rel in qs:
self.logger.info(u'removing group %s (%s) from user %s (%s)', rel.group,
rel.group.pk, rel.user, rel.user.pk)
qs.delete()

View File

@ -16,6 +16,7 @@ saml_attributes = {
'first_name': ['Foo'],
'last_name': ['Bar'],
'is_superuser': ['true'],
'group': ['GroupA', 'GroupB', 'GroupC'],
}
def test_format_username(settings):
@ -45,6 +46,7 @@ def test_lookup_user(settings):
assert User.objects.count() == 0
def test_provision(settings):
settings.MELLON_GROUP_ATTRIBUTE = 'group'
User = auth.get_user_model()
adapter = DefaultAdapter()
settings.MELLON_ATTRIBUTE_MAPPING = {
@ -59,6 +61,13 @@ def test_provision(settings):
assert user.last_name == 'Bar'
assert user.email == 'test@example.net'
assert user.is_superuser == False
assert user.groups.count() == 3
assert set(user.groups.values_list('name', flat=True)) == set(saml_attributes['group'])
saml_attributes2 = saml_attributes.copy()
saml_attributes2['group'] = ['GroupB', 'GroupC']
adapter.provision(user, idp, saml_attributes2)
assert user.groups.count() == 2
assert set(user.groups.values_list('name', flat=True)) == set(saml_attributes2['group'])
User.objects.all().delete()
settings.MELLON_SUPERUSER_MAPPING = {