drop and rename issuer field (#56819)

This commit is contained in:
Benjamin Dauvergne 2021-09-14 21:34:50 +02:00
parent a851b5b2ca
commit 73bfa476ef
8 changed files with 121 additions and 24 deletions

View File

@ -36,7 +36,7 @@ from django.utils.encoding import force_text
from django.utils.six.moves.urllib.parse import urlparse
from django.utils.translation import ugettext as _
from . import app_settings, models, utils
from . import app_settings, models, models_utils, utils
User = auth.get_user_model()
@ -325,14 +325,14 @@ class DefaultAdapter:
return None
else:
name_id = saml_attributes['name_id_content']
issuer = saml_attributes['issuer']
entity_id = saml_attributes['issuer']
try:
saml_identifier = models.UserSAMLIdentifier.objects.select_related('user').get(
name_id=name_id, issuer=issuer
name_id=name_id, issuer=models_utils.get_issuer(entity_id)
)
user = saml_identifier.user
user.saml_identifier = saml_identifier
logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, issuer)
logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, entity_id)
return user
except models.UserSAMLIdentifier.DoesNotExist:
pass
@ -347,10 +347,10 @@ class DefaultAdapter:
created = True
user = self.create_user(User)
nameid_user = self._link_user(idp, saml_attributes, issuer, name_id, user)
nameid_user = self._link_user(idp, saml_attributes, entity_id, name_id, user)
if user != nameid_user:
logger.info(
'mellon: looked up user %s with name_id %s from issuer %s', nameid_user, name_id, issuer
'mellon: looked up user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
)
if created:
user.delete()
@ -363,7 +363,7 @@ class DefaultAdapter:
user.delete()
return None
logger.info(
'mellon: created new user %s with name_id %s from issuer %s', nameid_user, name_id, issuer
'mellon: created new user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
)
return nameid_user
@ -455,9 +455,9 @@ class DefaultAdapter:
)
return None
def _link_user(self, idp, saml_attributes, issuer, name_id, user):
def _link_user(self, idp, saml_attributes, entity_id, name_id, user):
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
name_id=name_id, issuer=issuer, defaults={'user': user}
name_id=name_id, issuer=models_utils.get_issuer(entity_id), defaults={'user': user}
)
if created:
user.saml_identifier = saml_id

View File

@ -0,0 +1,30 @@
# Generated by Django 2.2.19 on 2021-09-14 19:31
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('mellon', '0004_migrate_issuer'),
]
operations = [
migrations.AlterUniqueTogether(
name='usersamlidentifier',
unique_together={('issuer_fk', 'name_id')},
),
migrations.RemoveField(
model_name='usersamlidentifier',
name='issuer',
),
migrations.RenameField(
model_name='usersamlidentifier',
old_name='issuer_fk',
new_name='issuer',
),
migrations.AlterUniqueTogether(
name='usersamlidentifier',
unique_together={('issuer', 'name_id')},
),
]

View File

@ -28,12 +28,9 @@ class UserSAMLIdentifier(models.Model):
related_name='saml_identifiers',
on_delete=models.CASCADE,
)
issuer = models.TextField(verbose_name=_('Issuer'), null=True)
name_id = models.TextField(verbose_name=_('SAML identifier'))
created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True)
issuer_fk = models.ForeignKey(
'mellon.Issuer', verbose_name=_('Issuer'), null=True, on_delete=models.CASCADE
)
issuer = models.ForeignKey('mellon.Issuer', verbose_name=_('Issuer'), null=True, on_delete=models.CASCADE)
class Meta:
verbose_name = _('user SAML identifier')

30
mellon/models_utils.py Normal file
View File

@ -0,0 +1,30 @@
# django-mellon - SAML2 authentication for Django
# Copyright (C) 2014-2019 Entr'ouvert
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import models, utils
def get_issuer(entity_id):
idp = utils.get_idp(entity_id)
slug = idp.get('SLUG')
if slug:
issuer = models.Issuer.objects.filter(slug=slug).first()
# migrate issuer entity_id based on the slug
if issuer and issuer.entity_id != entity_id:
issuer.entity_id = entity_id
issuer.save()
if not slug or not issuer:
issuer, created = models.Issuer.objects.update_or_create(entity_id=entity_id, defaults={'slug': slug})
return issuer

View File

