authentic/src/authentic2_idp_oidc/utils.py

267 lines
8.9 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import hashlib
import base64
import uuid
from jwcrypto.jwk import JWK, JWKSet, InvalidJWKValue
from jwcrypto.jwt import JWT
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from django.utils import six
from django.utils.encoding import force_bytes, force_text
from django.utils.six.moves.urllib import parse as urlparse
from authentic2 import hooks, crypto
from authentic2.attributes_ng.engine import get_attributes
from authentic2.utils.template import Template
from authentic2.decorators import GlobalCache
from . import app_settings
def base64url(content):
return base64.urlsafe_b64encode(content).strip(b'=')
def get_jwkset():
try:
jwkset = json.dumps(app_settings.JWKSET)
except Exception as e:
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
try:
jwkset = JWKSet.from_json(jwkset)
except InvalidJWKValue as e:
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
if len(jwkset['keys']) < 1:
raise ImproperlyConfigured('empty A2_IDP_OIDC_JWKSET')
return jwkset
def get_first_sig_key_by_type(kty=None):
if kty:
for key in get_jwkset()['keys']:
if key._params['kty'] != kty:
continue
use = key._params.get('use')
if use is None or use == 'sig':
return key
return None
def get_first_rsa_sig_key():
return get_first_sig_key_by_type('RSA')
def get_first_ec_sig_key():
return get_first_sig_key_by_type('EC')
def make_idtoken(client, claims):
'''Make a serialized JWT targeted for this client'''
if client.idtoken_algo == client.ALGO_HMAC:
header = {'alg': 'HS256'}
k=base64url(client.client_secret.encode('utf-8'))
jwk = JWK(kty='oct', k=force_text(k))
elif client.idtoken_algo == client.ALGO_RSA:
header = {'alg': 'RS256'}
jwk = get_first_rsa_sig_key()
header['kid'] = jwk.key_id
if jwk is None:
raise ImproperlyConfigured('no RSA key for signature operation in A2_IDP_OIDC_JWKSET')
elif client.idtoken_algo == client.ALGO_EC:
header = {'alg': 'ES256'}
jwk = get_first_ec_sig_key()
if jwk is None:
raise ImproperlyConfigured('no EC key for signature operation in A2_IDP_OIDC_JWKSET')
else:
raise NotImplementedError
jwt = JWT(header=header, claims=claims)
jwt.make_signed_token(jwk)
return jwt.serialize()
def scope_set(data):
'''Convert a scope string into a set of scopes'''
return set([scope.strip() for scope in data.split()])
def clean_words(data):
'''Clean and order a list of words'''
return u' '.join(sorted(x.strip() for x in data.split()))
def url_domain(url):
return urlparse.urlparse(url).netloc.split(':')[0]
def make_sub(client, user):
if client.identifier_policy in (client.POLICY_PAIRWISE, client.POLICY_PAIRWISE_REVERSIBLE):
return make_pairwise_sub(client, user)
elif client.identifier_policy == client.POLICY_UUID:
return force_text(user.uuid)
elif client.identifier_policy == client.POLICY_EMAIL:
return user.email
else:
raise NotImplementedError
def make_pairwise_sub(client, user):
'''Make a pairwise sub'''
if client.identifier_policy == client.POLICY_PAIRWISE:
return make_pairwise_unreversible_sub(client, user)
elif client.identifier_policy == client.POLICY_PAIRWISE_REVERSIBLE:
return make_pairwise_reversible_sub(client, user)
else:
raise NotImplementedError(
'unknown pairwise client.identifier_policy %s' % client.identifier_policy)
def make_pairwise_unreversible_sub(client, user):
sector_identifier = client.get_sector_identifier()
sub = sector_identifier + str(user.uuid) + settings.SECRET_KEY
sub = base64.b64encode(hashlib.sha256(sub.encode('utf-8')).digest())
return sub.decode('utf-8')
def make_pairwise_reversible_sub(client, user):
return make_pairwise_reversible_sub_from_uuid(client, user.uuid)
def make_pairwise_reversible_sub_from_uuid(client, user_uuid):
try:
identifier = uuid.UUID(user_uuid).bytes
except ValueError:
return None
sector_identifier = client.get_sector_identifier()
return crypto.aes_base64url_deterministic_encrypt(
settings.SECRET_KEY.encode('utf-8'), identifier, sector_identifier).decode('utf-8')
def reverse_pairwise_sub(client, sub):
sector_identifier = client.get_sector_identifier()
try:
return crypto.aes_base64url_deterministic_decrypt(
settings.SECRET_KEY.encode('utf-8'), sub, sector_identifier)
except crypto.DecryptionError:
return None
def normalize_claim_values(values):
values_list = []
if isinstance(values, six.string_types) or not hasattr(values, '__iter__'):
return values
for value in values:
if isinstance(value, bool):
value = str(value).lower()
values_list.append(value)
return values_list
def create_user_info(request, client, user, scope_set, id_token=False):
'''Create user info dictionary'''
user_info = {
}
if 'openid' in scope_set:
user_info['sub'] = make_sub(client, user)
attributes = get_attributes({
'user': user,
'request': request,
'service': client,
'__wanted_attributes': client.get_wanted_attributes(),
})
claims = client.oidcclaim_set.filter(name__isnull=False)
claims_to_show = set()
for claim in claims:
if not set(claim.get_scopes()).intersection(scope_set):
continue
claims_to_show.add(claim)
if claim.value and ('{{' in claim.value or '{%' in claim.value):
template = Template(claim.value)
attribute_value = template.render(context=attributes)
else:
if claim.value not in attributes:
continue
attribute_value = attributes[claim.value]
if attribute_value is None:
continue
user_info[claim.name] = normalize_claim_values(attribute_value)
# check if attribute is verified
if claim.value + ':verified' in attributes:
user_info[claim.name + '_verified'] = True
for claim in claims_to_show:
if claim.name not in user_info:
default_value = None
if claim.name in ['given_name', 'family_name', 'full_name', 'name',
'middle_name', 'nickname', 'email',
'preferred_username']:
default_value = ''
user_info[claim.name] = default_value
hooks.call_hooks('idp_oidc_modify_user_info', client, user, scope_set, user_info)
return user_info
def get_issuer(request):
return request.build_absolute_uri('/')
def get_session_id(request, client):
'''Derive an OIDC Session Id from the real session identifier, the sector
identifier of the RP and the secret key of the Django instance'''
session_key = force_bytes(request.session.session_key)
sector_identifier = force_bytes(client.get_sector_identifier())
secret_key = force_bytes(settings.SECRET_KEY)
return hashlib.md5(session_key + sector_identifier + secret_key).hexdigest()
def get_oidc_sessions(request):
return request.session.get('oidc_sessions', {})
def add_oidc_session(request, client):
oidc_sessions = request.session.setdefault('oidc_sessions', {})
if not client.frontchannel_logout_uri:
return
uri = client.frontchannel_logout_uri
oidc_session = {
'frontchannel_logout_uri': uri,
'frontchannel_timeout': client.frontchannel_timeout,
'name': client.name,
'sid': get_session_id(request, client),
'iss': get_issuer(request),
}
if oidc_sessions.get(uri) == oidc_session:
# already present
return
oidc_sessions[uri] = oidc_session
# force session save
request.session.modified = True
@GlobalCache(timeout=60)
def good_next_url(next_url):
from authentic2.utils import same_origin
from .models import OIDCClient
for oidc_client in OIDCClient.objects.all():
for url in oidc_client.redirect_uris.split():
if same_origin(url, next_url):
return True
return None