322 lines
12 KiB
Python
322 lines
12 KiB
Python
# authentic2 - versatile identity manager
|
|
# Copyright (C) 2010-2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import datetime
|
|
import urllib.parse
|
|
|
|
import requests
|
|
from django.shortcuts import get_object_or_404
|
|
from django.utils.timezone import utc
|
|
from django.utils.translation import gettext as _
|
|
from jwcrypto.common import JWException, base64url_encode, json_decode
|
|
from jwcrypto.jwk import JWK
|
|
from jwcrypto.jwt import JWT
|
|
|
|
from authentic2.a2_rbac.utils import get_default_ou
|
|
from authentic2.models import Attribute
|
|
from authentic2.utils.cache import GlobalCache
|
|
from authentic2.utils.template import Template
|
|
|
|
from . import models
|
|
|
|
TIMEOUT = 1
|
|
|
|
|
|
@GlobalCache(timeout=TIMEOUT)
|
|
def get_attributes():
|
|
return Attribute.objects.all()
|
|
|
|
|
|
@GlobalCache(timeout=TIMEOUT)
|
|
def get_provider(pk):
|
|
from . import models
|
|
|
|
return get_object_or_404(models.OIDCProvider, pk=pk)
|
|
|
|
|
|
@GlobalCache(timeout=TIMEOUT)
|
|
def get_provider_by_issuer(issuer):
|
|
from . import models
|
|
|
|
return models.OIDCProvider.objects.prefetch_related('claim_mappings').get(issuer=issuer)
|
|
|
|
|
|
def parse_id_token(encoded, provider):
|
|
"""May raise any subclass of jwcrypto.common.JWException"""
|
|
jwt = JWT()
|
|
jwt.deserialize(encoded, None)
|
|
header = jwt.token.jose_header
|
|
|
|
alg = header.get('alg')
|
|
try:
|
|
if alg in ('RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'ES512'):
|
|
kid = header.get('kid', None)
|
|
key = provider.jwkset.get_key(kid=kid)
|
|
if not key:
|
|
raise IDTokenError(_('Key ID %r not in key set') % kid)
|
|
elif alg in ('HS256', 'HS384', 'HS512'):
|
|
key = JWK(kty='oct', k=base64url_encode(provider.client_secret.encode('utf-8')))
|
|
jwt = JWT()
|
|
jwt.deserialize(encoded, key)
|
|
except JWException as e:
|
|
raise IDTokenError(_('Error during token parsing: %s') % e)
|
|
return json_decode(jwt.claims)
|
|
|
|
|
|
def resolve_claim_mappings(provider, context, id_token=None, user_info=None):
|
|
mappings = []
|
|
disabled_attrs = Attribute.all_objects.filter(disabled=True).values_list('name', flat=True)
|
|
for claim_mapping in provider.claim_mappings.all():
|
|
if claim_mapping.attribute in disabled_attrs:
|
|
continue
|
|
claim = claim_mapping.claim
|
|
if id_token is None and user_info is None:
|
|
source = context
|
|
elif claim_mapping.idtoken_claim:
|
|
source = id_token
|
|
else:
|
|
source = user_info
|
|
if not source or claim not in source and not ('{{' in claim or '{%' in claim):
|
|
continue
|
|
verified = False
|
|
attribute = claim_mapping.attribute
|
|
if '{{' in claim or '{%' in claim:
|
|
template = Template(claim)
|
|
value = template.render(context=context)
|
|
else:
|
|
value = source.get(claim)
|
|
if claim_mapping.verified == models.OIDCClaimMapping.VERIFIED_CLAIM:
|
|
verified = bool(source.get(claim + '_verified', False))
|
|
if claim_mapping.verified == models.OIDCClaimMapping.ALWAYS_VERIFIED:
|
|
verified = True
|
|
mappings.append((attribute, value, verified))
|
|
return mappings
|
|
|
|
|
|
REQUIRED_ID_TOKEN_KEYS = {'iss', 'sub', 'aud', 'exp', 'iat'}
|
|
KEY_TYPES = {
|
|
'iss': str,
|
|
'sub': str,
|
|
'exp': int,
|
|
'iat': int,
|
|
'auth_time': int,
|
|
'nonce': str,
|
|
'acr': str,
|
|
'azp': str,
|
|
# aud and amr havec specific checks
|
|
}
|
|
|
|
|
|
def parse_timestamp(tstamp):
|
|
if not isinstance(tstamp, int):
|
|
raise ValueError('%s' % tstamp)
|
|
return datetime.datetime.fromtimestamp(tstamp, utc)
|
|
|
|
|
|
class IDTokenError(ValueError):
|
|
pass
|
|
|
|
|
|
class IDToken:
|
|
auth_time = None
|
|
nonce = None
|
|
|
|
def __init__(self, encoded):
|
|
if not isinstance(encoded, (bytes, str)):
|
|
raise IDTokenError(_('Encoded ID Token must be either binary or string data'))
|
|
self._encoded = encoded
|
|
|
|
def as_dict(self, provider):
|
|
return parse_id_token(self._encoded, provider)
|
|
|
|
def deserialize(self, provider):
|
|
decoded = parse_id_token(self._encoded, provider)
|
|
if not decoded:
|
|
raise IDTokenError(_('invalid id_token'))
|
|
keys = set(decoded.keys())
|
|
# check fields are ok
|
|
if not keys.issuperset(REQUIRED_ID_TOKEN_KEYS):
|
|
raise IDTokenError(_('missing field: %s') % (REQUIRED_ID_TOKEN_KEYS - keys))
|
|
for key in keys:
|
|
if key == 'amr':
|
|
if not isinstance(decoded['amr'], list):
|
|
raise IDTokenError(_('invalid amr value: %s') % decoded['amr'])
|
|
if not all(isinstance(v, str) for v in decoded['amr']):
|
|
raise IDTokenError(_('invalid amr value: %s') % decoded['amr'])
|
|
elif key in KEY_TYPES:
|
|
if key not in REQUIRED_ID_TOKEN_KEYS and (decoded[key] is None or decoded[key] == ''):
|
|
# for optional keys ignore null and empty string values,
|
|
# even if specification says it should not happen.
|
|
# https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.5.3.2
|
|
continue
|
|
if not isinstance(decoded[key], KEY_TYPES[key]):
|
|
raise IDTokenError(
|
|
_('invalid %(key)s value: %(value)s') % {'key': key, 'value': decoded[key]}
|
|
)
|
|
self.iss = decoded.pop('iss')
|
|
self.sub = decoded.pop('sub')
|
|
self.aud = decoded.pop('aud')
|
|
self.exp = parse_timestamp(decoded.pop('exp'))
|
|
self.iat = parse_timestamp(decoded.pop('iat'))
|
|
auth_time = decoded.get('auth_time')
|
|
if auth_time:
|
|
try:
|
|
self.auth_time = parse_timestamp(auth_time)
|
|
except ValueError as e:
|
|
raise IDTokenError(_('invalid auth_time value: %s') % e)
|
|
self.nonce = decoded.pop('nonce', None)
|
|
self.acr = decoded.pop('acr', None)
|
|
self.azp = decoded.pop('azp', None)
|
|
self.extra = decoded
|
|
|
|
def __contains__(self, key):
|
|
if key in self.__dict__:
|
|
return True
|
|
if key in self.extra:
|
|
return True
|
|
return False
|
|
|
|
def __getitem__(self, key):
|
|
if key in self.__dict__:
|
|
return self.__dict__[key]
|
|
if key in self.extra:
|
|
return self.extra[key]
|
|
raise KeyError(key)
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
|
|
OPENID_CONFIGURATION_REQUIRED = {
|
|
'issuer',
|
|
'authorization_endpoint',
|
|
'token_endpoint',
|
|
'jwks_uri',
|
|
'response_types_supported',
|
|
'subject_types_supported',
|
|
'id_token_signing_alg_values_supported',
|
|
'userinfo_endpoint',
|
|
}
|
|
|
|
|
|
def check_https(url):
|
|
return urllib.parse.urlparse(url).scheme == 'https'
|
|
|
|
|
|
def register_issuer(
|
|
name, client_id, client_secret, issuer=None, openid_configuration=None, verify=True, timeout=None, ou=None
|
|
):
|
|
from . import models
|
|
|
|
if issuer and not openid_configuration:
|
|
openid_configuration_url = get_openid_configuration_url(issuer)
|
|
try:
|
|
response = requests.get(openid_configuration_url, verify=verify, timeout=timeout)
|
|
response.raise_for_status()
|
|
except requests.RequestException as e:
|
|
raise ValueError(
|
|
_('Unable to reach the OpenID Connect configuration for %(issuer)s: %(error)s')
|
|
% {
|
|
'issuer': issuer,
|
|
'error': e,
|
|
}
|
|
)
|
|
|
|
try:
|
|
openid_configuration = openid_configuration or response.json()
|
|
if not isinstance(openid_configuration, dict):
|
|
raise ValueError(_('MUST be a dictionnary'))
|
|
keys = set(openid_configuration.keys())
|
|
if not keys >= OPENID_CONFIGURATION_REQUIRED:
|
|
raise ValueError(_('missing keys %s') % (OPENID_CONFIGURATION_REQUIRED - keys))
|
|
for key in ['issuer', 'authorization_endpoint', 'token_endpoint', 'jwks_uri', 'userinfo_endpoint']:
|
|
if not check_https(openid_configuration[key]):
|
|
raise ValueError(
|
|
_('%(key)s is not an https:// URL; %(value)s')
|
|
% {
|
|
'key': key,
|
|
'value': openid_configuration[key],
|
|
}
|
|
)
|
|
except ValueError as e:
|
|
raise ValueError(_('Invalid OpenID Connect configuration for %(issuer)s: %(error)s') % (issuer, e))
|
|
if 'code' not in openid_configuration['response_types_supported']:
|
|
raise ValueError(_('authorization code flow is unsupported: code response type is unsupported'))
|
|
try:
|
|
response = requests.get(openid_configuration['jwks_uri'], verify=verify, timeout=None)
|
|
response.raise_for_status()
|
|
except requests.RequestException as e:
|
|
raise ValueError(
|
|
_('Unable to reach the OpenID Connect JWKSet for %(issuer)s: %(url)s %(error)s')
|
|
% {
|
|
'issuer': issuer,
|
|
'url': openid_configuration['jwks_uri'],
|
|
'error': e,
|
|
}
|
|
)
|
|
try:
|
|
jwkset_json = response.json()
|
|
except ValueError as e:
|
|
raise ValueError(_('Invalid JSKSet document: %s') % e)
|
|
try:
|
|
old_pk = models.OIDCProvider.objects.get(issuer=openid_configuration['issuer']).pk
|
|
except models.OIDCProvider.DoesNotExist:
|
|
old_pk = None
|
|
if {'RS256', 'RS384', 'RS512'} & set(openid_configuration['id_token_signing_alg_values_supported']):
|
|
idtoken_algo = models.OIDCProvider.ALGO_RSA
|
|
elif {'HS256', 'HS384', 'HS512'} & set(openid_configuration['id_token_signing_alg_values_supported']):
|
|
idtoken_algo = models.OIDCProvider.ALGO_HMAC
|
|
elif {'ES256', 'ES384', 'ES512'} & set(openid_configuration['id_token_signing_alg_values_supported']):
|
|
idtoken_algo = models.OIDCProvider.ALGO_EC
|
|
else:
|
|
raise ValueError(
|
|
_('no common algorithm found for signing idtokens: %s')
|
|
% openid_configuration['id_token_signing_alg_values_supported']
|
|
)
|
|
claims_parameter_supported = openid_configuration.get('claims_parameter_supported') is True
|
|
kwargs = dict(
|
|
ou=ou or get_default_ou(),
|
|
name=name,
|
|
issuer=openid_configuration['issuer'],
|
|
authorization_endpoint=openid_configuration['authorization_endpoint'],
|
|
token_endpoint=openid_configuration['token_endpoint'],
|
|
userinfo_endpoint=openid_configuration['userinfo_endpoint'],
|
|
jwkset_json=jwkset_json,
|
|
idtoken_algo=idtoken_algo,
|
|
strategy=models.OIDCProvider.STRATEGY_CREATE,
|
|
claims_parameter_supported=claims_parameter_supported,
|
|
)
|
|
if old_pk:
|
|
models.OIDCProvider.objects.filter(pk=old_pk).update(**kwargs)
|
|
return models.OIDCProvider.objects.get(pk=old_pk)
|
|
else:
|
|
return models.OIDCProvider.objects.create(**kwargs)
|
|
|
|
|
|
def get_openid_configuration_url(issuer):
|
|
parsed = urllib.parse.urlparse(issuer)
|
|
if parsed.query or parsed.fragment or parsed.scheme != 'https':
|
|
raise ValueError(
|
|
_('invalid issuer URL, it must use the https:// scheme and not have a query or fragment')
|
|
)
|
|
issuer = urllib.parse.urlunparse(
|
|
(parsed.scheme, parsed.netloc, parsed.path.rstrip('/'), None, None, None)
|
|
)
|
|
return issuer + '/.well-known/openid-configuration'
|