authentic/src/authentic2_idp_oidc/utils.py

275 lines
8.9 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import hashlib
import json
import urllib.parse
import uuid
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.encoding import force_bytes, force_text
from jwcrypto.jwk import JWK, InvalidJWKValue, JWKSet
from jwcrypto.jwt import JWT
from authentic2 import crypto, hooks
from authentic2.attributes_ng.engine import get_attributes
from authentic2.decorators import GlobalCache
from authentic2.utils.template import Template
from . import app_settings
def base64url(content):
return base64.urlsafe_b64encode(content).strip(b'=')
def get_jwkset():
try:
jwkset = json.dumps(app_settings.JWKSET)
except Exception as e:
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
try:
jwkset = JWKSet.from_json(jwkset)
except InvalidJWKValue as e:
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
if len(jwkset['keys']) < 1:
raise ImproperlyConfigured('empty A2_IDP_OIDC_JWKSET')
return jwkset
def get_first_sig_key_by_type(kty=None):
if kty:
for key in get_jwkset()['keys']:
if key._params['kty'] != kty:
continue
use = key._params.get('use')
if use is None or use == 'sig':
return key
return None
def get_first_rsa_sig_key():
return get_first_sig_key_by_type('RSA')
def get_first_ec_sig_key():
return get_first_sig_key_by_type('EC')
def make_idtoken(client, claims):
'''Make a serialized JWT targeted for this client'''
if client.idtoken_algo == client.ALGO_HMAC:
header = {'alg': 'HS256'}
k = base64url(client.client_secret.encode('utf-8'))
jwk = JWK(kty='oct', k=force_text(k))
elif client.idtoken_algo == client.ALGO_RSA:
header = {'alg': 'RS256'}
jwk = get_first_rsa_sig_key()
header['kid'] = jwk.key_id
if jwk is None:
raise ImproperlyConfigured('no RSA key for signature operation in A2_IDP_OIDC_JWKSET')
elif client.idtoken_algo == client.ALGO_EC:
header = {'alg': 'ES256'}
jwk = get_first_ec_sig_key()
if jwk is None:
raise ImproperlyConfigured('no EC key for signature operation in A2_IDP_OIDC_JWKSET')
else:
raise NotImplementedError
jwt = JWT(header=header, claims=claims)
jwt.make_signed_token(jwk)
return jwt.serialize()
def scope_set(data):
'''Convert a scope string into a set of scopes'''
return set([scope.strip() for scope in data.split()])
def clean_words(data):
'''Clean and order a list of words'''
return u' '.join(sorted(x.strip() for x in data.split()))
def url_domain(url):
return urllib.parse.urlparse(url).netloc.split(':')[0]
def make_sub(client, user):
if client.identifier_policy in (client.POLICY_PAIRWISE, client.POLICY_PAIRWISE_REVERSIBLE):
return make_pairwise_sub(client, user)
elif client.identifier_policy == client.POLICY_UUID:
return force_text(user.uuid)
elif client.identifier_policy == client.POLICY_EMAIL:
return user.email
else:
raise NotImplementedError
def make_pairwise_sub(client, user):
'''Make a pairwise sub'''
if client.identifier_policy == client.POLICY_PAIRWISE:
return make_pairwise_unreversible_sub(client, user)
elif client.identifier_policy == client.POLICY_PAIRWISE_REVERSIBLE:
return make_pairwise_reversible_sub(client, user)
else:
raise NotImplementedError('unknown pairwise client.identifier_policy %s' % client.identifier_policy)
def make_pairwise_unreversible_sub(client, user):
sector_identifier = client.get_sector_identifier()
sub = sector_identifier + str(user.uuid) + settings.SECRET_KEY
sub = base64.b64encode(hashlib.sha256(sub.encode('utf-8')).digest())
return sub.decode('utf-8')
def make_pairwise_reversible_sub(client, user):
return make_pairwise_reversible_sub_from_uuid(client, user.uuid)
def make_pairwise_reversible_sub_from_uuid(client, user_uuid):
try:
identifier = uuid.UUID(user_uuid).bytes
except ValueError:
return None
sector_identifier = client.get_sector_identifier()
return crypto.aes_base64url_deterministic_encrypt(
settings.SECRET_KEY.encode('utf-8'), identifier, sector_identifier
).decode('utf-8')
def reverse_pairwise_sub(client, sub):
sector_identifier = client.get_sector_identifier()
try:
return crypto.aes_base64url_deterministic_decrypt(
settings.SECRET_KEY.encode('utf-8'), sub, sector_identifier
)
except crypto.DecryptionError:
return None
def normalize_claim_values(values):
values_list = []
if isinstance(values, str) or not hasattr(values, '__iter__'):
return values
for value in values:
if isinstance(value, bool):
value = str(value).lower()
values_list.append(value)
return values_list
def create_user_info(request, client, user, scope_set, id_token=False):
'''Create user info dictionary'''
user_info = {}
if 'openid' in scope_set:
user_info['sub'] = make_sub(client, user)
attributes = get_attributes(
{
'user': user,
'request': request,
'service': client,
'__wanted_attributes': client.get_wanted_attributes(),
}
)
claims = client.oidcclaim_set.filter(name__isnull=False)
claims_to_show = set()
for claim in claims:
if not set(claim.get_scopes()).intersection(scope_set):
continue
claims_to_show.add(claim)
if claim.value and ('{{' in claim.value or '{%' in claim.value):
template = Template(claim.value)
attribute_value = template.render(context=attributes)
else:
if claim.value not in attributes:
continue
attribute_value = attributes[claim.value]
if attribute_value is None:
continue
user_info[claim.name] = normalize_claim_values(attribute_value)
# check if attribute is verified
if claim.value + ':verified' in attributes:
user_info[claim.name + '_verified'] = True
for claim in claims_to_show:
if claim.name not in user_info:
default_value = None
if claim.name in [
'given_name',
'family_name',
'full_name',
'name',
'middle_name',
'nickname',
'email',
'preferred_username',
]:
default_value = ''
user_info[claim.name] = default_value
hooks.call_hooks('idp_oidc_modify_user_info', client, user, scope_set, user_info)
return user_info
def get_issuer(request):
return request.build_absolute_uri('/')
def get_session_id(request, client):
"""Derive an OIDC Session Id from the real session identifier, the sector
identifier of the RP and the secret key of the Django instance"""
session_key = force_bytes(request.session.session_key)
sector_identifier = force_bytes(client.get_sector_identifier())
secret_key = force_bytes(settings.SECRET_KEY)
return hashlib.md5(session_key + sector_identifier + secret_key).hexdigest()
def get_oidc_sessions(request):
return request.session.get('oidc_sessions', {})
def add_oidc_session(request, client):
oidc_sessions = request.session.setdefault('oidc_sessions', {})
if not client.frontchannel_logout_uri:
return
uri = client.frontchannel_logout_uri
oidc_session = {
'frontchannel_logout_uri': uri,
'frontchannel_timeout': client.frontchannel_timeout,
'name': client.name,
'sid': get_session_id(request, client),
'iss': get_issuer(request),
}
if oidc_sessions.get(uri) == oidc_session:
# already present
return
oidc_sessions[uri] = oidc_session
# force session save
request.session.modified = True
@GlobalCache(timeout=60)
def good_next_url(next_url):
from authentic2.utils import same_origin
from .models import OIDCClient
for oidc_client in OIDCClient.objects.all():
for url in oidc_client.redirect_uris.split():
if same_origin(url, next_url):
return True
return None