authentic/src/authentic2_auth_oidc/utils.py

290 lines
11 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import requests
from django.utils import six
from django.utils.timezone import utc
from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
from django.utils.six.moves.urllib import parse as urlparse
from authentic2.decorators import GlobalCache
from authentic2.models import Attribute
from authentic2.a2_rbac.utils import get_default_ou
from jwcrypto.jwt import JWT, JWTMissingKey
from jwcrypto.jwk import JWK
from jwcrypto.common import (JWException, InvalidJWAAlgorithm, json_decode,
base64url_encode)
from . import models
TIMEOUT = 1
@GlobalCache(timeout=5, kwargs=['shown'])
def get_providers(shown=None):
qs = models.OIDCProvider.objects.all()
if shown is not None:
qs = qs.filter(show=shown)
return qs
@GlobalCache(timeout=TIMEOUT)
def get_attributes():
return Attribute.objects.all()
@GlobalCache(timeout=TIMEOUT)
def get_provider(pk):
from . import models
return get_object_or_404(models.OIDCProvider, pk=pk)
@GlobalCache(timeout=TIMEOUT)
def has_providers():
from . import models
return models.OIDCProvider.objects.filter(show=True).exists()
@GlobalCache(timeout=TIMEOUT)
def get_provider_by_issuer(issuer):
from . import models
return models.OIDCProvider.objects.prefetch_related('claim_mappings').get(issuer=issuer)
def parse_id_token(encoded, provider):
""" May raise any subclass of jwcrypto.common.JWException """
jwt = JWT()
jwt.deserialize(encoded, None)
header = jwt.token.jose_header
if header['alg'] in ('RS256', 'RS384', 'RS512'):
key = provider.jwkset.get_key(kid=header.get('kid'))
if not key:
raise JWTMissingKey(_('Unknown RSA key identifier %s for provider %s')
% (header.get('kid'), provider))
elif header['alg'] in ('HS256', 'HS384', 'HS512'):
key = JWK(kty='oct', k=base64url_encode(
provider.client_secret.encode('utf-8')))
else:
raise InvalidJWAAlgorithm(
_('Unsupported %s signature algorithm') % header['alg'])
jwt = JWT()
jwt.deserialize(encoded, key)
return json_decode(jwt.claims)
REQUIRED_ID_TOKEN_KEYS = set(['iss', 'sub', 'aud', 'exp', 'iat'])
KEY_TYPES = {
'iss': six.text_type,
'sub': six.text_type,
'exp': int,
'iat': int,
'auth_time': int,
'nonce': six.text_type,
'acr': six.text_type,
'azp': six.text_type,
# aud and amr havec specific checks
}
def parse_timestamp(tstamp):
if not isinstance(tstamp, int):
raise ValueError('%s' % tstamp)
return datetime.datetime.fromtimestamp(tstamp, utc)
class IDTokenError(ValueError):
pass
class IDToken(object):
auth_time = None
nonce = None
def __init__(self, encoded):
if not isinstance(encoded, (six.binary_type, six.string_types)):
raise IDTokenError(
_('Encoded ID Token must be either binary or string data'))
self._encoded = encoded
def deserialize(self, provider):
try:
decoded = parse_id_token(self._encoded, provider)
if not decoded:
raise JWException(_('invalid id_token'))
except JWException as e:
raise IDTokenError(e)
keys = set(decoded.keys())
# check fields are ok
if not keys.issuperset(REQUIRED_ID_TOKEN_KEYS):
raise IDTokenError(
_('missing field: %s') % (REQUIRED_ID_TOKEN_KEYS - keys))
for key in keys:
if key == 'amr':
if not isinstance(decoded['amr'], list):
raise IDTokenError(
_('invalid amr value: %s') % decoded['amr'])
if not all(isinstance(v, six.text_type) for v in decoded['amr']):
raise IDTokenError(
_('invalid amr value: %s') % decoded['amr'])
elif key in KEY_TYPES:
if not isinstance(decoded[key], KEY_TYPES[key]):
raise IDTokenError(
_('invalid %s value: %s') % (key, decoded[key]))
self.iss = decoded.pop('iss')
self.sub = decoded.pop('sub')
self.aud = decoded.pop('aud')
self.exp = parse_timestamp(decoded.pop('exp'))
self.iat = parse_timestamp(decoded.pop('iat'))
if 'auth_time' in decoded:
try:
self.auth_time = parse_timestamp(decoded.pop('auth_time'))
except ValueError as e:
raise IDTokenError(
_('invalid auth_time value: %s') % e)
self.nonce = decoded.pop('nonce', None)
self.acr = decoded.pop('acr', None)
self.azp = decoded.pop('azp', None)
self.extra = decoded
def __contains__(self, key):
if key in self.__dict__:
return True
if key in self.extra:
return True
return False
def __getitem__(self, key):
if key in self.__dict__:
return self.__dict__[key]
if key in self.extra:
return self.extra[key]
raise KeyError(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
OPENID_CONFIGURATION_REQUIRED = set(
['issuer', 'authorization_endpoint', 'token_endpoint', 'jwks_uri', 'response_types_supported',
'subject_types_supported', 'id_token_signing_alg_values_supported', 'userinfo_endpoint']
)
def check_https(url):
return urlparse.urlparse(url).scheme == 'https'
def register_issuer(name, issuer=None, openid_configuration=None, verify=True, timeout=None,
ou=None):
from . import models
if issuer and not openid_configuration:
openid_configuration_url = get_openid_configuration_url(issuer)
try:
response = requests.get(openid_configuration_url, verify=verify, timeout=timeout)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(_('Unable to reach the OpenID Connect configuration for %(issuer)s: '
'%(error)s') % {
'issuer': issuer,
'error': e,
})
try:
openid_configuration = openid_configuration or response.json()
if not isinstance(openid_configuration, dict):
raise ValueError(_('MUST be a dictionnary'))
keys = set(openid_configuration.keys())
if not keys >= OPENID_CONFIGURATION_REQUIRED:
raise ValueError(_('missing keys %s') % (OPENID_CONFIGURATION_REQUIRED - keys))
for key in ['issuer', 'authorization_endpoint', 'token_endpoint', 'jwks_uri',
'userinfo_endpoint']:
if not check_https(openid_configuration[key]):
raise ValueError(_('%(key)s is not an https:// URL; %(value)s') % {
'key': key,
'value': openid_configuration[key],
})
except ValueError as e:
raise ValueError(_('Invalid OpenID Connect configuration for %(issuer)s: '
'%(error)s') % (issuer, e))
if 'code' not in openid_configuration['response_types_supported']:
raise ValueError(_('authorization code flow is unsupported: code response type is '
'unsupported'))
try:
response = requests.get(openid_configuration['jwks_uri'], verify=verify, timeout=None)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(_('Unable to reach the OpenID Connect JWKSet for %(issuer)s: '
'%(url)s %(error)s') % {
'issuer': issuer,
'url': openid_configuration['jwks_uri'],
'error': e,
})
try:
jwkset_json = response.json()
except ValueError as e:
raise ValueError(_('Invalid JSKSet document: %s') % e)
try:
old_pk = models.OIDCProvider.objects.get(issuer=openid_configuration['issuer']).pk
except models.OIDCProvider.DoesNotExist:
old_pk = None
if (set(['RS256', 'RS384', 'RS512'])
& set(openid_configuration['id_token_signing_alg_values_supported'])):
idtoken_algo = models.OIDCProvider.ALGO_RSA
elif (set(['HS256', 'HS384', 'HS512'])
& set(openid_configuration['id_token_signing_alg_values_supported'])):
idtoken_algo = models.OIDCProvider.ALGO_HMAC
else:
raise ValueError(_('no common algorithm found for signing idtokens: %s') %
openid_configuration['id_token_signing_alg_values_supported'])
claims_parameter_supported = openid_configuration.get('claims_parameter_supported') is True
kwargs = dict(
ou=ou or get_default_ou(),
name=name,
issuer=openid_configuration['issuer'],
authorization_endpoint=openid_configuration['authorization_endpoint'],
token_endpoint=openid_configuration['token_endpoint'],
userinfo_endpoint=openid_configuration['userinfo_endpoint'],
jwkset_json=jwkset_json,
idtoken_algo=idtoken_algo,
strategy=models.OIDCProvider.STRATEGY_CREATE,
claims_parameter_supported=claims_parameter_supported)
if old_pk:
models.OIDCProvider.objects.filter(pk=old_pk).update(**kwargs)
return models.OIDCProvider.objects.get(pk=old_pk)
else:
return models.OIDCProvider.objects.create(**kwargs)
def get_openid_configuration_url(issuer):
parsed = urlparse.urlparse(issuer)
if parsed.query or parsed.fragment or parsed.scheme != 'https':
raise ValueError(_('invalid issuer URL, it must use the https:// scheme and not have a '
'query or fragment'))
issuer = urlparse.urlunparse((parsed.scheme, parsed.netloc, parsed.path.rstrip('/'), None,
None, None))
return issuer + '/.well-known/openid-configuration'