auth_oidc: factorize claim mapping resolution (#72418)
This commit is contained in:
parent
ee7274b887
commit
4c03402f17
|
@ -34,7 +34,6 @@ from authentic2.a2_rbac.models import OrganizationalUnit
|
|||
from authentic2.models import Lock
|
||||
from authentic2.utils import hooks
|
||||
from authentic2.utils.crypto import base64url_encode
|
||||
from authentic2.utils.template import Template
|
||||
|
||||
from . import models, utils
|
||||
|
||||
|
@ -152,7 +151,6 @@ class OIDCBackend(ModelBackend):
|
|||
user_ou = provider.ou
|
||||
user_info = None
|
||||
save_user = False
|
||||
mappings = []
|
||||
context = id_token_content.copy()
|
||||
need_user_info = False
|
||||
for claim_mapping in provider.claim_mappings.all():
|
||||
|
@ -182,30 +180,11 @@ class OIDCBackend(ModelBackend):
|
|||
logger.debug('auth_oidc: user_info content %s', user_info)
|
||||
context.update(user_info or {})
|
||||
|
||||
for claim_mapping in provider.claim_mappings.all():
|
||||
claim = claim_mapping.claim
|
||||
if claim_mapping.idtoken_claim:
|
||||
source = id_token
|
||||
else:
|
||||
source = user_info
|
||||
if not source or claim not in source and not ('{{' in claim or '{%' in claim):
|
||||
continue
|
||||
verified = False
|
||||
attribute = claim_mapping.attribute
|
||||
if '{{' in claim or '{%' in claim:
|
||||
template = Template(claim)
|
||||
value = template.render(context=context)
|
||||
# xxx missing verification logic for templated claims
|
||||
else:
|
||||
value = source.get(claim)
|
||||
if claim_mapping.verified == models.OIDCClaimMapping.VERIFIED_CLAIM:
|
||||
verified = bool(source.get(claim + '_verified', False))
|
||||
mappings = utils.resolve_claim_mappings(provider, context, id_token, user_info)
|
||||
for attribute, value, dummy in mappings:
|
||||
if attribute == 'ou__slug' and value in ou_map:
|
||||
user_ou = ou_map[value]
|
||||
continue
|
||||
if claim_mapping.verified == models.OIDCClaimMapping.ALWAYS_VERIFIED:
|
||||
verified = True
|
||||
mappings.append((attribute, value, verified))
|
||||
break
|
||||
|
||||
# check for required claims
|
||||
for claim_mapping in provider.claim_mappings.all():
|
||||
|
@ -393,7 +372,7 @@ class OIDCBackend(ModelBackend):
|
|||
|
||||
# new style attributes
|
||||
for attribute, value, verified in mappings:
|
||||
if attribute in ('username', 'email'):
|
||||
if attribute in ('username', 'email', 'ou__slug'):
|
||||
continue
|
||||
if attribute in ('first_name', 'last_name') and not verified:
|
||||
continue
|
||||
|
|
|
@ -36,9 +36,9 @@ from authentic2.apps.authenticators.models import (
|
|||
BaseAuthenticator,
|
||||
)
|
||||
from authentic2.utils.misc import make_url
|
||||
from authentic2.utils.template import Template, validate_template
|
||||
from authentic2.utils.template import validate_template
|
||||
|
||||
from . import managers
|
||||
from . import managers, utils
|
||||
|
||||
if django.VERSION < (3, 1):
|
||||
from django.contrib.postgres.fields.jsonb import JSONField # noqa pylint: disable=ungrouped-imports
|
||||
|
@ -334,25 +334,21 @@ class OIDCProvider(BaseAuthenticator):
|
|||
except OIDCAccount.MultipleObjectsReturned:
|
||||
continue
|
||||
had_changes = False
|
||||
for claim in self.claim_mappings.all():
|
||||
if '{{' in claim.claim or '{%' in claim.claim:
|
||||
template = Template(claim.claim)
|
||||
attribute_value = template.render(context=user_dict)
|
||||
else:
|
||||
attribute_value = user_dict.get(claim.claim)
|
||||
mappings = utils.resolve_claim_mappings(self, user_dict)
|
||||
for attribute, value, dummy in mappings:
|
||||
try:
|
||||
old_attribute_value = getattr(account.user, claim.attribute)
|
||||
old_attribute_value = getattr(account.user, attribute)
|
||||
except AttributeError:
|
||||
try:
|
||||
old_attribute_value = getattr(account.user.attributes, claim.attribute)
|
||||
old_attribute_value = getattr(account.user.attributes, attribute)
|
||||
except AttributeError:
|
||||
old_attribute_value = None
|
||||
if old_attribute_value == attribute_value:
|
||||
if old_attribute_value == value:
|
||||
continue
|
||||
had_changes = True
|
||||
setattr(account.user, claim.attribute, attribute_value)
|
||||
setattr(account.user, attribute, value)
|
||||
try:
|
||||
setattr(account.user.attributes, claim.attribute, attribute_value)
|
||||
setattr(account.user.attributes, attribute, value)
|
||||
except AttributeError:
|
||||
pass
|
||||
if had_changes:
|
||||
|
|
|
@ -28,6 +28,9 @@ from jwcrypto.jwt import JWT
|
|||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.models import Attribute
|
||||
from authentic2.utils.cache import GlobalCache
|
||||
from authentic2.utils.template import Template
|
||||
|
||||
from . import models
|
||||
|
||||
TIMEOUT = 1
|
||||
|
||||
|
@ -73,6 +76,33 @@ def parse_id_token(encoded, provider):
|
|||
return json_decode(jwt.claims)
|
||||
|
||||
|
||||
def resolve_claim_mappings(provider, context, id_token=None, user_info=None):
|
||||
mappings = []
|
||||
for claim_mapping in provider.claim_mappings.all():
|
||||
claim = claim_mapping.claim
|
||||
if id_token is None and user_info is None:
|
||||
source = context
|
||||
elif claim_mapping.idtoken_claim:
|
||||
source = id_token
|
||||
else:
|
||||
source = user_info
|
||||
if not source or claim not in source and not ('{{' in claim or '{%' in claim):
|
||||
continue
|
||||
verified = False
|
||||
attribute = claim_mapping.attribute
|
||||
if '{{' in claim or '{%' in claim:
|
||||
template = Template(claim)
|
||||
value = template.render(context=context)
|
||||
else:
|
||||
value = source.get(claim)
|
||||
if claim_mapping.verified == models.OIDCClaimMapping.VERIFIED_CLAIM:
|
||||
verified = bool(source.get(claim + '_verified', False))
|
||||
if claim_mapping.verified == models.OIDCClaimMapping.ALWAYS_VERIFIED:
|
||||
verified = True
|
||||
mappings.append((attribute, value, verified))
|
||||
return mappings
|
||||
|
||||
|
||||
REQUIRED_ID_TOKEN_KEYS = {'iss', 'sub', 'aud', 'exp', 'iat'}
|
||||
KEY_TYPES = {
|
||||
'iss': str,
|
||||
|
|
Loading…
Reference in New Issue