auth_oidc: factorize claim mapping resolution (#72418)

This commit is contained in:
Paul Marillonnet 2023-01-31 14:13:56 +01:00
parent ee7274b887
commit 4c03402f17
3 changed files with 43 additions and 38 deletions

View File

@ -34,7 +34,6 @@ from authentic2.a2_rbac.models import OrganizationalUnit
from authentic2.models import Lock
from authentic2.utils import hooks
from authentic2.utils.crypto import base64url_encode
from authentic2.utils.template import Template
from . import models, utils
@ -152,7 +151,6 @@ class OIDCBackend(ModelBackend):
user_ou = provider.ou
user_info = None
save_user = False
mappings = []
context = id_token_content.copy()
need_user_info = False
for claim_mapping in provider.claim_mappings.all():
@ -182,30 +180,11 @@ class OIDCBackend(ModelBackend):
logger.debug('auth_oidc: user_info content %s', user_info)
context.update(user_info or {})
for claim_mapping in provider.claim_mappings.all():
claim = claim_mapping.claim
if claim_mapping.idtoken_claim:
source = id_token
else:
source = user_info
if not source or claim not in source and not ('{{' in claim or '{%' in claim):
continue
verified = False
attribute = claim_mapping.attribute
if '{{' in claim or '{%' in claim:
template = Template(claim)
value = template.render(context=context)
# xxx missing verification logic for templated claims
else:
value = source.get(claim)
if claim_mapping.verified == models.OIDCClaimMapping.VERIFIED_CLAIM:
verified = bool(source.get(claim + '_verified', False))
mappings = utils.resolve_claim_mappings(provider, context, id_token, user_info)
for attribute, value, dummy in mappings:
if attribute == 'ou__slug' and value in ou_map:
user_ou = ou_map[value]
continue
if claim_mapping.verified == models.OIDCClaimMapping.ALWAYS_VERIFIED:
verified = True
mappings.append((attribute, value, verified))
break
# check for required claims
for claim_mapping in provider.claim_mappings.all():
@ -393,7 +372,7 @@ class OIDCBackend(ModelBackend):
# new style attributes
for attribute, value, verified in mappings:
if attribute in ('username', 'email'):
if attribute in ('username', 'email', 'ou__slug'):
continue
if attribute in ('first_name', 'last_name') and not verified:
continue

View File

@ -36,9 +36,9 @@ from authentic2.apps.authenticators.models import (
BaseAuthenticator,
)
from authentic2.utils.misc import make_url
from authentic2.utils.template import Template, validate_template
from authentic2.utils.template import validate_template
from . import managers
from . import managers, utils
if django.VERSION < (3, 1):
from django.contrib.postgres.fields.jsonb import JSONField # noqa pylint: disable=ungrouped-imports
@ -334,25 +334,21 @@ class OIDCProvider(BaseAuthenticator):
except OIDCAccount.MultipleObjectsReturned:
continue
had_changes = False
for claim in self.claim_mappings.all():
if '{{' in claim.claim or '{%' in claim.claim:
template = Template(claim.claim)
attribute_value = template.render(context=user_dict)
else:
attribute_value = user_dict.get(claim.claim)
mappings = utils.resolve_claim_mappings(self, user_dict)
for attribute, value, dummy in mappings:
try:
old_attribute_value = getattr(account.user, claim.attribute)
old_attribute_value = getattr(account.user, attribute)
except AttributeError:
try:
old_attribute_value = getattr(account.user.attributes, claim.attribute)
old_attribute_value = getattr(account.user.attributes, attribute)
except AttributeError:
old_attribute_value = None
if old_attribute_value == attribute_value:
if old_attribute_value == value:
continue
had_changes = True
setattr(account.user, claim.attribute, attribute_value)
setattr(account.user, attribute, value)
try:
setattr(account.user.attributes, claim.attribute, attribute_value)
setattr(account.user.attributes, attribute, value)
except AttributeError:
pass
if had_changes:

View File

@ -28,6 +28,9 @@ from jwcrypto.jwt import JWT
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Attribute
from authentic2.utils.cache import GlobalCache
from authentic2.utils.template import Template
from . import models
TIMEOUT = 1
@ -73,6 +76,33 @@ def parse_id_token(encoded, provider):
return json_decode(jwt.claims)
def resolve_claim_mappings(provider, context, id_token=None, user_info=None):
mappings = []
for claim_mapping in provider.claim_mappings.all():
claim = claim_mapping.claim
if id_token is None and user_info is None:
source = context
elif claim_mapping.idtoken_claim:
source = id_token
else:
source = user_info
if not source or claim not in source and not ('{{' in claim or '{%' in claim):
continue
verified = False
attribute = claim_mapping.attribute
if '{{' in claim or '{%' in claim:
template = Template(claim)
value = template.render(context=context)
else:
value = source.get(claim)
if claim_mapping.verified == models.OIDCClaimMapping.VERIFIED_CLAIM:
verified = bool(source.get(claim + '_verified', False))
if claim_mapping.verified == models.OIDCClaimMapping.ALWAYS_VERIFIED:
verified = True
mappings.append((attribute, value, verified))
return mappings
REQUIRED_ID_TOKEN_KEYS = {'iss', 'sub', 'aud', 'exp', 'iat'}
KEY_TYPES = {
'iss': str,