oidc authn: verify id token signature (#31862)
This commit is contained in:
parent
e472246f3c
commit
4cc45665b7
|
@ -36,14 +36,15 @@ from . import models, utils
|
|||
|
||||
|
||||
class OIDCBackend(ModelBackend):
|
||||
def authenticate(self, request, access_token=None, id_token=None, nonce=None):
|
||||
def authenticate(self, request, access_token=None, id_token=None, nonce=None, provider=None):
|
||||
logger = logging.getLogger(__name__)
|
||||
if id_token is None:
|
||||
if None in (id_token, provider):
|
||||
return
|
||||
original_id_token = id_token
|
||||
try:
|
||||
id_token = utils.IDToken(id_token)
|
||||
except ValueError as e:
|
||||
id_token.deserialize(provider)
|
||||
except utils.IDTokenError as e:
|
||||
logger.warning(u'auth_oidc: invalid id_token %r: %s', id_token, e)
|
||||
return None
|
||||
|
||||
|
@ -98,10 +99,6 @@ class OIDCBackend(ModelBackend):
|
|||
id_token.azp, provider.client_id)
|
||||
return None
|
||||
|
||||
if id_token.exp < now():
|
||||
logger.warning(u'auth_oidc: id_token expired %s', id_token.exp)
|
||||
return None
|
||||
|
||||
if provider.max_auth_age:
|
||||
if not id_token.iat:
|
||||
logger.warning('auth_oidc: provider configured for fresh authentication but iat is '
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import datetime
|
||||
import base64
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from django.utils import six
|
||||
|
@ -30,6 +27,11 @@ from authentic2.decorators import GlobalCache
|
|||
from authentic2.models import Attribute
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
|
||||
from jwcrypto.jwt import JWT, JWTMissingKey
|
||||
from jwcrypto.jwk import JWK
|
||||
from jwcrypto.common import (JWException, InvalidJWAAlgorithm, json_decode,
|
||||
base64url_encode)
|
||||
|
||||
from . import models
|
||||
|
||||
TIMEOUT = 1
|
||||
|
@ -66,49 +68,27 @@ def get_provider_by_issuer(issuer):
|
|||
return models.OIDCProvider.objects.prefetch_related('claim_mappings').get(issuer=issuer)
|
||||
|
||||
|
||||
def base64url_decode(input):
|
||||
rem = len(input) % 4
|
||||
if rem > 0:
|
||||
input += b'=' * (4 - rem)
|
||||
return base64.urlsafe_b64decode(input)
|
||||
def parse_id_token(encoded, provider):
|
||||
""" May raise any subclass of jwcrypto.common.JWException """
|
||||
jwt = JWT()
|
||||
jwt.deserialize(encoded, None)
|
||||
header = jwt.token.jose_header
|
||||
|
||||
if header['alg'] in ('RS256', 'RS384', 'RS512'):
|
||||
key = provider.jwkset.get_key(kid=header.get('kid'))
|
||||
if not key:
|
||||
raise JWTMissingKey(_('Unknown RSA key identifier %s for provider %s')
|
||||
% (header.get('kid'), provider))
|
||||
elif header['alg'] in ('HS256', 'HS384', 'HS512'):
|
||||
key = JWK(kty='oct', k=base64url_encode(
|
||||
provider.client_secret.encode('utf-8')))
|
||||
else:
|
||||
raise InvalidJWAAlgorithm(
|
||||
_('Unsupported %s signature algorithm') % header['alg'])
|
||||
|
||||
def parse_id_token(id_token):
|
||||
try:
|
||||
id_token = str(id_token)
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError('invalid characters in id_token')
|
||||
payload = id_token.split('.')
|
||||
if len(payload) == 5:
|
||||
raise ValueError('encrypted IDToken is unsupported')
|
||||
if len(payload) != 3:
|
||||
raise ValueError('IDToken does not have three parts, %d found' % len(payload))
|
||||
try:
|
||||
headers = base64url_decode(payload[0])
|
||||
except TypeError as e:
|
||||
raise ValueError('header is not base64 decodable: %s' % e)
|
||||
try:
|
||||
headers = json.loads(headers)
|
||||
except ValueError:
|
||||
raise ValueError('cannot JSON decode headers')
|
||||
if not isinstance(headers, dict):
|
||||
raise ValueError('JOSE header is not a dict %r' % headers)
|
||||
if 'typ' in headers and headers.get('typ') != 'JWT':
|
||||
raise ValueError('JOSE type is not JWT: %s' % headers)
|
||||
try:
|
||||
payload = base64url_decode(payload[1])
|
||||
except TypeError as e:
|
||||
raise ValueError('payload is not base64 decodable: %s' % e)
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except ValueError as e:
|
||||
raise ValueError('invalid JSON payload: %s' % e)
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError('JOSE payload is not a dict %r' % payload)
|
||||
# FIXME : really check signature !!!
|
||||
if 'alg' not in headers or headers['alg'] is None or headers['alg'] == 'none':
|
||||
raise ValueError('unsigned token: %s' % headers)
|
||||
return payload
|
||||
jwt = JWT()
|
||||
jwt.deserialize(encoded, key)
|
||||
return json_decode(jwt.claims)
|
||||
|
||||
|
||||
REQUIRED_ID_TOKEN_KEYS = set(['iss', 'sub', 'aud', 'exp', 'iat'])
|
||||
|
@ -131,52 +111,55 @@ def parse_timestamp(tstamp):
|
|||
return datetime.datetime.fromtimestamp(tstamp, utc)
|
||||
|
||||
|
||||
class IDToken(str):
|
||||
class IDTokenError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class IDToken(object):
|
||||
auth_time = None
|
||||
nonce = None
|
||||
|
||||
def __new__(cls, encoded):
|
||||
return str.__new__(cls, encoded)
|
||||
|
||||
def __init__(self, encoded):
|
||||
decoded = parse_id_token(encoded)
|
||||
if not decoded:
|
||||
raise ValueError('invalid id_token')
|
||||
if not isinstance(encoded, (six.binary_type, six.string_types)):
|
||||
raise IDTokenError(
|
||||
_('Encoded ID Token must be either binary or string data'))
|
||||
self._encoded = encoded
|
||||
|
||||
def deserialize(self, provider):
|
||||
try:
|
||||
decoded = parse_id_token(self._encoded, provider)
|
||||
if not decoded:
|
||||
raise JWException(_('invalid id_token'))
|
||||
except JWException as e:
|
||||
raise IDTokenError(e)
|
||||
keys = set(decoded.keys())
|
||||
# check fields are ok
|
||||
if not keys.issuperset(REQUIRED_ID_TOKEN_KEYS):
|
||||
raise ValueError('missing field: %s' % (REQUIRED_ID_TOKEN_KEYS - keys))
|
||||
raise IDTokenError(
|
||||
_('missing field: %s') % (REQUIRED_ID_TOKEN_KEYS - keys))
|
||||
for key in keys:
|
||||
if key == 'aud':
|
||||
if not isinstance(decoded['aud'], (six.text_type, list)):
|
||||
raise ValueError('invalid aud value: %r' % decoded['aud'])
|
||||
if isinstance(decoded['aud'], list) and not all(isinstance(v, six.text_type) for v in
|
||||
decoded['aud']):
|
||||
raise ValueError('invalid aud value: %r' % decoded['aud'])
|
||||
elif key == 'amr':
|
||||
if key == 'amr':
|
||||
if not isinstance(decoded['amr'], list):
|
||||
raise ValueError('invalid amr value: %s' % decoded['amr'])
|
||||
raise IDTokenError(
|
||||
_('invalid amr value: %s') % decoded['amr'])
|
||||
if not all(isinstance(v, six.text_type) for v in decoded['amr']):
|
||||
raise ValueError('invalid amr value: %s' % decoded['amr'])
|
||||
raise IDTokenError(
|
||||
_('invalid amr value: %s') % decoded['amr'])
|
||||
elif key in KEY_TYPES:
|
||||
if not isinstance(decoded[key], KEY_TYPES[key]):
|
||||
raise ValueError('invalid %s value: %s' % (key, decoded[key]))
|
||||
raise IDTokenError(
|
||||
_('invalid %s value: %s') % (key, decoded[key]))
|
||||
self.iss = decoded.pop('iss')
|
||||
self.sub = decoded.pop('sub')
|
||||
self.aud = decoded.pop('aud')
|
||||
try:
|
||||
self.exp = parse_timestamp(decoded.pop('exp'))
|
||||
except ValueError as e:
|
||||
raise ValueError('invalid exp value: %s' % e)
|
||||
try:
|
||||
self.iat = parse_timestamp(decoded.pop('iat'))
|
||||
except ValueError as e:
|
||||
raise ValueError('invalid iat value: %s' % e)
|
||||
self.exp = parse_timestamp(decoded.pop('exp'))
|
||||
self.iat = parse_timestamp(decoded.pop('iat'))
|
||||
if 'auth_time' in decoded:
|
||||
try:
|
||||
self.auth_time = parse_timestamp(decoded.pop('auth_time'))
|
||||
except ValueError as e:
|
||||
raise ValueError('invalid auth_time value: %s' % e)
|
||||
raise IDTokenError(
|
||||
_('invalid auth_time value: %s') % e)
|
||||
self.nonce = decoded.pop('nonce', None)
|
||||
self.acr = decoded.pop('acr', None)
|
||||
self.azp = decoded.pop('azp', None)
|
||||
|
|
|
@ -198,7 +198,7 @@ class LoginCallback(View):
|
|||
return self.continue_to_next_url()
|
||||
logger.info(u'got token response %s', result)
|
||||
access_token = result.get('access_token')
|
||||
user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token'])
|
||||
user = authenticate(request, access_token=access_token, nonce=nonce, id_token=result['id_token'], provider=provider)
|
||||
if user:
|
||||
# remember last tokens for logout
|
||||
tokens = request.session.setdefault('auth_oidc', {}).setdefault('tokens', [])
|
||||
|
|
|
@ -24,6 +24,8 @@ import random
|
|||
|
||||
from jwcrypto.jwk import JWKSet, JWK
|
||||
from jwcrypto.jwt import JWT
|
||||
from jwcrypto.jws import JWS, InvalidJWSObject
|
||||
from jwcrypto.common import base64url_encode, base64url_decode, json_encode
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
|
@ -35,73 +37,70 @@ from django.utils.six.moves.urllib import parse as urlparse
|
|||
|
||||
from django_rbac.utils import get_ou_model
|
||||
|
||||
from authentic2_auth_oidc.utils import (base64url_decode, parse_id_token, IDToken, get_providers,
|
||||
has_providers, register_issuer)
|
||||
from authentic2_auth_oidc.utils import (parse_id_token, IDToken, get_providers,
|
||||
has_providers, register_issuer, IDTokenError)
|
||||
from authentic2_auth_oidc.models import OIDCProvider, OIDCClaimMapping
|
||||
from authentic2.models import AttributeValue
|
||||
from authentic2.utils import timestamp_from_datetime, last_authentication_event
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.crypto import base64url_encode
|
||||
|
||||
import utils
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
def test_base64url_decode():
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
base64url_decode('x')
|
||||
base64url_decode('aa')
|
||||
|
||||
header = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ'
|
||||
payload = ('ewogImlzcyI6ICJodHRw'
|
||||
'Oi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4Mjg5NzYxMDAxIiw'
|
||||
'KICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAibi0wUzZfV3pBMk1qIi'
|
||||
'wKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEzMTEyODA5NzAKfQ')
|
||||
signature = ('ggW8hZ'
|
||||
'1EuVLuxNuuIJKX_V8a_OMXzR0EHR9R6jgdqrOOF4daGU96Sr_P6qJp6IcmD3HP9'
|
||||
'9Obi1PRs-cwh3LO-p146waJ8IhehcwL7F09JdijmBqkvPeB2T9CJNqeGpe-gccM'
|
||||
'g4vfKjkM8FcGvnzZUN4_KSP0aAp1tOJ1zZwgjxqGByKHiOtX7TpdQyHE5lcMiKP'
|
||||
'XfEIQILVq0pc_E2DzL7emopWoaoZTF_m0_N0YzFC6g6EJbOEoRoSK5hoDalrcvR'
|
||||
'YLSrQAZZKflyuVCyixEoV9GfNQC3_osjzw2PAithfubEEBLuVVk4XUVrWOLrLl0'
|
||||
'nx7RkKU8NXNHq-rvKMzqg')
|
||||
header_rsa_decoded = {'alg': 'RS256', 'kid': '1e9gdk7'}
|
||||
header_hmac_decoded = {'alg': 'HS256'}
|
||||
payload_decoded = {
|
||||
'sub': '248289761001',
|
||||
'iss': 'http://server.example.com',
|
||||
'aud': 's6BhdRkqt3',
|
||||
'nonce': 'n-0S6_WzA2Mj',
|
||||
'iat': 1311280970,
|
||||
'exp': 1311281970,
|
||||
'exp': 2201094278,
|
||||
}
|
||||
header_rsa = 'eyJhbGciOiJSUzI1NiIsImtpZCI6IjFlOWdkazcifQ'
|
||||
header_hmac = 'eyJhbGciOiJIUzI1NiJ9'
|
||||
payload = ('eyJhdWQiOiJzNkJoZFJrcXQzIiwiZXhwIjoyMjAxMDk0Mjc4LCJpYXQiOjEzMTEyOD'
|
||||
'A5NzAsImlzcyI6Imh0dHA6Ly9zZXJ2ZXIuZXhhbXBsZS5jb20iLCJub25jZSI6Im4t'
|
||||
'MFM2X1d6QTJNaiIsInN1YiI6IjI0ODI4OTc2MTAwMSJ9')
|
||||
|
||||
|
||||
def test_parse_id_token():
|
||||
# example taken from https://tools.ietf.org/html/rfc7519#section-3.1
|
||||
assert parse_id_token('%s.%s.%s' % (header, payload, signature)) == payload_decoded
|
||||
with pytest.raises(ValueError):
|
||||
parse_id_token('x%s.%s.%s' % (header, payload, signature))
|
||||
with pytest.raises(ValueError):
|
||||
parse_id_token('%s.%s.%s' % ('$', payload, signature))
|
||||
with pytest.raises(ValueError):
|
||||
parse_id_token('%s.x%s.%s' % (header, payload, signature))
|
||||
with pytest.raises(ValueError):
|
||||
parse_id_token('%s.%s.%s' % (header, '$', signature))
|
||||
# signagure is currently ignored
|
||||
assert parse_id_token('%s.%s.x%s' % (header, payload, signature)) == payload_decoded
|
||||
assert parse_id_token('%s.%s.%s' % (header, payload, '-')) == payload_decoded
|
||||
def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header,
|
||||
signature):
|
||||
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
|
||||
with pytest.raises(InvalidJWSObject):
|
||||
parse_id_token('x%s.%s.%s' % (header, payload, signature), oidc_provider)
|
||||
with pytest.raises(InvalidJWSObject):
|
||||
parse_id_token('%s.%s.%s' % ('$', payload, signature), oidc_provider)
|
||||
with pytest.raises(InvalidJWSObject):
|
||||
parse_id_token('%s.x%s.%s' % (header, payload, signature), oidc_provider)
|
||||
with pytest.raises(InvalidJWSObject):
|
||||
parse_id_token('%s.%s.%s' % (header, '$', signature), oidc_provider)
|
||||
with pytest.raises(InvalidJWSObject):
|
||||
parse_id_token('%s.%s.%s' % (header, payload, '-'), oidc_provider)
|
||||
assert parse_id_token('%s.%s.%s' % (header, payload, signature), oidc_provider)
|
||||
|
||||
|
||||
def test_idtoken():
|
||||
def test_idtoken(oidc_provider, header, signature):
|
||||
token = IDToken('%s.%s.%s' % (header, payload, signature))
|
||||
token.deserialize(oidc_provider)
|
||||
assert token.sub == payload_decoded['sub']
|
||||
assert token.iss == payload_decoded['iss']
|
||||
assert token.aud == payload_decoded['aud']
|
||||
assert token.nonce == payload_decoded['nonce']
|
||||
assert token.iat == datetime.datetime(2011, 7, 21, 20, 42, 50, tzinfo=utc)
|
||||
assert token.exp == datetime.datetime(2011, 7, 21, 20, 59, 30, tzinfo=utc)
|
||||
assert token.exp == datetime.datetime(2039, 10, 1, 15, 4, 38, tzinfo=utc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_provider_jwkset():
|
||||
key = JWK.generate(kty='RSA', size=512)
|
||||
key = JWK.generate(kty='RSA', size=512, kid='1e9gdk7')
|
||||
jwkset = JWKSet()
|
||||
jwkset.add(key)
|
||||
return jwkset
|
||||
|
@ -131,12 +130,12 @@ def oidc_provider(request, db, oidc_provider_jwkset):
|
|||
provider = OIDCProvider.objects.create(
|
||||
ou=get_default_ou(),
|
||||
name='OIDIDP',
|
||||
issuer='https://idp.example.com/',
|
||||
authorization_endpoint='https://idp.example.com/authorize',
|
||||
token_endpoint='https://idp.example.com/token',
|
||||
end_session_endpoint='https://idp.example.com/logout',
|
||||
userinfo_endpoint='https://idp.example.com/user_info',
|
||||
token_revocation_endpoint='https://idp.example.com/revoke',
|
||||
issuer='http://server.example.com',
|
||||
authorization_endpoint='https://server.example.com/authorize',
|
||||
token_endpoint='https://server.example.com/token',
|
||||
end_session_endpoint='https://server.example.com/logout',
|
||||
userinfo_endpoint='https://server.example.com/user_info',
|
||||
token_revocation_endpoint='https://server.example.com/revoke',
|
||||
max_auth_age=10,
|
||||
strategy=OIDCProvider.STRATEGY_CREATE,
|
||||
jwkset_json=jwkset,
|
||||
|
@ -182,6 +181,28 @@ def code():
|
|||
return 'xxxx'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def header(oidc_provider):
|
||||
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
|
||||
return header_rsa
|
||||
elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
|
||||
return header_hmac
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def signature(oidc_provider):
|
||||
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
|
||||
key = oidc_provider.jwkset.get_key(kid='1e9gdk7')
|
||||
header_decoded = header_rsa_decoded
|
||||
elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
|
||||
key = JWK(kty='oct', k=base64url_encode(
|
||||
oidc_provider.client_secret.encode('utf-8')))
|
||||
header_decoded = header_hmac_decoded
|
||||
jws = JWS(payload=json_encode(payload_decoded))
|
||||
jws.add_signature(key=key, protected=header_decoded)
|
||||
return json.loads(jws.serialize())['signature']
|
||||
|
||||
|
||||
def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token=None,
|
||||
extra_user_info=None, sub='john.doe', nonce=None):
|
||||
token_endpoint = urlparse.urlparse(oidc_provider.token_endpoint)
|
||||
|
@ -204,7 +225,7 @@ def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token
|
|||
id_token.update(extra_id_token)
|
||||
|
||||
if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
|
||||
jwt = JWT(header={'alg': 'RS256'},
|
||||
jwt = JWT(header={'alg': 'RS256', 'kid': '1e9gdk7'},
|
||||
claims=id_token)
|
||||
jwt.make_signed_token(list(oidc_provider_jwkset['keys'])[0])
|
||||
else:
|
||||
|
@ -331,7 +352,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url,
|
|||
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
|
||||
extra_id_token={'iat': 1}):
|
||||
response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
|
||||
with utils.check_log(caplog, 'id_token expired'):
|
||||
with utils.check_log(caplog, 'invalid id_token %r'):
|
||||
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
|
||||
extra_id_token={'exp': 1}):
|
||||
response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
|
||||
|
@ -395,7 +416,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url,
|
|||
with utils.check_log(caplog, 'revoked token from OIDC'):
|
||||
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
|
||||
response = response.click(href='logout')
|
||||
assert response.location.startswith('https://idp.example.com/logout?')
|
||||
assert response.location.startswith('https://server.example.com/logout?')
|
||||
|
||||
|
||||
def test_show_on_login_page(app, oidc_provider):
|
||||
|
@ -461,7 +482,7 @@ def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwks
|
|||
with utils.check_log(caplog, 'revoked token from OIDC'):
|
||||
with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce):
|
||||
response = response.click(href='logout')
|
||||
assert response.location.startswith('https://idp.example.com/logout?')
|
||||
assert response.location.startswith('https://server.example.com/logout?')
|
||||
|
||||
|
||||
def test_register_issuer(db, app, caplog, oidc_provider_jwkset):
|
||||
|
@ -489,7 +510,7 @@ def test_register_issuer(db, app, caplog, oidc_provider_jwkset):
|
|||
openid_configuration=oidc_conf)
|
||||
|
||||
|
||||
def test_required_keys(db, caplog):
|
||||
def test_required_keys(db, oidc_provider, header, signature, caplog):
|
||||
erroneous_payload = base64url_encode(json.dumps({
|
||||
'sub': '248289761001',
|
||||
'iss': 'http://server.example.com',
|
||||
|
@ -498,6 +519,7 @@ def test_required_keys(db, caplog):
|
|||
'extra_stuff': 'hi there', # Wrong claim
|
||||
}))
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
with pytest.raises(IDTokenError):
|
||||
with utils.check_log(caplog, 'missing field'):
|
||||
IDToken('{}.{}.{}'.format(header, erroneous_payload, signature))
|
||||
token = IDToken('{}.{}.{}'.format(header, erroneous_payload, signature))
|
||||
token.deserialize(oidc_provider)
|
||||
|
|
Loading…
Reference in New Issue