From 4cc45665b7ff5cc4c0cb1f2709609ad793cdccc3 Mon Sep 17 00:00:00 2001 From: Paul Marillonnet Date: Fri, 9 Aug 2019 15:31:27 +0200 Subject: [PATCH] oidc authn: verify id token signature (#31862) --- src/authentic2_auth_oidc/backends.py | 11 +-- src/authentic2_auth_oidc/utils.py | 125 ++++++++++++--------------- src/authentic2_auth_oidc/views.py | 2 +- tests/test_auth_oidc.py | 116 +++++++++++++++---------- 4 files changed, 128 insertions(+), 126 deletions(-) diff --git a/src/authentic2_auth_oidc/backends.py b/src/authentic2_auth_oidc/backends.py index 5a20f24bf..83d99e66a 100644 --- a/src/authentic2_auth_oidc/backends.py +++ b/src/authentic2_auth_oidc/backends.py @@ -36,14 +36,15 @@ from . import models, utils class OIDCBackend(ModelBackend): - def authenticate(self, request, access_token=None, id_token=None, nonce=None): + def authenticate(self, request, access_token=None, id_token=None, nonce=None, provider=None): logger = logging.getLogger(__name__) - if id_token is None: + if None in (id_token, provider): return original_id_token = id_token try: id_token = utils.IDToken(id_token) - except ValueError as e: + id_token.deserialize(provider) + except utils.IDTokenError as e: logger.warning(u'auth_oidc: invalid id_token %r: %s', id_token, e) return None @@ -98,10 +99,6 @@ class OIDCBackend(ModelBackend): id_token.azp, provider.client_id) return None - if id_token.exp < now(): - logger.warning(u'auth_oidc: id_token expired %s', id_token.exp) - return None - if provider.max_auth_age: if not id_token.iat: logger.warning('auth_oidc: provider configured for fresh authentication but iat is ' diff --git a/src/authentic2_auth_oidc/utils.py b/src/authentic2_auth_oidc/utils.py index 7cb419d24..41ec787f7 100644 --- a/src/authentic2_auth_oidc/utils.py +++ b/src/authentic2_auth_oidc/utils.py @@ -15,9 +15,6 @@ # along with this program. If not, see . import datetime -import base64 -import json - import requests from django.utils import six @@ -30,6 +27,11 @@ from authentic2.decorators import GlobalCache from authentic2.models import Attribute from authentic2.a2_rbac.utils import get_default_ou +from jwcrypto.jwt import JWT, JWTMissingKey +from jwcrypto.jwk import JWK +from jwcrypto.common import (JWException, InvalidJWAAlgorithm, json_decode, + base64url_encode) + from . import models TIMEOUT = 1 @@ -66,49 +68,27 @@ def get_provider_by_issuer(issuer): return models.OIDCProvider.objects.prefetch_related('claim_mappings').get(issuer=issuer) -def base64url_decode(input): - rem = len(input) % 4 - if rem > 0: - input += b'=' * (4 - rem) - return base64.urlsafe_b64decode(input) +def parse_id_token(encoded, provider): + """ May raise any subclass of jwcrypto.common.JWException """ + jwt = JWT() + jwt.deserialize(encoded, None) + header = jwt.token.jose_header + if header['alg'] in ('RS256', 'RS384', 'RS512'): + key = provider.jwkset.get_key(kid=header.get('kid')) + if not key: + raise JWTMissingKey(_('Unknown RSA key identifier %s for provider %s') + % (header.get('kid'), provider)) + elif header['alg'] in ('HS256', 'HS384', 'HS512'): + key = JWK(kty='oct', k=base64url_encode( + provider.client_secret.encode('utf-8'))) + else: + raise InvalidJWAAlgorithm( + _('Unsupported %s signature algorithm') % header['alg']) -def parse_id_token(id_token): - try: - id_token = str(id_token) - except UnicodeDecodeError: - raise ValueError('invalid characters in id_token') - payload = id_token.split('.') - if len(payload) == 5: - raise ValueError('encrypted IDToken is unsupported') - if len(payload) != 3: - raise ValueError('IDToken does not have three parts, %d found' % len(payload)) - try: - headers = base64url_decode(payload[0]) - except TypeError as e: - raise ValueError('header is not base64 decodable: %s' % e) - try: - headers = json.loads(headers) - except ValueError: - raise ValueError('cannot JSON decode headers') - if not isinstance(headers, dict): - raise ValueError('JOSE header is not a dict %r' % headers) - if 'typ' in headers and headers.get('typ') != 'JWT': - raise ValueError('JOSE type is not JWT: %s' % headers) - try: - payload = base64url_decode(payload[1]) - except TypeError as e: - raise ValueError('payload is not base64 decodable: %s' % e) - try: - payload = json.loads(payload) - except ValueError as e: - raise ValueError('invalid JSON payload: %s' % e) - if not isinstance(payload, dict): - raise ValueError('JOSE payload is not a dict %r' % payload) - # FIXME : really check signature !!! - if 'alg' not in headers or headers['alg'] is None or headers['alg'] == 'none': - raise ValueError('unsigned token: %s' % headers) - return payload + jwt = JWT() + jwt.deserialize(encoded, key) + return json_decode(jwt.claims) REQUIRED_ID_TOKEN_KEYS = set(['iss', 'sub', 'aud', 'exp', 'iat']) @@ -131,52 +111,55 @@ def parse_timestamp(tstamp): return datetime.datetime.fromtimestamp(tstamp, utc) -class IDToken(str): +class IDTokenError(ValueError): + pass + + +class IDToken(object): auth_time = None nonce = None - def __new__(cls, encoded): - return str.__new__(cls, encoded) - def __init__(self, encoded): - decoded = parse_id_token(encoded) - if not decoded: - raise ValueError('invalid id_token') + if not isinstance(encoded, (six.binary_type, six.string_types)): + raise IDTokenError( + _('Encoded ID Token must be either binary or string data')) + self._encoded = encoded + + def deserialize(self, provider): + try: + decoded = parse_id_token(self._encoded, provider) + if not decoded: + raise JWException(_('invalid id_token')) + except JWException as e: + raise IDTokenError(e) keys = set(decoded.keys()) # check fields are ok if not keys.issuperset(REQUIRED_ID_TOKEN_KEYS): - raise ValueError('missing field: %s' % (REQUIRED_ID_TOKEN_KEYS - keys)) + raise IDTokenError( + _('missing field: %s') % (REQUIRED_ID_TOKEN_KEYS - keys)) for key in keys: - if key == 'aud': - if not isinstance(decoded['aud'], (six.text_type, list)): - raise ValueError('invalid aud value: %r' % decoded['aud']) - if isinstance(decoded['aud'], list) and not all(isinstance(v, six.text_type) for v in - decoded['aud']): - raise ValueError('invalid aud value: %r' % decoded['aud']) - elif key == 'amr': + if key == 'amr': if not isinstance(decoded['amr'], list): - raise ValueError('invalid amr value: %s' % decoded['amr']) + raise IDTokenError( + _('invalid amr value: %s') % decoded['amr']) if not all(isinstance(v, six.text_type) for v in decoded['amr']): - raise ValueError('invalid amr value: %s' % decoded['amr']) + raise IDTokenError( + _('invalid amr value: %s') % decoded['amr']) elif key in KEY_TYPES: if not isinstance(decoded[key], KEY_TYPES[key]): - raise ValueError('invalid %s value: %s' % (key, decoded[key])) + raise IDTokenError( + _('invalid %s value: %s') % (key, decoded[key])) self.iss = decoded.pop('iss') self.sub = decoded.pop('sub') self.aud = decoded.pop('aud') - try: - self.exp = parse_timestamp(decoded.pop('exp')) - except ValueError as e: - raise ValueError('invalid exp value: %s' % e) - try: - self.iat = parse_timestamp(decoded.pop('iat')) - except ValueError as e: - raise ValueError('invalid iat value: %s' % e) + self.exp = parse_timestamp(decoded.pop('exp')) + self.iat = parse_timestamp(decoded.pop('iat')) if 'auth_time' in decoded: try: self.auth_time = parse_timestamp(decoded.pop('auth_time')) except ValueError as e: - raise ValueError('invalid auth_time value: %s' % e) + raise IDTokenError( + _('invalid auth_time value: %s') % e) self.nonce = decoded.pop('nonce', None) self.acr = decoded.pop('acr', None) self.azp = decoded.pop('azp', None) diff --git a/src/authentic2_auth_oidc/views.py b/src/authentic2_auth_oidc/views.py index 1e71a3f37..e85eb3edb 100644 --- a/src/authentic2_auth_oidc/views.py +++ b/src/authentic2_auth_oidc/views.py @@ -198,7 +198,7 @@ class LoginCallback(View): return self.continue_to_next_url() logger.info(u'got token response %s', result) access_token = result.get('access_token') - user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token']) + user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token'], provider=provider) if user: # remember last tokens for logout tokens = request.session.setdefault('auth_oidc', {}).setdefault('tokens', []) diff --git a/tests/test_auth_oidc.py b/tests/test_auth_oidc.py index dbb55eef1..5fc4ef3fc 100644 --- a/tests/test_auth_oidc.py +++ b/tests/test_auth_oidc.py @@ -24,6 +24,8 @@ import random from jwcrypto.jwk import JWKSet, JWK from jwcrypto.jwt import JWT +from jwcrypto.jws import JWS, InvalidJWSObject +from jwcrypto.common import base64url_encode, base64url_decode, json_encode from httmock import urlmatch, HTTMock @@ -35,73 +37,70 @@ from django.utils.six.moves.urllib import parse as urlparse from django_rbac.utils import get_ou_model -from authentic2_auth_oidc.utils import (base64url_decode, parse_id_token, IDToken, get_providers, - has_providers, register_issuer) +from authentic2_auth_oidc.utils import (parse_id_token, IDToken, get_providers, + has_providers, register_issuer, IDTokenError) from authentic2_auth_oidc.models import OIDCProvider, OIDCClaimMapping from authentic2.models import AttributeValue from authentic2.utils import timestamp_from_datetime, last_authentication_event from authentic2.a2_rbac.utils import get_default_ou -from authentic2.crypto import base64url_encode import utils +pytestmark = pytest.mark.django_db + def test_base64url_decode(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): base64url_decode('x') base64url_decode('aa') -header = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ' -payload = ('ewogImlzcyI6ICJodHRw' - 'Oi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4Mjg5NzYxMDAxIiw' - 'KICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAibi0wUzZfV3pBMk1qIi' - 'wKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEzMTEyODA5NzAKfQ') -signature = ('ggW8hZ' - '1EuVLuxNuuIJKX_V8a_OMXzR0EHR9R6jgdqrOOF4daGU96Sr_P6qJp6IcmD3HP9' - '9Obi1PRs-cwh3LO-p146waJ8IhehcwL7F09JdijmBqkvPeB2T9CJNqeGpe-gccM' - 'g4vfKjkM8FcGvnzZUN4_KSP0aAp1tOJ1zZwgjxqGByKHiOtX7TpdQyHE5lcMiKP' - 'XfEIQILVq0pc_E2DzL7emopWoaoZTF_m0_N0YzFC6g6EJbOEoRoSK5hoDalrcvR' - 'YLSrQAZZKflyuVCyixEoV9GfNQC3_osjzw2PAithfubEEBLuVVk4XUVrWOLrLl0' - 'nx7RkKU8NXNHq-rvKMzqg') +header_rsa_decoded = {'alg': 'RS256', 'kid': '1e9gdk7'} +header_hmac_decoded = {'alg': 'HS256'} payload_decoded = { 'sub': '248289761001', 'iss': 'http://server.example.com', 'aud': 's6BhdRkqt3', 'nonce': 'n-0S6_WzA2Mj', 'iat': 1311280970, - 'exp': 1311281970, + 'exp': 2201094278, } +header_rsa = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ' +header_hmac = 'eyJhbGciOiJIUzI1NiJ9' +payload = ('eyJhdWQiOiJzNkJoZFJrcXQzIiwiZXhwIjoyMjAxMDk0Mjc4LCJpYXQiOjEzMTEyOD' + 'A5NzAsImlzcyI6Imh0dHA6Ly9zZXJ2ZXIuZXhhbXBsZS5jb20iLCJub25jZSI6Im4t' + 'MFM2X1d6QTJNaiIsInN1YiI6IjI0ODI4OTc2MTAwMSJ9') -def test_parse_id_token(): - # example taken from https://tools.ietf.org/html/rfc7519#section-3.1 - assert parse_id_token('%s.%s.%s' % (header, payload, signature)) == payload_decoded - with pytest.raises(ValueError): - parse_id_token('x%s.%s.%s' % (header, payload, signature)) - with pytest.raises(ValueError): - parse_id_token('%s.%s.%s' % ('$', payload, signature)) - with pytest.raises(ValueError): - parse_id_token('%s.x%s.%s' % (header, payload, signature)) - with pytest.raises(ValueError): - parse_id_token('%s.%s.%s' % (header, '$', signature)) - # signagure is currently ignored - assert parse_id_token('%s.%s.x%s' % (header, payload, signature)) == payload_decoded - assert parse_id_token('%s.%s.%s' % (header, payload, '-')) == payload_decoded +def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header, + signature): + with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): + with pytest.raises(InvalidJWSObject): + parse_id_token('x%s.%s.%s' % (header, payload, signature), oidc_provider) + with pytest.raises(InvalidJWSObject): + parse_id_token('%s.%s.%s' % ('$', payload, signature), oidc_provider) + with pytest.raises(InvalidJWSObject): + parse_id_token('%s.x%s.%s' % (header, payload, signature), oidc_provider) + with pytest.raises(InvalidJWSObject): + parse_id_token('%s.%s.%s' % (header, '$', signature), oidc_provider) + with pytest.raises(InvalidJWSObject): + parse_id_token('%s.%s.%s' % (header, payload, '-'), oidc_provider) + assert parse_id_token('%s.%s.%s' % (header, payload, signature), oidc_provider) -def test_idtoken(): +def test_idtoken(oidc_provider, header, signature): token = IDToken('%s.%s.%s' % (header, payload, signature)) + token.deserialize(oidc_provider) assert token.sub == payload_decoded['sub'] assert token.iss == payload_decoded['iss'] assert token.aud == payload_decoded['aud'] assert token.nonce == payload_decoded['nonce'] assert token.iat == datetime.datetime(2011, 7, 21, 20, 42, 50, tzinfo=utc) - assert token.exp == datetime.datetime(2011, 7, 21, 20, 59, 30, tzinfo=utc) + assert token.exp == datetime.datetime(2039, 10, 1, 15, 4, 38, tzinfo=utc) @pytest.fixture def oidc_provider_jwkset(): - key = JWK.generate(kty='RSA', size=512) + key = JWK.generate(kty='RSA', size=512, kid='1e9gdk7') jwkset = JWKSet() jwkset.add(key) return jwkset @@ -131,12 +130,12 @@ def oidc_provider(request, db, oidc_provider_jwkset): provider = OIDCProvider.objects.create( ou=get_default_ou(), name='OIDIDP', - issuer='https://idp.example.com/', - authorization_endpoint='https://idp.example.com/authorize', - token_endpoint='https://idp.example.com/token', - end_session_endpoint='https://idp.example.com/logout', - userinfo_endpoint='https://idp.example.com/user_info', - token_revocation_endpoint='https://idp.example.com/revoke', + issuer='http://server.example.com', + authorization_endpoint='https://server.example.com/authorize', + token_endpoint='https://server.example.com/token', + end_session_endpoint='https://server.example.com/logout', + userinfo_endpoint='https://server.example.com/user_info', + token_revocation_endpoint='https://server.example.com/revoke', max_auth_age=10, strategy=OIDCProvider.STRATEGY_CREATE, jwkset_json=jwkset, @@ -182,6 +181,28 @@ def code(): return 'xxxx' +@pytest.fixture +def header(oidc_provider): + if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: + return header_rsa + elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC: + return header_hmac + + +@pytest.fixture +def signature(oidc_provider): + if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: + key = oidc_provider.jwkset.get_key(kid='1e9gdk7') + header_decoded = header_rsa_decoded + elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC: + key = JWK(kty='oct', k=base64url_encode( + oidc_provider.client_secret.encode('utf-8'))) + header_decoded = header_hmac_decoded + jws = JWS(payload=json_encode(payload_decoded)) + jws.add_signature(key=key, protected=header_decoded) + return json.loads(jws.serialize())['signature'] + + def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token=None, extra_user_info=None, sub='john.doe', nonce=None): token_endpoint = urlparse.urlparse(oidc_provider.token_endpoint) @@ -204,7 +225,7 @@ def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token id_token.update(extra_id_token) if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: - jwt = JWT(header={'alg': 'RS256'}, + jwt = JWT(header={'alg': 'RS256', 'kid': '1e9gdk7'}, claims=id_token) jwt.make_signed_token(list(oidc_provider_jwkset['keys'])[0]) else: @@ -331,7 +352,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'iat': 1}): response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) - with utils.check_log(caplog, 'id_token expired'): + with utils.check_log(caplog, 'invalid id_token %r'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'exp': 1}): response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) @@ -395,7 +416,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, with utils.check_log(caplog, 'revoked token from OIDC'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): response = response.click(href='logout') - assert response.location.startswith('https://idp.example.com/logout?') + assert response.location.startswith('https://server.example.com/logout?') def test_show_on_login_page(app, oidc_provider): @@ -461,7 +482,7 @@ def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwks with utils.check_log(caplog, 'revoked token from OIDC'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce): response = response.click(href='logout') - assert response.location.startswith('https://idp.example.com/logout?') + assert response.location.startswith('https://server.example.com/logout?') def test_register_issuer(db, app, caplog, oidc_provider_jwkset): @@ -489,7 +510,7 @@ def test_register_issuer(db, app, caplog, oidc_provider_jwkset): openid_configuration=oidc_conf) -def test_required_keys(db, caplog): +def test_required_keys(db, oidc_provider, header, signature, caplog): erroneous_payload = base64url_encode(json.dumps({ 'sub': '248289761001', 'iss': 'http://server.example.com', @@ -498,6 +519,7 @@ def test_required_keys(db, caplog): 'extra_stuff': 'hi there', # Wrong claim })) - with pytest.raises(ValueError) as e: + with pytest.raises(IDTokenError): with utils.check_log(caplog, 'missing field'): - IDToken('{}.{}.{}'.format(header, erroneous_payload, signature)) + token = IDToken('{}.{}.{}'.format(header, erroneous_payload, signature)) + token.deserialize(oidc_provider)