@ -212,7 +212,7 @@ def make_session_dump(lasso_name_id, indexes):
name_qualifier = lasso_name_id.nameQualifier and force_text(lasso_name_id.nameQualifier)
sp_name_qualifier = lasso_name_id.spNameQualifier and force_text(lasso_name_id.spNameQualifier)
for index in indexes:
issuer = index.saml_identifier.issuer
issuer = index.saml_identifier.issuer.entity_id
session_infos.append(
{
'entity_id': issuer,

View File

@ -32,7 +32,6 @@ from django.db import transaction
from django.http import Http404, HttpResponse, HttpResponseForbidden, HttpResponseRedirect
from django.shortcuts import render, resolve_url
from django.urls import reverse
from django.utils import six
from django.utils.encoding import force_str, force_text
from django.utils.http import urlencode
from django.utils.translation import ugettext as _
@ -41,7 +40,7 @@ from django.views.generic import View
from django.views.generic.base import RedirectView
from requests.exceptions import RequestException
from . import app_settings, models, utils
from . import app_settings, models, models_utils, utils
RETRY_LOGIN_COOKIE = 'MELLON_RETRY_LOGIN'
@ -244,7 +243,7 @@ class LoginView(ProfileMixin, LogMixin, View):
content = self.get_attribute_value(at, attribute_value)
if content is not None:
values.append(content)
attributes['issuer'] = login.remoteProviderId
entity_id = attributes['issuer'] = login.remoteProviderId
in_response_to = login.response.inResponseTo
if in_response_to:
attributes['nonce'] = request.session.get('mellon-nonce-%s' % in_response_to)
@ -280,7 +279,9 @@ class LoginView(ProfileMixin, LogMixin, View):
return response
def authenticate(self, request, login, attributes):
user = auth.authenticate(request=request, saml_attributes=attributes)
user = auth.authenticate(
request=request, issuer=models_utils.get_issuer(attributes['issuer']), saml_attributes=attributes
)
next_url = self.get_next_url(default=resolve_url(settings.LOGIN_REDIRECT_URL))
if user is not None:
if user.is_active:
@ -598,7 +599,7 @@ class LogoutView(ProfileMixin, LogMixin, View):
def post(self, request, *args, **kwargs):
return self.idp_logout(request, force_str(request.body), 'soap')
def logout(self, request, issuer, saml_user, session_indexes, indexes, mode):
def logout(self, request, saml_user, session_indexes, indexes, mode):
session_keys = set(indexes.values_list('session_key', flat=True))
indexes.delete()
@ -647,14 +648,15 @@ class LogoutView(ProfileMixin, LogMixin, View):
except lasso.Error as e:
return HttpResponseBadRequest('error processing logout request: %r' % e)
issuer = force_text(logout.remoteProviderId)
entity_id = force_text(logout.remoteProviderId)
session_indexes = {force_text(sessionIndex) for sessionIndex in logout.request.sessionIndexes}
saml_identifier = (
models.UserSAMLIdentifier.objects.filter(
name_id=force_text(logout.nameIdentifier.content), issuer=issuer
name_id=force_text(logout.nameIdentifier.content),
issuer=models_utils.get_issuer(entity_id),
)
.select_related('user')
.select_related('user', 'issuer')
.first()
)
@ -680,7 +682,6 @@ class LogoutView(ProfileMixin, LogMixin, View):
self.log.info('full logout requested, no sessionIndexes')
self.logout(
request,
issuer=issuer,
saml_user=name_id_user,
session_indexes=session_indexes,
indexes=indexes,

View File

@ -92,10 +92,12 @@ def test_lookup_user(settings, idp, saml_attributes):
assert User.objects.count() == 0
def test_lookup_user_transaction(transactional_db, concurrency, idp, saml_attributes):
def test_lookup_user_transaction(transactional_db, concurrency, idp, saml_attributes, settings):
adapter = DefaultAdapter()
p = ThreadPool(concurrency)
settings.MELLON_IDENTITY_PROVIDERS = [idp]
if connection.vendor == 'postgresql':
with connection.cursor() as c:
c.execute('SHOW max_connections')

View File

@ -20,6 +20,8 @@ from unittest import mock
import lasso
from xml_utils import assert_xml_constraints
from mellon.models import Issuer
from mellon.models_utils import get_issuer
from mellon.utils import create_metadata, flatten_datetime, iso8601_to_datetime
from mellon.views import check_next_url
@ -199,3 +201,38 @@ def test_check_next_url(rf):
assert not check_next_url(rf.get('/'), 'https://example.invalid/')
# default hostname is testserver
assert check_next_url(rf.get('/'), 'http://testserver/ok/')
def test_get_issuer_entity_id_migration(db, settings, metadata):
entity_id1 = 'http://idp5/metadata'
entity_id2 = 'http://idp6/metadata'
settings.MELLON_IDENTITY_PROVIDERS = [
{
'METADATA': metadata,
},
]
issuer1 = get_issuer(entity_id1)
assert issuer1.entity_id == entity_id1
assert issuer1.slug is None
settings.MELLON_IDENTITY_PROVIDERS = [
{
'METADATA': metadata,
'SLUG': 'idp',
},
]
issuer2 = get_issuer(entity_id1)
assert issuer2.id == issuer1.id
assert issuer2.entity_id == entity_id1
assert issuer2.slug == 'idp'
settings.MELLON_IDENTITY_PROVIDERS = [
{
'METADATA': metadata.replace(entity_id1, entity_id2),
'SLUG': 'idp',
},
]
issuer3 = get_issuer(entity_id2)
assert issuer3.id == issuer1.id
assert issuer3.entity_id == entity_id2
assert issuer3.slug == 'idp'