diff --git a/mellon/adapters.py b/mellon/adapters.py index 0f5b98e..44e9193 100644 --- a/mellon/adapters.py +++ b/mellon/adapters.py @@ -331,9 +331,20 @@ class DefaultAdapter: name_id = saml_attributes['name_id_content'] entity_id = saml_attributes['issuer'] try: + to_update = { + 'nid_format': saml_attributes['name_id_format'], + 'nid_name_qualifier': saml_attributes.get('name_id_name_qualifier'), + 'nid_sp_name_qualifier': saml_attributes.get('name_id_sp_name_qualifier'), + 'nid_sp_provided_id': saml_attributes.get('name_id_sp_provided_id'), + } saml_identifier = models.UserSAMLIdentifier.objects.select_related('user').get( name_id=name_id, issuer=models_utils.get_issuer(entity_id) ) + # nid_* attributes are new, we must update them if they are not initialized, eventually + for key in to_update: + if getattr(saml_identifier, key) != to_update[key]: + models.UserSAMLIdentifier.objects.filter(pk=saml_identifier.pk).update(**to_update) + break user = saml_identifier.user user.saml_identifier = saml_identifier logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, entity_id) @@ -463,17 +474,25 @@ class DefaultAdapter: name_id_content = saml_attributes['name_id_content'] if saml_attributes['name_id_format'] == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT: name_id_content = saml_attributes['transient_name_id_content'] + to_update = { + 'nid_format': saml_attributes['name_id_format'], + 'nid_name_qualifier': saml_attributes.get('name_id_name_qualifier'), + 'nid_sp_name_qualifier': saml_attributes.get('name_id_sp_name_qualifier'), + 'nid_sp_provided_id': saml_attributes.get('name_id_sp_provided_id'), + } saml_id, created = models.UserSAMLIdentifier.objects.get_or_create( name_id=name_id_content, issuer=models_utils.get_issuer(saml_attributes['issuer']), defaults={ 'user': user, - 'nid_format': saml_attributes['name_id_format'], - 'nid_name_qualifier': saml_attributes.get('name_id_name_qualifier'), - 'nid_sp_name_qualifier': saml_attributes.get('name_id_sp_name_qualifier'), - 'nid_sp_provided_id': saml_attributes.get('name_id_sp_provided_id'), + **to_update, }, ) + # nid_* attributes are new, we must update them eventually + for key in to_update: + if getattr(saml_id, key) != to_update[key]: + models.UserSAMLIdentifier.objects.filter(pk=saml_id.pk).update(**to_update) + break if created: user.saml_identifier = saml_id return user diff --git a/tests/test_sso_slo.py b/tests/test_sso_slo.py index 3fedc93..b2b87b4 100644 --- a/tests/test_sso_slo.py +++ b/tests/test_sso_slo.py @@ -894,3 +894,31 @@ def test_sso_slo_token(db, app, rf, idp, caplog, django_user_model, freezer): assert len(caplog.records) == 0, 'logout failed' assert response.location == '/somepath/' assert models.SessionIndex.objects.count() == 0 + + +def test_sso_slo_update_of_new_fields(db, app, idp, caplog, sp_settings): + response = app.get('/login/') + url, body, relay_state = idp.process_authn_request_redirect(response['Location']) + response = app.post( + reverse('mellon_login'), params={'SAMLResponse': body, 'RelayState': relay_state} + ).follow() + # violent logout + app.session.flush() + + # remove existing fields + models.UserSAMLIdentifier.objects.all().update( + nid_format=None, nid_name_qualifier=None, nid_sp_name_qualifier=None, nid_sp_provided_id=None + ) + + response = app.get('/login/') + url, body, relay_state = idp.process_authn_request_redirect(response['Location']) + response = app.post( + reverse('mellon_login'), params={'SAMLResponse': body, 'RelayState': relay_state} + ).follow() + + # check logout works + response = app.get('/logout/') + url = idp.process_logout_request_redirect(response.location) + caplog.clear() + response = app.get(url) + assert len(caplog.records) == 0, 'logout failed'