oidc authn: verify id token signature (#31862)

This commit is contained in:
Paul Marillonnet 2019-08-09 15:31:27 +02:00
parent e472246f3c
commit 4cc45665b7
4 changed files with 128 additions and 126 deletions

View File

@ -36,14 +36,15 @@ from . import models, utils
class OIDCBackend(ModelBackend):
def authenticate(self, request, access_token=None, id_token=None, nonce=None):
def authenticate(self, request, access_token=None, id_token=None, nonce=None, provider=None):
logger = logging.getLogger(__name__)
if id_token is None:
if None in (id_token, provider):
return
original_id_token = id_token
try:
id_token = utils.IDToken(id_token)
except ValueError as e:
id_token.deserialize(provider)
except utils.IDTokenError as e:
logger.warning(u'auth_oidc: invalid id_token %r: %s', id_token, e)
return None
@ -98,10 +99,6 @@ class OIDCBackend(ModelBackend):
id_token.azp, provider.client_id)
return None
if id_token.exp < now():
logger.warning(u'auth_oidc: id_token expired %s', id_token.exp)
return None
if provider.max_auth_age:
if not id_token.iat:
logger.warning('auth_oidc: provider configured for fresh authentication but iat is '

View File

@ -15,9 +15,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import base64
import json
import requests
from django.utils import six
@ -30,6 +27,11 @@ from authentic2.decorators import GlobalCache
from authentic2.models import Attribute
from authentic2.a2_rbac.utils import get_default_ou
from jwcrypto.jwt import JWT, JWTMissingKey
from jwcrypto.jwk import JWK
from jwcrypto.common import (JWException, InvalidJWAAlgorithm, json_decode,
base64url_encode)
from . import models
TIMEOUT = 1
@ -66,49 +68,27 @@ def get_provider_by_issuer(issuer):
return models.OIDCProvider.objects.prefetch_related('claim_mappings').get(issuer=issuer)
def base64url_decode(input):
rem = len(input) % 4
if rem > 0:
input += b'=' * (4 - rem)
return base64.urlsafe_b64decode(input)
def parse_id_token(encoded, provider):
""" May raise any subclass of jwcrypto.common.JWException """
jwt = JWT()
jwt.deserialize(encoded, None)
header = jwt.token.jose_header
if header['alg'] in ('RS256', 'RS384', 'RS512'):
key = provider.jwkset.get_key(kid=header.get('kid'))
if not key:
raise JWTMissingKey(_('Unknown RSA key identifier %s for provider %s')
% (header.get('kid'), provider))
elif header['alg'] in ('HS256', 'HS384', 'HS512'):
key = JWK(kty='oct', k=base64url_encode(
provider.client_secret.encode('utf-8')))
else:
raise InvalidJWAAlgorithm(
_('Unsupported %s signature algorithm') % header['alg'])
def parse_id_token(id_token):
try:
id_token = str(id_token)
except UnicodeDecodeError:
raise ValueError('invalid characters in id_token')
payload = id_token.split('.')
if len(payload) == 5:
raise ValueError('encrypted IDToken is unsupported')
if len(payload) != 3:
raise ValueError('IDToken does not have three parts, %d found' % len(payload))
try:
headers = base64url_decode(payload[0])
except TypeError as e:
raise ValueError('header is not base64 decodable: %s' % e)
try:
headers = json.loads(headers)
except ValueError:
raise ValueError('cannot JSON decode headers')
if not isinstance(headers, dict):
raise ValueError('JOSE header is not a dict %r' % headers)
if 'typ' in headers and headers.get('typ') != 'JWT':
raise ValueError('JOSE type is not JWT: %s' % headers)
try:
payload = base64url_decode(payload[1])
except TypeError as e:
raise ValueError('payload is not base64 decodable: %s' % e)
try:
payload = json.loads(payload)
except ValueError as e:
raise ValueError('invalid JSON payload: %s' % e)
if not isinstance(payload, dict):
raise ValueError('JOSE payload is not a dict %r' % payload)
# FIXME : really check signature !!!
if 'alg' not in headers or headers['alg'] is None or headers['alg'] == 'none':
raise ValueError('unsigned token: %s' % headers)
return payload
jwt = JWT()
jwt.deserialize(encoded, key)
return json_decode(jwt.claims)
REQUIRED_ID_TOKEN_KEYS = set(['iss', 'sub', 'aud', 'exp', 'iat'])
@ -131,52 +111,55 @@ def parse_timestamp(tstamp):
return datetime.datetime.fromtimestamp(tstamp, utc)
class IDToken(str):
class IDTokenError(ValueError):
pass
class IDToken(object):
auth_time = None
nonce = None
def __new__(cls, encoded):
return str.__new__(cls, encoded)
def __init__(self, encoded):
decoded = parse_id_token(encoded)
if not decoded:
raise ValueError('invalid id_token')
if not isinstance(encoded, (six.binary_type, six.string_types)):
raise IDTokenError(
_('Encoded ID Token must be either binary or string data'))
self._encoded = encoded
def deserialize(self, provider):
try:
decoded = parse_id_token(self._encoded, provider)
if not decoded:
raise JWException(_('invalid id_token'))
except JWException as e:
raise IDTokenError(e)
keys = set(decoded.keys())
# check fields are ok
if not keys.issuperset(REQUIRED_ID_TOKEN_KEYS):
raise ValueError('missing field: %s' % (REQUIRED_ID_TOKEN_KEYS - keys))
raise IDTokenError(
_('missing field: %s') % (REQUIRED_ID_TOKEN_KEYS - keys))
for key in keys:
if key == 'aud':
if not isinstance(decoded['aud'], (six.text_type, list)):
raise ValueError('invalid aud value: %r' % decoded['aud'])
if isinstance(decoded['aud'], list) and not all(isinstance(v, six.text_type) for v in
decoded['aud']):
raise ValueError('invalid aud value: %r' % decoded['aud'])
elif key == 'amr':
if key == 'amr':
if not isinstance(decoded['amr'], list):
raise ValueError('invalid amr value: %s' % decoded['amr'])
raise IDTokenError(
_('invalid amr value: %s') % decoded['amr'])
if not all(isinstance(v, six.text_type) for v in decoded['amr']):
raise ValueError('invalid amr value: %s' % decoded['amr'])
raise IDTokenError(
_('invalid amr value: %s') % decoded['amr'])
elif key in KEY_TYPES:
if not isinstance(decoded[key], KEY_TYPES[key]):
raise ValueError('invalid %s value: %s' % (key, decoded[key]))
raise IDTokenError(
_('invalid %s value: %s') % (key, decoded[key]))
self.iss = decoded.pop('iss')
self.sub = decoded.pop('sub')
self.aud = decoded.pop('aud')
try:
self.exp = parse_timestamp(decoded.pop('exp'))
except ValueError as e:
raise ValueError('invalid exp value: %s' % e)
try:
self.iat = parse_timestamp(decoded.pop('iat'))
except ValueError as e:
raise ValueError('invalid iat value: %s' % e)
self.exp = parse_timestamp(decoded.pop('exp'))
self.iat = parse_timestamp(decoded.pop('iat'))
if 'auth_time' in decoded:
try:
self.auth_time = parse_timestamp(decoded.pop('auth_time'))
except ValueError as e:
raise ValueError('invalid auth_time value: %s' % e)
raise IDTokenError(
_('invalid auth_time value: %s') % e)
self.nonce = decoded.pop('nonce', None)
self.acr = decoded.pop('acr', None)
self.azp = decoded.pop('azp', None)

View File

@ -198,7 +198,7 @@ class LoginCallback(View):
return self.continue_to_next_url()
logger.info(u'got token response %s', result)
access_token = result.get('access_token')
user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token'])
user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token'], provider=provider)
if user:
# remember last tokens for logout
tokens = request.session.setdefault('auth_oidc', {}).setdefault('tokens', [])

View File

@ -24,6 +24,8 @@ import random
from jwcrypto.jwk import JWKSet, JWK
from jwcrypto.jwt import JWT
from jwcrypto.jws import JWS, InvalidJWSObject
from jwcrypto.common import base64url_encode, base64url_decode, json_encode
from httmock import urlmatch, HTTMock
@ -35,73 +37,70 @@ from django.utils.six.moves.urllib import parse as urlparse
from django_rbac.utils import get_ou_model
from authentic2_auth_oidc.utils import (base64url_decode, parse_id_token, IDToken, get_providers,
has_providers, register_issuer)
from authentic2_auth_oidc.utils import (parse_id_token, IDToken, get_providers,
has_providers, register_issuer, IDTokenError)
from authentic2_auth_oidc.models import OIDCProvider, OIDCClaimMapping
from authentic2.models import AttributeValue
from authentic2.utils import timestamp_from_datetime, last_authentication_event
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.crypto import base64url_encode
import utils
pytestmark = pytest.mark.django_db
def test_base64url_decode():
with pytest.raises(TypeError):
with pytest.raises(ValueError):
base64url_decode('x')
base64url_decode('aa')
header = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ'
payload = ('ewogImlzcyI6ICJodHRw'
'Oi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4Mjg5NzYxMDAxIiw'
'KICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAibi0wUzZfV3pBMk1qIi'
'wKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEzMTEyODA5NzAKfQ')
signature = ('ggW8hZ'
'1EuVLuxNuuIJKX_V8a_OMXzR0EHR9R6jgdqrOOF4daGU96Sr_P6qJp6IcmD3HP9'
'9Obi1PRs-cwh3LO-p146waJ8IhehcwL7F09JdijmBqkvPeB2T9CJNqeGpe-gccM'
'g4vfKjkM8FcGvnzZUN4_KSP0aAp1tOJ1zZwgjxqGByKHiOtX7TpdQyHE5lcMiKP'
'XfEIQILVq0pc_E2DzL7emopWoaoZTF_m0_N0YzFC6g6EJbOEoRoSK5hoDalrcvR'
'YLSrQAZZKflyuVCyixEoV9GfNQC3_osjzw2PAithfubEEBLuVVk4XUVrWOLrLl0'
'nx7RkKU8NXNHq-rvKMzqg')
header_rsa_decoded = {'alg': 'RS256', 'kid': '1e9gdk7'}
header_hmac_decoded = {'alg': 'HS256'}
payload_decoded = {
'sub': '248289761001',
'iss': 'http://server.example.com',
'aud': 's6BhdRkqt3',
'nonce': 'n-0S6_WzA2Mj',
'iat': 1311280970,
'exp': 1311281970,
'exp': 2201094278,
}
header_rsa = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ'
header_hmac = 'eyJhbGciOiJIUzI1NiJ9'
payload = ('eyJhdWQiOiJzNkJoZFJrcXQzIiwiZXhwIjoyMjAxMDk0Mjc4LCJpYXQiOjEzMTEyOD'
'A5NzAsImlzcyI6Imh0dHA6Ly9zZXJ2ZXIuZXhhbXBsZS5jb20iLCJub25jZSI6Im4t'
'MFM2X1d6QTJNaiIsInN1YiI6IjI0ODI4OTc2MTAwMSJ9')
def test_parse_id_token():
# example taken from https://tools.ietf.org/html/rfc7519#section-3.1
assert parse_id_token('%s.%s.%s' % (header, payload, signature)) == payload_decoded
with pytest.raises(ValueError):
parse_id_token('x%s.%s.%s' % (header, payload, signature))
with pytest.raises(ValueError):
parse_id_token('%s.%s.%s' % ('$', payload, signature))
with pytest.raises(ValueError):
parse_id_token('%s.x%s.%s' % (header, payload, signature))
with pytest.raises(ValueError):
parse_id_token('%s.%s.%s' % (header, '$', signature))
# signagure is currently ignored
assert parse_id_token('%s.%s.x%s' % (header, payload, signature)) == payload_decoded
assert parse_id_token('%s.%s.%s' % (header, payload, '-')) == payload_decoded
def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header,
signature):
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
with pytest.raises(InvalidJWSObject):
parse_id_token('x%s.%s.%s' % (header, payload, signature), oidc_provider)
with pytest.raises(InvalidJWSObject):
parse_id_token('%s.%s.%s' % ('$', payload, signature), oidc_provider)
with pytest.raises(InvalidJWSObject):
parse_id_token('%s.x%s.%s' % (header, payload, signature), oidc_provider)
with pytest.raises(InvalidJWSObject):
parse_id_token('%s.%s.%s' % (header, '$', signature), oidc_provider)
with pytest.raises(InvalidJWSObject):
parse_id_token('%s.%s.%s' % (header, payload, '-'), oidc_provider)
assert parse_id_token('%s.%s.%s' % (header, payload, signature), oidc_provider)
def test_idtoken():
def test_idtoken(oidc_provider, header, signature):
token = IDToken('%s.%s.%s' % (header, payload, signature))
token.deserialize(oidc_provider)
assert token.sub == payload_decoded['sub']
assert token.iss == payload_decoded['iss']
assert token.aud == payload_decoded['aud']
assert token.nonce == payload_decoded['nonce']
assert token.iat == datetime.datetime(2011, 7, 21, 20, 42, 50, tzinfo=utc)
assert token.exp == datetime.datetime(2011, 7, 21, 20, 59, 30, tzinfo=utc)
assert token.exp == datetime.datetime(2039, 10, 1, 15, 4, 38, tzinfo=utc)
@pytest.fixture
def oidc_provider_jwkset():
key = JWK.generate(kty='RSA', size=512)
key = JWK.generate(kty='RSA', size=512, kid='1e9gdk7')
jwkset = JWKSet()
jwkset.add(key)
return jwkset
@ -131,12 +130,12 @@ def oidc_provider(request, db, oidc_provider_jwkset):
provider = OIDCProvider.objects.create(
ou=get_default_ou(),
name='OIDIDP',
issuer='https://idp.example.com/',
authorization_endpoint='https://idp.example.com/authorize',
token_endpoint='https://idp.example.com/token',
end_session_endpoint='https://idp.example.com/logout',
userinfo_endpoint='https://idp.example.com/user_info',
token_revocation_endpoint='https://idp.example.com/revoke',
issuer='http://server.example.com',
authorization_endpoint='https://server.example.com/authorize',
token_endpoint='https://server.example.com/token',
end_session_endpoint='https://server.example.com/logout',
userinfo_endpoint='https://server.example.com/user_info',
token_revocation_endpoint='https://server.example.com/revoke',
max_auth_age=10,
strategy=OIDCProvider.STRATEGY_CREATE,
jwkset_json=jwkset,
@ -182,6 +181,28 @@ def code():
return 'xxxx'
@pytest.fixture
def header(oidc_provider):
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
return header_rsa
elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
return header_hmac
@pytest.fixture
def signature(oidc_provider):
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
key = oidc_provider.jwkset.get_key(kid='1e9gdk7')
header_decoded = header_rsa_decoded
elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
key = JWK(kty='oct', k=base64url_encode(
oidc_provider.client_secret.encode('utf-8')))
header_decoded = header_hmac_decoded
jws = JWS(payload=json_encode(payload_decoded))
jws.add_signature(key=key, protected=header_decoded)
return json.loads(jws.serialize())['signature']
def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token=None,
extra_user_info=None, sub='john.doe', nonce=None):
token_endpoint = urlparse.urlparse(oidc_provider.token_endpoint)
@ -204,7 +225,7 @@ def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token
id_token.update(extra_id_token)
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
jwt = JWT(header={'alg': 'RS256'},
jwt = JWT(header={'alg': 'RS256', 'kid': '1e9gdk7'},
claims=id_token)
jwt.make_signed_token(list(oidc_provider_jwkset['keys'])[0])
else:
@ -331,7 +352,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url,
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
extra_id_token={'iat': 1}):
response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
with utils.check_log(caplog, 'id_token expired'):
with utils.check_log(caplog, 'invalid id_token %r'):
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
extra_id_token={'exp': 1}):
response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
@ -395,7 +416,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url,
with utils.check_log(caplog, 'revoked token from OIDC'):
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
response = response.click(href='logout')
assert response.location.startswith('https://idp.example.com/logout?')
assert response.location.startswith('https://server.example.com/logout?')
def test_show_on_login_page(app, oidc_provider):
@ -461,7 +482,7 @@ def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwks
with utils.check_log(caplog, 'revoked token from OIDC'):
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce):
response = response.click(href='logout')
assert response.location.startswith('https://idp.example.com/logout?')
assert response.location.startswith('https://server.example.com/logout?')
def test_register_issuer(db, app, caplog, oidc_provider_jwkset):
@ -489,7 +510,7 @@ def test_register_issuer(db, app, caplog, oidc_provider_jwkset):
openid_configuration=oidc_conf)
def test_required_keys(db, caplog):
def test_required_keys(db, oidc_provider, header, signature, caplog):
erroneous_payload = base64url_encode(json.dumps({
'sub': '248289761001',
'iss': 'http://server.example.com',
@ -498,6 +519,7 @@ def test_required_keys(db, caplog):
'extra_stuff': 'hi there', # Wrong claim
}))
with pytest.raises(ValueError) as e:
with pytest.raises(IDTokenError):
with utils.check_log(caplog, 'missing field'):
IDToken('{}.{}.{}'.format(header, erroneous_payload, signature))
token = IDToken('{}.{}.{}'.format(header, erroneous_payload, signature))
token.deserialize(oidc_provider)