import logging import datetime import importlib from functools import wraps import isodate from django.contrib import auth from django.core.exceptions import ValidationError from django.core.urlresolvers import reverse from django.core.validators import URLValidator from django.template.loader import render_to_string from django.utils.text import slugify from django.utils.timezone import make_aware, now, make_naive, is_aware, get_default_timezone from django.conf import settings from django.utils.six.moves.urllib.parse import urlparse import lasso from . import app_settings from federation_utils import get_federation_from_url, idp_metadata_is_file def create_metadata(request): entity_id = reverse('mellon_metadata') cache = getattr(settings, '_MELLON_METADATA_CACHE', {}) if not entity_id in cache: login_url = reverse(app_settings.LOGIN_URL) logout_url = reverse(app_settings.LOGOUT_URL) public_keys = [] for public_key in app_settings.PUBLIC_KEYS: if public_key.startswith('/'): # clean PEM file public_key = ''.join(file(public_key).read().splitlines()[1:-1]) public_keys.append(public_key) name_id_formats = app_settings.NAME_ID_FORMATS cache[entity_id] = render_to_string('mellon/metadata.xml', { 'entity_id': request.build_absolute_uri(entity_id), 'login_url': request.build_absolute_uri(login_url), 'logout_url': request.build_absolute_uri(logout_url), 'public_keys': public_keys, 'name_id_formats': name_id_formats, 'default_assertion_consumer_binding': app_settings.DEFAULT_ASSERTION_CONSUMER_BINDING, 'organization': app_settings.ORGANIZATION, 'contact_persons': app_settings.CONTACT_PERSONS, 'discovery_endpoint_url': request.build_absolute_uri(reverse('mellon_login')), }) settings._MELLON_METADATA_CACHE = cache return settings._MELLON_METADATA_CACHE[entity_id] SERVERS = {} def create_server(request): logger = logging.getLogger(__name__) root = request.build_absolute_uri('/') metadata = create_metadata(request) if app_settings.PRIVATE_KEY: private_key = app_settings.PRIVATE_KEY private_key_password = app_settings.PRIVATE_KEY_PASSWORD elif app_settings.PRIVATE_KEYS: private_key = app_settings.PRIVATE_KEYS[0] private_key_password = None if isinstance(private_key, (tuple, list)): private_key_password = private_key[1] private_key = private_key[0] else: # no signature private_key = None private_key_password = None server = lasso.Server.newFromBuffers(metadata, private_key_content=private_key, private_key_password=private_key_password) server.setEncryptionPrivateKeyWithPassword(private_key, private_key_password) private_keys = app_settings.PRIVATE_KEYS # skip first key if it is already loaded if not app_settings.PRIVATE_KEY: private_keys = app_settings.PRIVATE_KEYS[1:] for key in private_keys: password = None if isinstance(key, (tuple, list)): password = key[1] key = key[0] server.setEncryptionPrivateKeyWithPassword(key, password) for idp in get_idps(): try: metadata = idp.get('METADATA') if idp_metadata_is_file(metadata): with open(metadata, 'r') as f: metadata = f.read() server.addProviderFromBuffer(lasso.PROVIDER_ROLE_IDP, metadata) except lasso.Error, e: logger.error(u'bad metadata in idp %r', idp) logger.debug(u'lasso error: %s', e) except IOError, e: logger.warning('No such metadata file: %r', metadata) continue return server def get_federation_metadata(federation): logger = logging.getLogger(__name__) fedmd = None pemcert = None if (isinstance(federation, tuple) and len(federation) == 2): logger.info('Loading local cert-based federation %r', federation) if federation[1].endswith('.pem'): fedmd = federation[0] pemcert = federation[1] else: urlval = URLValidator() try: urlval(federation) except ValidationError as e: logger.info('Loading file-based federation %s', federation) fedmd = federation else: logger.info('Fetching and loading url-based federation %s', federation) fedmd = get_federation_from_url(federation) return (fedmd, pemcert) def create_login(request): server = create_server(request) login = lasso.Login(server) if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS: login.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) return login def get_idp(entity_id): for adapter in get_adapters(): if hasattr(adapter, 'get_idp'): idp = adapter.get_idp(entity_id) if idp: return idp return {} def get_idps(): for adapter in get_adapters(): if hasattr(adapter, 'get_idps'): for idp in adapter.get_idps(): yield idp def get_federations(): for adapter in get_adapters(): if hasattr(adapter, 'get_federations'): for federation in adapter.get_federations(): yield federation def flatten_datetime(d): d = d.copy() for key, value in d.iteritems(): if isinstance(value, datetime.datetime): d[key] = value.isoformat() return d def iso8601_to_datetime(date_string, default=None): '''Convert a string formatted as an ISO8601 date into a datetime value. This function ignores the sub-second resolution''' try: dt = isodate.parse_datetime(date_string) except: return default if is_aware(dt): if not settings.USE_TZ: dt = make_naive(dt, get_default_timezone()) else: if settings.USE_TZ: dt = make_aware(dt, get_default_timezone()) return dt def get_seconds_expiry(datetime_expiry): return (datetime_expiry - now()).total_seconds() def to_list(func): @wraps(func) def f(*args, **kwargs): return list(func(*args, **kwargs)) return f def import_object(path): module, name = path.rsplit('.', 1) module = importlib.import_module(module) return getattr(module, name) @to_list def get_adapters(idp={}): idp = idp or {} adapters = tuple(idp.get('ADAPTER', ())) + tuple(app_settings.ADAPTER) for adapter in adapters: yield import_object(adapter)() def get_values(saml_attributes, name): values = saml_attributes.get(name) if values is None: return () if not isinstance(values, (list, tuple)): return (values,) return values def get_setting(idp, name, default=None): '''Get a parameter from an IdP specific configuration or from the main settings. ''' return idp.get(name) or getattr(app_settings, name, default) def create_logout(request): logger = logging.getLogger(__name__) server = create_server(request) mellon_session = request.session.get('mellon_session', {}) entity_id = mellon_session.get('issuer') session_index = mellon_session.get('session_index') name_id_format = mellon_session.get('name_id_format') name_id_content = mellon_session.get('name_id_content') name_id_name_qualifier = mellon_session.get('name_id_name_qualifier') name_id_sp_name_qualifier = mellon_session.get('name_id_sp_name_qualifier') session_dump = render_to_string('mellon/session_dump.xml', { 'entity_id': entity_id, 'session_index': session_index, 'name_id_format': name_id_format, 'name_id_content': name_id_content, 'name_id_name_qualifier': name_id_name_qualifier, 'name_id_sp_name_qualifier': name_id_sp_name_qualifier, }) logger.debug('session_dump %s', session_dump) logout = lasso.Logout(server) if not app_settings.PRIVATE_KEY: logout.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) logout.setSessionFromDump(session_dump) return logout def is_nonnull(s): return not '\x00' in s def same_origin(url1, url2): """ Checks if two URLs are 'same-origin' """ p1, p2 = urlparse(url1), urlparse(url2) if url1.startswith('/') or url2.startswith('/'): return True try: return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port) except ValueError: return False def get_status_codes_and_message(profile): assert profile, 'missing lasso.Profile' assert profile.response, 'missing response in profile' assert profile.response.status, 'missing status in response' status_codes = [] status = profile.response.status a = status while a.statusCode: status_codes.append(a.statusCode.value.decode('utf-8')) a = a.statusCode message = None if status.statusMessage: message = status.statusMessage.decode('utf-8') return status_codes, message def login(request, user): for adapter in get_adapters(): if hasattr(adapter, 'auth_login'): adapter.auth_login(request, user) break else: auth.login(request, user)