import logging import datetime import importlib from functools import wraps import isodate from xml.parsers import expat from django.contrib import auth from django.contrib.auth.models import Group from django.core.urlresolvers import reverse from django.http import QueryDict, HttpResponseRedirect from django.template.loader import render_to_string from django.utils.timezone import make_aware, now, make_naive, is_aware, get_default_timezone from django.conf import settings from django.utils.six.moves.urllib.parse import urlparse import lasso from . import app_settings from .exceptions import RolesNotInSession def create_metadata(request): entity_id = reverse('mellon_metadata') login_url = reverse(app_settings.LOGIN_URL) logout_url = reverse(app_settings.LOGOUT_URL) public_keys = [] for public_key in app_settings.PUBLIC_KEYS: if public_key.startswith('/'): # clean PEM file content = open(public_key).read() public_key = ''.join(content.splitlines()[1:-1]) public_keys.append(public_key) name_id_formats = app_settings.NAME_ID_FORMATS return render_to_string('mellon/metadata.xml', { 'entity_id': request.build_absolute_uri(entity_id), 'login_url': request.build_absolute_uri(login_url), 'logout_url': request.build_absolute_uri(logout_url), 'public_keys': public_keys, 'name_id_formats': name_id_formats, 'default_assertion_consumer_binding': app_settings.DEFAULT_ASSERTION_CONSUMER_BINDING, 'organization': app_settings.ORGANIZATION, 'contact_persons': app_settings.CONTACT_PERSONS, 'discovery_endpoint_url': request.build_absolute_uri(reverse('mellon_login')), }) def create_server(request): logger = logging.getLogger(__name__) root = request.build_absolute_uri('/') cache = getattr(settings, '_MELLON_SERVER_CACHE', {}) if root not in cache: metadata = create_metadata(request) if app_settings.PRIVATE_KEY: private_key = app_settings.PRIVATE_KEY private_key_password = app_settings.PRIVATE_KEY_PASSWORD elif app_settings.PRIVATE_KEYS: private_key = app_settings.PRIVATE_KEYS[0] private_key_password = None if isinstance(private_key, (tuple, list)): private_key_password = private_key[1] private_key = private_key[0] else: # no signature private_key = None private_key_password = None server = lasso.Server.newFromBuffers(metadata, private_key_content=private_key, private_key_password=private_key_password) if app_settings.SIGNATURE_METHOD: symbol_name = 'SIGNATURE_METHOD_' + app_settings.SIGNATURE_METHOD.replace('-', '_').upper() if hasattr(lasso, symbol_name): server.signatureMethod = getattr(lasso, symbol_name) else: logger.warning('mellon: unable to set signature method %s', app_settings.SIGNATURE_METHOD) server.setEncryptionPrivateKeyWithPassword(private_key, private_key_password) private_keys = app_settings.PRIVATE_KEYS # skip first key if it is already loaded if not app_settings.PRIVATE_KEY: private_keys = app_settings.PRIVATE_KEYS[1:] for key in private_keys: password = None if isinstance(key, (tuple, list)): password = key[1] key = key[0] server.setEncryptionPrivateKeyWithPassword(key, password) for idp in get_idps(): try: server.addProviderFromBuffer(lasso.PROVIDER_ROLE_IDP, idp['METADATA']) except lasso.Error as e: logger.error(u'bad metadata in idp %r', idp['ENTITY_ID']) logger.debug(u'lasso error: %s', e) continue cache[root] = server settings._MELLON_SERVER_CACHE = cache return settings._MELLON_SERVER_CACHE.get(root) def create_login(request): server = create_server(request) login = lasso.Login(server) if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS: login.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) return login def get_idp(entity_id): for adapter in get_adapters(): if hasattr(adapter, 'get_idp'): idp = adapter.get_idp(entity_id) if idp: return idp return {} def get_idps(): for adapter in get_adapters(): if hasattr(adapter, 'get_idps'): for idp in adapter.get_idps(): yield idp def flatten_datetime(d): d = d.copy() for key, value in d.items(): if isinstance(value, datetime.datetime): d[key] = value.isoformat() return d def iso8601_to_datetime(date_string, default=None): '''Convert a string formatted as an ISO8601 date into a datetime value. This function ignores the sub-second resolution''' try: dt = isodate.parse_datetime(date_string) except: return default if is_aware(dt): if not settings.USE_TZ: dt = make_naive(dt, get_default_timezone()) else: if settings.USE_TZ: dt = make_aware(dt, get_default_timezone()) return dt def get_seconds_expiry(datetime_expiry): return (datetime_expiry - now()).total_seconds() def to_list(func): @wraps(func) def f(*args, **kwargs): return list(func(*args, **kwargs)) return f def import_object(path): module, name = path.rsplit('.', 1) module = importlib.import_module(module) return getattr(module, name) @to_list def get_adapters(idp={}): idp = idp or {} adapters = tuple(idp.get('ADAPTER', ())) + tuple(app_settings.ADAPTER) for adapter in adapters: yield import_object(adapter)() def get_values(saml_attributes, name): values = saml_attributes.get(name) if values is None: return () if not isinstance(values, (list, tuple)): return (values,) return values def get_setting(idp, name, default=None): '''Get a parameter from an IdP specific configuration or from the main settings. ''' return idp.get(name) or getattr(app_settings, name, default) def create_logout(request): logger = logging.getLogger(__name__) server = create_server(request) mellon_session = request.session.get('mellon_session', {}) entity_id = mellon_session.get('issuer') session_index = mellon_session.get('session_index') name_id_format = mellon_session.get('name_id_format') name_id_content = mellon_session.get('name_id_content') name_id_name_qualifier = mellon_session.get('name_id_name_qualifier') name_id_sp_name_qualifier = mellon_session.get('name_id_sp_name_qualifier') session_dump = render_to_string('mellon/session_dump.xml', { 'entity_id': entity_id, 'session_index': session_index, 'name_id_format': name_id_format, 'name_id_content': name_id_content, 'name_id_name_qualifier': name_id_name_qualifier, 'name_id_sp_name_qualifier': name_id_sp_name_qualifier, }) logger.debug('session_dump %s', session_dump) logout = lasso.Logout(server) if not app_settings.PRIVATE_KEY: logout.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID) logout.setSessionFromDump(session_dump) return logout def is_nonnull(s): return not '\x00' in s def same_origin(url1, url2): """ Checks if two URLs are 'same-origin' """ p1, p2 = urlparse(url1), urlparse(url2) if url1.startswith('/') or url2.startswith('/'): return True try: return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port) except ValueError: return False def get_status_codes_and_message(profile): assert profile, 'missing lasso.Profile' assert profile.response, 'missing response in profile' assert profile.response.status, 'missing status in response' from .views import lasso_decode status_codes = [] status = profile.response.status a = status while a.statusCode: status_codes.append(lasso_decode(a.statusCode.value)) a = a.statusCode message = None if status.statusMessage: message = lasso_decode(status.statusMessage) return status_codes, message def login(request, user): for adapter in get_adapters(): if hasattr(adapter, 'auth_login'): adapter.auth_login(request, user) break else: auth.login(request, user) def get_xml_encoding(content): xml_encoding = 'utf-8' def xmlDeclHandler(version, encoding, standalone): xml_encoding = encoding parser = expat.ParserCreate() parser.XmlDeclHandler = xmlDeclHandler parser.Parse(content, True) return xml_encoding def get_local_path(request, url): if not url: return parsed = urlparse(url) path = parsed.path if request.META.get('SCRIPT_NAME'): path = path[len(request.META['SCRIPT_NAME']):] return path def has_superuser_flag(idp, saml_attributes): superuser_mapping = get_setting(idp, 'SUPERUSER_MAPPING') if not superuser_mapping: return False for key, values in superuser_mapping.items(): if key in saml_attributes: if not isinstance(values, (tuple, list)): values = [values] values = set(values) attribute_values = saml_attributes[key] if not isinstance(attribute_values, (tuple, list)): attribute_values = [attribute_values] attribute_values = set(attribute_values) if attribute_values & values: return True return False def user_has_roles(request, roles): if request.user.is_staff and request.session.get('is_staff'): return True groups = set(roles).intersection(request.user.groups.all()) if not groups: if request.user.is_staff: raise RolesNotInSession(('staff',)) return False role_uuids = {getattr(group, 'role').uuid for group in groups} if not role_uuids: return True if set(request.session['mellon_session']['role-slug']) & role_uuids: return True raise RolesNotInSession(role_uuids) def user_has_role(request, role): return user_has_roles(request, {role}) def get_role_request_url(request, roles): login_url = reverse(app_settings.LOGIN_URL) q = QueryDict(mutable=True) q.setlist('roles', roles) q['next'] = request.get_full_path() return '?'.join((login_url, q.urlencode(safe='/')))