283 lines
9.3 KiB
Python
283 lines
9.3 KiB
Python
# authentic2 - versatile identity manager
|
|
# Copyright (C) 2010-2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import urllib.parse
|
|
import uuid
|
|
|
|
from django.conf import settings
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
from django.utils.encoding import force_bytes, force_text
|
|
from jwcrypto.jwk import JWK, InvalidJWKValue, JWKSet
|
|
from jwcrypto.jwt import JWT
|
|
|
|
from authentic2 import crypto, hooks
|
|
from authentic2.attributes_ng.engine import get_attributes
|
|
from authentic2.decorators import GlobalCache
|
|
from authentic2.utils.template import Template
|
|
|
|
from . import app_settings
|
|
|
|
|
|
def base64url(content):
|
|
return base64.urlsafe_b64encode(content).strip(b'=')
|
|
|
|
|
|
def get_jwkset():
|
|
try:
|
|
jwkset = json.dumps(app_settings.JWKSET)
|
|
except Exception as e:
|
|
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
|
|
try:
|
|
jwkset = JWKSet.from_json(jwkset)
|
|
except InvalidJWKValue as e:
|
|
raise ImproperlyConfigured('invalid setting A2_IDP_OIDC_JWKSET: %s' % e)
|
|
if len(jwkset['keys']) < 1:
|
|
raise ImproperlyConfigured('empty A2_IDP_OIDC_JWKSET')
|
|
return jwkset
|
|
|
|
|
|
def get_first_sig_key_by_type(kty=None):
|
|
if kty:
|
|
for key in get_jwkset()['keys']:
|
|
# XXX: remove when jwcrypto version is over 0.9.1 everywhere
|
|
if hasattr(key, '_params'):
|
|
if key._params['kty'] != kty:
|
|
continue
|
|
use = key._params.get('use')
|
|
if use is None or use == 'sig':
|
|
return key
|
|
else:
|
|
if key['kty'] != kty:
|
|
continue
|
|
use = key.get('use')
|
|
if use is None or use == 'sig':
|
|
return key
|
|
return None
|
|
|
|
|
|
def get_first_rsa_sig_key():
|
|
return get_first_sig_key_by_type('RSA')
|
|
|
|
|
|
def get_first_ec_sig_key():
|
|
return get_first_sig_key_by_type('EC')
|
|
|
|
|
|
def make_idtoken(client, claims):
|
|
'''Make a serialized JWT targeted for this client'''
|
|
if client.idtoken_algo == client.ALGO_HMAC:
|
|
header = {'alg': 'HS256'}
|
|
k = base64url(client.client_secret.encode('utf-8'))
|
|
jwk = JWK(kty='oct', k=force_text(k))
|
|
elif client.idtoken_algo == client.ALGO_RSA:
|
|
header = {'alg': 'RS256'}
|
|
jwk = get_first_rsa_sig_key()
|
|
header['kid'] = jwk.key_id
|
|
if jwk is None:
|
|
raise ImproperlyConfigured('no RSA key for signature operation in A2_IDP_OIDC_JWKSET')
|
|
elif client.idtoken_algo == client.ALGO_EC:
|
|
header = {'alg': 'ES256'}
|
|
jwk = get_first_ec_sig_key()
|
|
if jwk is None:
|
|
raise ImproperlyConfigured('no EC key for signature operation in A2_IDP_OIDC_JWKSET')
|
|
else:
|
|
raise NotImplementedError
|
|
jwt = JWT(header=header, claims=claims)
|
|
jwt.make_signed_token(jwk)
|
|
return jwt.serialize()
|
|
|
|
|
|
def scope_set(data):
|
|
'''Convert a scope string into a set of scopes'''
|
|
return {scope.strip() for scope in data.split()}
|
|
|
|
|
|
def clean_words(data):
|
|
'''Clean and order a list of words'''
|
|
return ' '.join(sorted(x.strip() for x in data.split()))
|
|
|
|
|
|
def url_domain(url):
|
|
return urllib.parse.urlparse(url).netloc.split(':')[0]
|
|
|
|
|
|
def make_sub(client, user):
|
|
if client.identifier_policy in (client.POLICY_PAIRWISE, client.POLICY_PAIRWISE_REVERSIBLE):
|
|
return make_pairwise_sub(client, user)
|
|
elif client.identifier_policy == client.POLICY_UUID:
|
|
return force_text(user.uuid)
|
|
elif client.identifier_policy == client.POLICY_EMAIL:
|
|
return user.email
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def make_pairwise_sub(client, user):
|
|
'''Make a pairwise sub'''
|
|
if client.identifier_policy == client.POLICY_PAIRWISE:
|
|
return make_pairwise_unreversible_sub(client, user)
|
|
elif client.identifier_policy == client.POLICY_PAIRWISE_REVERSIBLE:
|
|
return make_pairwise_reversible_sub(client, user)
|
|
else:
|
|
raise NotImplementedError('unknown pairwise client.identifier_policy %s' % client.identifier_policy)
|
|
|
|
|
|
def make_pairwise_unreversible_sub(client, user):
|
|
sector_identifier = client.get_sector_identifier()
|
|
sub = sector_identifier + str(user.uuid) + settings.SECRET_KEY
|
|
sub = base64.b64encode(hashlib.sha256(sub.encode('utf-8')).digest())
|
|
return sub.decode('utf-8')
|
|
|
|
|
|
def make_pairwise_reversible_sub(client, user):
|
|
return make_pairwise_reversible_sub_from_uuid(client, user.uuid)
|
|
|
|
|
|
def make_pairwise_reversible_sub_from_uuid(client, user_uuid):
|
|
try:
|
|
identifier = uuid.UUID(user_uuid).bytes
|
|
except ValueError:
|
|
return None
|
|
sector_identifier = client.get_sector_identifier()
|
|
return crypto.aes_base64url_deterministic_encrypt(
|
|
settings.SECRET_KEY.encode('utf-8'), identifier, sector_identifier
|
|
).decode('utf-8')
|
|
|
|
|
|
def reverse_pairwise_sub(client, sub):
|
|
sector_identifier = client.get_sector_identifier()
|
|
try:
|
|
return crypto.aes_base64url_deterministic_decrypt(
|
|
settings.SECRET_KEY.encode('utf-8'), sub, sector_identifier
|
|
)
|
|
except crypto.DecryptionError:
|
|
return None
|
|
|
|
|
|
def normalize_claim_values(values):
|
|
values_list = []
|
|
if isinstance(values, str) or not hasattr(values, '__iter__'):
|
|
return values
|
|
for value in values:
|
|
if isinstance(value, bool):
|
|
value = str(value).lower()
|
|
values_list.append(value)
|
|
return values_list
|
|
|
|
|
|
def create_user_info(request, client, user, scope_set, id_token=False):
|
|
'''Create user info dictionary'''
|
|
user_info = {}
|
|
if 'openid' in scope_set:
|
|
user_info['sub'] = make_sub(client, user)
|
|
attributes = get_attributes(
|
|
{
|
|
'user': user,
|
|
'request': request,
|
|
'service': client,
|
|
'__wanted_attributes': client.get_wanted_attributes(),
|
|
}
|
|
)
|
|
claims = client.oidcclaim_set.filter(name__isnull=False)
|
|
claims_to_show = set()
|
|
for claim in claims:
|
|
if not set(claim.get_scopes()).intersection(scope_set):
|
|
continue
|
|
claims_to_show.add(claim)
|
|
if claim.value and ('{{' in claim.value or '{%' in claim.value):
|
|
template = Template(claim.value)
|
|
attribute_value = template.render(context=attributes)
|
|
else:
|
|
if claim.value not in attributes:
|
|
continue
|
|
attribute_value = attributes[claim.value]
|
|
if attribute_value is None:
|
|
continue
|
|
user_info[claim.name] = normalize_claim_values(attribute_value)
|
|
# check if attribute is verified
|
|
if claim.value + ':verified' in attributes:
|
|
user_info[claim.name + '_verified'] = True
|
|
for claim in claims_to_show:
|
|
if claim.name not in user_info:
|
|
default_value = None
|
|
if claim.name in [
|
|
'given_name',
|
|
'family_name',
|
|
'full_name',
|
|
'name',
|
|
'middle_name',
|
|
'nickname',
|
|
'email',
|
|
'preferred_username',
|
|
]:
|
|
default_value = ''
|
|
user_info[claim.name] = default_value
|
|
hooks.call_hooks('idp_oidc_modify_user_info', client, user, scope_set, user_info)
|
|
return user_info
|
|
|
|
|
|
def get_issuer(request):
|
|
return request.build_absolute_uri('/')
|
|
|
|
|
|
def get_session_id(request, client):
|
|
"""Derive an OIDC Session Id from the real session identifier, the sector
|
|
identifier of the RP and the secret key of the Django instance"""
|
|
session_key = force_bytes(request.session.session_key)
|
|
sector_identifier = force_bytes(client.get_sector_identifier())
|
|
secret_key = force_bytes(settings.SECRET_KEY)
|
|
return hashlib.md5(session_key + sector_identifier + secret_key).hexdigest()
|
|
|
|
|
|
def get_oidc_sessions(request):
|
|
return request.session.get('oidc_sessions', {})
|
|
|
|
|
|
def add_oidc_session(request, client):
|
|
oidc_sessions = request.session.setdefault('oidc_sessions', {})
|
|
if not client.frontchannel_logout_uri:
|
|
return
|
|
uri = client.frontchannel_logout_uri
|
|
oidc_session = {
|
|
'frontchannel_logout_uri': uri,
|
|
'frontchannel_timeout': client.frontchannel_timeout,
|
|
'name': client.name,
|
|
'sid': get_session_id(request, client),
|
|
'iss': get_issuer(request),
|
|
}
|
|
if oidc_sessions.get(uri) == oidc_session:
|
|
# already present
|
|
return
|
|
oidc_sessions[uri] = oidc_session
|
|
# force session save
|
|
request.session.modified = True
|
|
|
|
|
|
@GlobalCache(timeout=60)
|
|
def good_next_url(next_url):
|
|
from authentic2.utils.misc import same_origin
|
|
|
|
from .models import OIDCClient
|
|
|
|
for oidc_client in OIDCClient.objects.all():
|
|
for url in oidc_client.redirect_uris.split():
|
|
if same_origin(url, next_url):
|
|
return True
|
|
return None
|