drop and rename issuer field (#56819)
This commit is contained in:
parent
a851b5b2ca
commit
73bfa476ef
|
@ -36,7 +36,7 @@ from django.utils.encoding import force_text
|
||||||
from django.utils.six.moves.urllib.parse import urlparse
|
from django.utils.six.moves.urllib.parse import urlparse
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
|
||||||
from . import app_settings, models, utils
|
from . import app_settings, models, models_utils, utils
|
||||||
|
|
||||||
User = auth.get_user_model()
|
User = auth.get_user_model()
|
||||||
|
|
||||||
|
@ -325,14 +325,14 @@ class DefaultAdapter:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
name_id = saml_attributes['name_id_content']
|
name_id = saml_attributes['name_id_content']
|
||||||
issuer = saml_attributes['issuer']
|
entity_id = saml_attributes['issuer']
|
||||||
try:
|
try:
|
||||||
saml_identifier = models.UserSAMLIdentifier.objects.select_related('user').get(
|
saml_identifier = models.UserSAMLIdentifier.objects.select_related('user').get(
|
||||||
name_id=name_id, issuer=issuer
|
name_id=name_id, issuer=models_utils.get_issuer(entity_id)
|
||||||
)
|
)
|
||||||
user = saml_identifier.user
|
user = saml_identifier.user
|
||||||
user.saml_identifier = saml_identifier
|
user.saml_identifier = saml_identifier
|
||||||
logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, issuer)
|
logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, entity_id)
|
||||||
return user
|
return user
|
||||||
except models.UserSAMLIdentifier.DoesNotExist:
|
except models.UserSAMLIdentifier.DoesNotExist:
|
||||||
pass
|
pass
|
||||||
|
@ -347,10 +347,10 @@ class DefaultAdapter:
|
||||||
created = True
|
created = True
|
||||||
user = self.create_user(User)
|
user = self.create_user(User)
|
||||||
|
|
||||||
nameid_user = self._link_user(idp, saml_attributes, issuer, name_id, user)
|
nameid_user = self._link_user(idp, saml_attributes, entity_id, name_id, user)
|
||||||
if user != nameid_user:
|
if user != nameid_user:
|
||||||
logger.info(
|
logger.info(
|
||||||
'mellon: looked up user %s with name_id %s from issuer %s', nameid_user, name_id, issuer
|
'mellon: looked up user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
|
||||||
)
|
)
|
||||||
if created:
|
if created:
|
||||||
user.delete()
|
user.delete()
|
||||||
|
@ -363,7 +363,7 @@ class DefaultAdapter:
|
||||||
user.delete()
|
user.delete()
|
||||||
return None
|
return None
|
||||||
logger.info(
|
logger.info(
|
||||||
'mellon: created new user %s with name_id %s from issuer %s', nameid_user, name_id, issuer
|
'mellon: created new user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
|
||||||
)
|
)
|
||||||
return nameid_user
|
return nameid_user
|
||||||
|
|
||||||
|
@ -455,9 +455,9 @@ class DefaultAdapter:
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _link_user(self, idp, saml_attributes, issuer, name_id, user):
|
def _link_user(self, idp, saml_attributes, entity_id, name_id, user):
|
||||||
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
|
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
|
||||||
name_id=name_id, issuer=issuer, defaults={'user': user}
|
name_id=name_id, issuer=models_utils.get_issuer(entity_id), defaults={'user': user}
|
||||||
)
|
)
|
||||||
if created:
|
if created:
|
||||||
user.saml_identifier = saml_id
|
user.saml_identifier = saml_id
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Generated by Django 2.2.19 on 2021-09-14 19:31
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('mellon', '0004_migrate_issuer'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name='usersamlidentifier',
|
||||||
|
unique_together={('issuer_fk', 'name_id')},
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='usersamlidentifier',
|
||||||
|
name='issuer',
|
||||||
|
),
|
||||||
|
migrations.RenameField(
|
||||||
|
model_name='usersamlidentifier',
|
||||||
|
old_name='issuer_fk',
|
||||||
|
new_name='issuer',
|
||||||
|
),
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name='usersamlidentifier',
|
||||||
|
unique_together={('issuer', 'name_id')},
|
||||||
|
),
|
||||||
|
]
|
|
@ -28,12 +28,9 @@ class UserSAMLIdentifier(models.Model):
|
||||||
related_name='saml_identifiers',
|
related_name='saml_identifiers',
|
||||||
on_delete=models.CASCADE,
|
on_delete=models.CASCADE,
|
||||||
)
|
)
|
||||||
issuer = models.TextField(verbose_name=_('Issuer'), null=True)
|
|
||||||
name_id = models.TextField(verbose_name=_('SAML identifier'))
|
name_id = models.TextField(verbose_name=_('SAML identifier'))
|
||||||
created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True)
|
created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True)
|
||||||
issuer_fk = models.ForeignKey(
|
issuer = models.ForeignKey('mellon.Issuer', verbose_name=_('Issuer'), null=True, on_delete=models.CASCADE)
|
||||||
'mellon.Issuer', verbose_name=_('Issuer'), null=True, on_delete=models.CASCADE
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _('user SAML identifier')
|
verbose_name = _('user SAML identifier')
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# django-mellon - SAML2 authentication for Django
|
||||||
|
# Copyright (C) 2014-2019 Entr'ouvert
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from . import models, utils
|
||||||
|
|
||||||
|
|
||||||
|
def get_issuer(entity_id):
|
||||||
|
idp = utils.get_idp(entity_id)
|
||||||
|
slug = idp.get('SLUG')
|
||||||
|
if slug:
|
||||||
|
issuer = models.Issuer.objects.filter(slug=slug).first()
|
||||||
|
# migrate issuer entity_id based on the slug
|
||||||
|
if issuer and issuer.entity_id != entity_id:
|
||||||
|
issuer.entity_id = entity_id
|
||||||
|
issuer.save()
|
||||||
|
if not slug or not issuer:
|
||||||
|
issuer, created = models.Issuer.objects.update_or_create(entity_id=entity_id, defaults={'slug': slug})
|
||||||
|
return issuer
|
|
@ -212,7 +212,7 @@ def make_session_dump(lasso_name_id, indexes):
|
||||||
name_qualifier = lasso_name_id.nameQualifier and force_text(lasso_name_id.nameQualifier)
|
name_qualifier = lasso_name_id.nameQualifier and force_text(lasso_name_id.nameQualifier)
|
||||||
sp_name_qualifier = lasso_name_id.spNameQualifier and force_text(lasso_name_id.spNameQualifier)
|
sp_name_qualifier = lasso_name_id.spNameQualifier and force_text(lasso_name_id.spNameQualifier)
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
issuer = index.saml_identifier.issuer
|
issuer = index.saml_identifier.issuer.entity_id
|
||||||
session_infos.append(
|
session_infos.append(
|
||||||
{
|
{
|
||||||
'entity_id': issuer,
|
'entity_id': issuer,
|
||||||
|
|
|
@ -32,7 +32,6 @@ from django.db import transaction
|
||||||
from django.http import Http404, HttpResponse, HttpResponseForbidden, HttpResponseRedirect
|
from django.http import Http404, HttpResponse, HttpResponseForbidden, HttpResponseRedirect
|
||||||
from django.shortcuts import render, resolve_url
|
from django.shortcuts import render, resolve_url
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils import six
|
|
||||||
from django.utils.encoding import force_str, force_text
|
from django.utils.encoding import force_str, force_text
|
||||||
from django.utils.http import urlencode
|
from django.utils.http import urlencode
|
||||||
from django.utils.translation import ugettext as _
|
from django.utils.translation import ugettext as _
|
||||||
|
@ -41,7 +40,7 @@ from django.views.generic import View
|
||||||
from django.views.generic.base import RedirectView
|
from django.views.generic.base import RedirectView
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from . import app_settings, models, utils
|
from . import app_settings, models, models_utils, utils
|
||||||
|
|
||||||
RETRY_LOGIN_COOKIE = 'MELLON_RETRY_LOGIN'
|
RETRY_LOGIN_COOKIE = 'MELLON_RETRY_LOGIN'
|
||||||
|
|
||||||
|
@ -244,7 +243,7 @@ class LoginView(ProfileMixin, LogMixin, View):
|
||||||
content = self.get_attribute_value(at, attribute_value)
|
content = self.get_attribute_value(at, attribute_value)
|
||||||
if content is not None:
|
if content is not None:
|
||||||
values.append(content)
|
values.append(content)
|
||||||
attributes['issuer'] = login.remoteProviderId
|
entity_id = attributes['issuer'] = login.remoteProviderId
|
||||||
in_response_to = login.response.inResponseTo
|
in_response_to = login.response.inResponseTo
|
||||||
if in_response_to:
|
if in_response_to:
|
||||||
attributes['nonce'] = request.session.get('mellon-nonce-%s' % in_response_to)
|
attributes['nonce'] = request.session.get('mellon-nonce-%s' % in_response_to)
|
||||||
|
@ -280,7 +279,9 @@ class LoginView(ProfileMixin, LogMixin, View):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def authenticate(self, request, login, attributes):
|
def authenticate(self, request, login, attributes):
|
||||||
user = auth.authenticate(request=request, saml_attributes=attributes)
|
user = auth.authenticate(
|
||||||
|
request=request, issuer=models_utils.get_issuer(attributes['issuer']), saml_attributes=attributes
|
||||||
|
)
|
||||||
next_url = self.get_next_url(default=resolve_url(settings.LOGIN_REDIRECT_URL))
|
next_url = self.get_next_url(default=resolve_url(settings.LOGIN_REDIRECT_URL))
|
||||||
if user is not None:
|
if user is not None:
|
||||||
if user.is_active:
|
if user.is_active:
|
||||||
|
@ -598,7 +599,7 @@ class LogoutView(ProfileMixin, LogMixin, View):
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
return self.idp_logout(request, force_str(request.body), 'soap')
|
return self.idp_logout(request, force_str(request.body), 'soap')
|
||||||
|
|
||||||
def logout(self, request, issuer, saml_user, session_indexes, indexes, mode):
|
def logout(self, request, saml_user, session_indexes, indexes, mode):
|
||||||
session_keys = set(indexes.values_list('session_key', flat=True))
|
session_keys = set(indexes.values_list('session_key', flat=True))
|
||||||
indexes.delete()
|
indexes.delete()
|
||||||
|
|
||||||
|
@ -647,14 +648,15 @@ class LogoutView(ProfileMixin, LogMixin, View):
|
||||||
except lasso.Error as e:
|
except lasso.Error as e:
|
||||||
return HttpResponseBadRequest('error processing logout request: %r' % e)
|
return HttpResponseBadRequest('error processing logout request: %r' % e)
|
||||||
|
|
||||||
issuer = force_text(logout.remoteProviderId)
|
entity_id = force_text(logout.remoteProviderId)
|
||||||
session_indexes = {force_text(sessionIndex) for sessionIndex in logout.request.sessionIndexes}
|
session_indexes = {force_text(sessionIndex) for sessionIndex in logout.request.sessionIndexes}
|
||||||
|
|
||||||
saml_identifier = (
|
saml_identifier = (
|
||||||
models.UserSAMLIdentifier.objects.filter(
|
models.UserSAMLIdentifier.objects.filter(
|
||||||
name_id=force_text(logout.nameIdentifier.content), issuer=issuer
|
name_id=force_text(logout.nameIdentifier.content),
|
||||||
|
issuer=models_utils.get_issuer(entity_id),
|
||||||
)
|
)
|
||||||
.select_related('user')
|
.select_related('user', 'issuer')
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -680,7 +682,6 @@ class LogoutView(ProfileMixin, LogMixin, View):
|
||||||
self.log.info('full logout requested, no sessionIndexes')
|
self.log.info('full logout requested, no sessionIndexes')
|
||||||
self.logout(
|
self.logout(
|
||||||
request,
|
request,
|
||||||
issuer=issuer,
|
|
||||||
saml_user=name_id_user,
|
saml_user=name_id_user,
|
||||||
session_indexes=session_indexes,
|
session_indexes=session_indexes,
|
||||||
indexes=indexes,
|
indexes=indexes,
|
||||||
|
|
|
@ -92,10 +92,12 @@ def test_lookup_user(settings, idp, saml_attributes):
|
||||||
assert User.objects.count() == 0
|
assert User.objects.count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_lookup_user_transaction(transactional_db, concurrency, idp, saml_attributes):
|
def test_lookup_user_transaction(transactional_db, concurrency, idp, saml_attributes, settings):
|
||||||
adapter = DefaultAdapter()
|
adapter = DefaultAdapter()
|
||||||
p = ThreadPool(concurrency)
|
p = ThreadPool(concurrency)
|
||||||
|
|
||||||
|
settings.MELLON_IDENTITY_PROVIDERS = [idp]
|
||||||
|
|
||||||
if connection.vendor == 'postgresql':
|
if connection.vendor == 'postgresql':
|
||||||
with connection.cursor() as c:
|
with connection.cursor() as c:
|
||||||
c.execute('SHOW max_connections')
|
c.execute('SHOW max_connections')
|
||||||
|
|
|
@ -20,6 +20,8 @@ from unittest import mock
|
||||||
import lasso
|
import lasso
|
||||||
from xml_utils import assert_xml_constraints
|
from xml_utils import assert_xml_constraints
|
||||||
|
|
||||||
|
from mellon.models import Issuer
|
||||||
|
from mellon.models_utils import get_issuer
|
||||||
from mellon.utils import create_metadata, flatten_datetime, iso8601_to_datetime
|
from mellon.utils import create_metadata, flatten_datetime, iso8601_to_datetime
|
||||||
from mellon.views import check_next_url
|
from mellon.views import check_next_url
|
||||||
|
|
||||||
|
@ -199,3 +201,38 @@ def test_check_next_url(rf):
|
||||||
assert not check_next_url(rf.get('/'), 'https://example.invalid/')
|
assert not check_next_url(rf.get('/'), 'https://example.invalid/')
|
||||||
# default hostname is testserver
|
# default hostname is testserver
|
||||||
assert check_next_url(rf.get('/'), 'http://testserver/ok/')
|
assert check_next_url(rf.get('/'), 'http://testserver/ok/')
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_issuer_entity_id_migration(db, settings, metadata):
|
||||||
|
entity_id1 = 'http://idp5/metadata'
|
||||||
|
entity_id2 = 'http://idp6/metadata'
|
||||||
|
settings.MELLON_IDENTITY_PROVIDERS = [
|
||||||
|
{
|
||||||
|
'METADATA': metadata,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
issuer1 = get_issuer(entity_id1)
|
||||||
|
assert issuer1.entity_id == entity_id1
|
||||||
|
assert issuer1.slug is None
|
||||||
|
|
||||||
|
settings.MELLON_IDENTITY_PROVIDERS = [
|
||||||
|
{
|
||||||
|
'METADATA': metadata,
|
||||||
|
'SLUG': 'idp',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
issuer2 = get_issuer(entity_id1)
|
||||||
|
assert issuer2.id == issuer1.id
|
||||||
|
assert issuer2.entity_id == entity_id1
|
||||||
|
assert issuer2.slug == 'idp'
|
||||||
|
|
||||||
|
settings.MELLON_IDENTITY_PROVIDERS = [
|
||||||
|
{
|
||||||
|
'METADATA': metadata.replace(entity_id1, entity_id2),
|
||||||
|
'SLUG': 'idp',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
issuer3 = get_issuer(entity_id2)
|
||||||
|
assert issuer3.id == issuer1.id
|
||||||
|
assert issuer3.entity_id == entity_id2
|
||||||
|
assert issuer3.slug == 'idp'
|
||||||
|
|
Loading…
Reference in New Issue