# django-mellon - SAML2 authentication for Django # Copyright (C) 2014-2019 Entr'ouvert # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import datetime import importlib import logging from functools import wraps from xml.parsers import expat import isodate import lasso from django.conf import settings from django.contrib import auth from django.template.loader import render_to_string from django.urls import reverse from django.utils.encoding import force_text from django.utils.six.moves.urllib.parse import urlparse from django.utils.timezone import get_default_timezone, is_aware, make_aware, make_naive, now from . import app_settings logger = logging.getLogger(__name__) class CreateServerError(Exception): pass def create_metadata(request): entity_id = reverse('mellon_metadata') login_url = reverse(app_settings.LOGIN_URL) logout_url = reverse(app_settings.LOGOUT_URL) public_keys = [] for public_key in app_settings.PUBLIC_KEYS: if public_key.startswith('/'): # clean PEM file content = open(public_key).read() public_key = ''.join(content.splitlines()[1:-1]) public_keys.append(public_key) name_id_formats = app_settings.NAME_ID_FORMATS ctx = { 'request': request, 'entity_id': request.build_absolute_uri(entity_id), 'login_url': request.build_absolute_uri(login_url), 'logout_url': request.build_absolute_uri(logout_url), 'public_keys': public_keys, 'name_id_formats': name_id_formats, 'assertion_consumer_bindings': app_settings.ASSERTION_CONSUMER_BINDINGS, 'default_assertion_consumer_binding': app_settings.DEFAULT_ASSERTION_CONSUMER_BINDING, 'organization': app_settings.ORGANIZATION, 'contact_persons': app_settings.CONTACT_PERSONS, } if app_settings.METADATA_PUBLISH_DISCOVERY_RESPONSE: ctx['discovery_endpoint_url'] = request.build_absolute_uri(reverse('mellon_login')) return render_to_string('mellon/metadata.xml', ctx) def create_server(request): root = request.build_absolute_uri('/') cache = getattr(settings, '_MELLON_SERVER_CACHE', {}) if root not in cache: metadata = create_metadata(request) if app_settings.PRIVATE_KEY: private_key = app_settings.PRIVATE_KEY private_key_password = app_settings.PRIVATE_KEY_PASSWORD elif app_settings.PRIVATE_KEYS: private_key = app_settings.PRIVATE_KEYS[0] private_key_password = None if isinstance(private_key, (tuple, list)): private_key_password = private_key[1] private_key = private_key[0] else: # no signature private_key = None private_key_password = None server = lasso.Server.newFromBuffers( metadata, private_key_content=private_key, private_key_password=private_key_password ) if not server: raise CreateServerError if app_settings.SIGNATURE_METHOD: symbol_name = 'SIGNATURE_METHOD_' + app_settings.SIGNATURE_METHOD.replace('-', '_').upper() if hasattr(lasso, symbol_name): server.signatureMethod = getattr(lasso, symbol_name) else: logger.warning('mellon: unable to set signature method %s', app_settings.SIGNATURE_METHOD) server.setEncryptionPrivateKeyWithPassword(private_key, private_key_password) private_keys = app_settings.PRIVATE_KEYS # skip first key if it is already loaded if not app_settings.PRIVATE_KEY: private_keys = app_settings.PRIVATE_KEYS[1:] for key in private_keys: password = None if isinstance(key, (tuple, list)): password = key[1] key = key[0] server.setEncryptionPrivateKeyWithPassword(key, password) for idp in get_idps(): if idp and idp.get('METADATA'): try: server.addProviderFromBuffer(lasso.PROVIDER_ROLE_IDP, idp['METADATA']) except lasso.Error as e: logger.error('bad metadata in idp %s, %s', idp['ENTITY_ID'], e) cache[root] = server settings._MELLON_SERVER_CACHE = cache return cache.get(root) def create_login(request): server = create_server(request) login = lasso.Login(server) if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS: login.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) return login def get_idp(entity_id): for adapter in get_adapters(): if hasattr(adapter, 'get_idp'): idp = adapter.get_idp(entity_id) if idp: return idp return {} def get_idps(): for adapter in get_adapters(): if hasattr(adapter, 'get_idps'): yield from adapter.get_idps() def flatten_datetime(d): d = d.copy() for key, value in d.items(): if isinstance(value, datetime.datetime): d[key] = value.isoformat() return d def iso8601_to_datetime(date_string, default=None): """Convert a string formatted as an ISO8601 date into a datetime value. This function ignores the sub-second resolution""" try: dt = isodate.parse_datetime(date_string) except Exception: return default if is_aware(dt): if not settings.USE_TZ: dt = make_naive(dt, get_default_timezone()) else: if settings.USE_TZ: dt = make_aware(dt, get_default_timezone()) return dt def get_seconds_expiry(datetime_expiry): return (datetime_expiry - now()).total_seconds() def to_list(func): @wraps(func) def f(*args, **kwargs): return list(func(*args, **kwargs)) return f def import_object(path): module, name = path.rsplit('.', 1) module = importlib.import_module(module) return getattr(module, name) @to_list def get_adapters(idp={}, **kwargs): idp = idp or {} adapters = tuple(idp.get('ADAPTER', ())) + tuple(app_settings.ADAPTER) for adapter in adapters: klass = import_object(adapter) yield klass(**kwargs) def get_values(saml_attributes, name): values = saml_attributes.get(name) if values is None: return () if not isinstance(values, (list, tuple)): return (values,) return values def get_setting(idp, name, default=None): """Get a parameter from an IdP specific configuration or from the main settings. """ return idp.get(name) or getattr(app_settings, name, default) def make_session_dump(lasso_name_id, indexes): session_infos = [] name_id = force_text(lasso_name_id.content) name_id_format = force_text(lasso_name_id.format) name_qualifier = lasso_name_id.nameQualifier and force_text(lasso_name_id.nameQualifier) sp_name_qualifier = lasso_name_id.spNameQualifier and force_text(lasso_name_id.spNameQualifier) for index in indexes: issuer = index.saml_identifier.issuer.entity_id session_infos.append( { 'entity_id': issuer, 'session_index': index.session_index, 'name_id_content': name_id, 'name_id_format': name_id_format, 'name_id_name_qualifier': name_qualifier, 'name_id_sp_name_qualifier': sp_name_qualifier, } ) session_dump = render_to_string('mellon/session_dump.xml', {'session_infos': session_infos}) return session_dump def create_logout(request): server = create_server(request) logout = lasso.Logout(server) if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS: logout.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) return logout def is_nonnull(s): return '\x00' not in s def same_origin(url1, url2): """ Checks if two URLs are 'same-origin' """ p1, p2 = urlparse(url1), urlparse(url2) if url1.startswith('/') or url2.startswith('/'): return True try: return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port) except ValueError: return False def get_status_codes_and_message(profile): assert profile, 'missing lasso.Profile' assert profile.response, 'missing response in profile' assert profile.response.status, 'missing status in response' from .views import lasso_decode status_codes = [] status = profile.response.status a = status while a.statusCode: status_codes.append(lasso_decode(a.statusCode.value)) a = a.statusCode message = None if status.statusMessage: message = lasso_decode(status.statusMessage) return status_codes, message def login(request, user): for adapter in get_adapters(): if hasattr(adapter, 'auth_login'): adapter.auth_login(request, user) break else: auth.login(request, user) def get_xml_encoding(content): xml_encoding = 'utf-8' def xmlDeclHandler(version, encoding, standalone): global xml_encoding if encoding: xml_encoding = encoding parser = expat.ParserCreate() parser.XmlDeclHandler = xmlDeclHandler try: parser.Parse(content, True) except expat.ExpatError as e: raise ValueError('invalid XML %s' % e) return xml_encoding def get_local_path(request, url): if not url: return parsed = urlparse(url) path = parsed.path if request.META.get('SCRIPT_NAME'): path = path[len(request.META['SCRIPT_NAME']) :] return path