django-mellon/mellon/utils.py

221 lines
7.5 KiB
Python

import logging
import datetime
import importlib
from functools import wraps
from xml.etree import ElementTree as ET
import requests
import requests.exceptions
import dateutil.parser
from django.core.urlresolvers import reverse
from django.template.loader import render_to_string
from django.utils.timezone import make_aware, now, make_naive, is_aware
from django.conf import settings
import lasso
from . import app_settings
def create_metadata(request):
entity_id = reverse('mellon_metadata')
cache = getattr(settings, '_MELLON_METADATA_CACHE', {})
if not entity_id in cache:
login_url = reverse('mellon_login')
logout_url = reverse('mellon_logout')
public_keys = []
for public_key in app_settings.PUBLIC_KEYS:
if public_key.startswith('/'):
# clean PEM file
public_key = ''.join(file(public_key).read().splitlines()[1:-1])
public_keys.append(public_key)
name_id_formats = app_settings.NAME_ID_FORMATS
cache[entity_id] = render_to_string('mellon/metadata.xml', {
'entity_id': request.build_absolute_uri(entity_id),
'login_url': request.build_absolute_uri(login_url),
'logout_url': request.build_absolute_uri(logout_url),
'public_keys': public_keys,
'name_id_formats': name_id_formats,
'default_assertion_consumer_binding': app_settings.DEFAULT_ASSERTION_CONSUMER_BINDING,
})
settings._MELLON_METADATA_CACHE = cache
return settings._MELLON_METADATA_CACHE[entity_id]
SERVERS = {}
def create_server(request):
logger = logging.getLogger(__name__)
root = request.build_absolute_uri('/')
if root not in SERVERS:
idps = get_idps()
metadata = create_metadata(request)
if app_settings.PRIVATE_KEY:
private_key = app_settings.PRIVATE_KEY
private_key_password = app_settings.PRIVATE_KEY_PASSWORD
elif app_settings.PRIVATE_KEYS:
private_key = app_settings.PRIVATE_KEYS
private_key_password = None
if isinstance(private_key, (tuple, list)):
private_key_password = private_key[1]
private_key = private_key[0]
else: # no signature
private_key = None
private_key_password = None
server = lasso.Server.newFromBuffers(metadata,
private_key_content=private_key,
private_key_password=private_key_password)
server.setEncryptionPrivateKeyWithPassword(private_key, private_key_password)
for key in app_settings.PRIVATE_KEYS:
password = None
if isinstance(key, (tuple, list)):
password = key[1]
key = key[0]
server.setEncryptionPrivateKeyWithPassword(key, password)
for i, idp in enumerate(idps):
if 'METADATA_URL' in idp and 'METADATA' not in idp:
verify_ssl_certificate = get_setting(
idp, 'VERIFY_SSL_CERTIFICATE')
try:
response = requests.get(idp['METADATA_URL'],
verify=verify_ssl_certificate)
response.raise_for_status()
except requests.exceptions.RequestException, e:
logger.error(u'retrieval of metadata URL %r failed with error %s for %d-th idp',
idp['METADATA_URL'], e, i)
continue
metadata = response.content
elif 'METADATA' in idp:
if idp['METADATA'].startswith('/'):
metadata = file(idp['METADATA']).read()
else:
logger.error(u'missing METADATA or METADATA_URL in %d-th idp', i)
continue
try:
server.addProviderFromBuffer(lasso.PROVIDER_ROLE_IDP, metadata)
except lasso.Error, e:
logger.error(u'bad metadata in %d-th idp', i)
logger.debug(u'lasso error: %s', e)
continue
idp['ENTITY_ID'] = ET.fromstring(metadata).attrib['entityID']
idp['METADATA'] = metadata
SERVERS[root] = server
return SERVERS[root]
def create_login(request):
server = create_server(request)
login = lasso.Login(server)
if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS:
login.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID)
return login
def get_idp(entity_id):
for adapter in get_adapters():
if hasattr(adapter, 'get_idp'):
idp = adapter.get_idp(entity_id)
if idp:
return idp
return {}
def get_idps():
for adapter in get_adapters():
if hasattr(adapter, 'get_idps'):
for idp in adapter.get_idps():
yield idp
def flatten_datetime(d):
d = d.copy()
for key, value in d.iteritems():
if isinstance(value, datetime.datetime):
d[key] = value.isoformat() + 'Z'
return d
def iso8601_to_datetime(date_string):
'''Convert a string formatted as an ISO8601 date into a time_t
value.
This function ignores the sub-second resolution'''
dt = dateutil.parser.parse(date_string)
if is_aware(dt):
if not settings.USE_TZ:
dt = make_naive(dt)
else:
if settings.USE_TZ:
dt = make_aware(dt)
return dt
def get_seconds_expiry(datetime_expiry):
return (datetime_expiry - now()).total_seconds()
def to_list(func):
@wraps(func)
def f(*args, **kwargs):
return list(func(*args, **kwargs))
return f
def import_object(path):
module, name = path.rsplit('.', 1)
module = importlib.import_module(module)
return getattr(module, name)
@to_list
def get_adapters(idp={}):
idp = idp or {}
adapters = tuple(idp.get('ADAPTER', ())) + tuple(app_settings.ADAPTER)
for adapter in adapters:
yield import_object(adapter)()
def get_values(saml_attributes, name):
values = saml_attributes.get(name)
if values is None:
return ()
if not isinstance(values, (list, tuple)):
return (values,)
return values
def get_setting(idp, name, default=None):
'''Get a parameter from an IdP specific configuration or from the main
settings.
'''
return idp.get(name) or getattr(app_settings, name, default)
def create_logout(request):
logger = logging.getLogger(__name__)
server = create_server(request)
mellon_session = request.session.get('mellon_session', {})
entity_id = mellon_session.get('issuer')
session_index = mellon_session.get('session_index')
name_id_format = mellon_session.get('name_id_format')
name_id_content = mellon_session.get('name_id_content')
name_id_name_qualifier = mellon_session.get('name_id_name_qualifier')
name_id_sp_name_qualifier = mellon_session.get('name_id_sp_name_qualifier')
session_dump = render_to_string('mellon/session_dump.xml', {
'entity_id': entity_id,
'session_index': session_index,
'name_id_format': name_id_format,
'name_id_content': name_id_content,
'name_id_name_qualifier': name_id_name_qualifier,
'name_id_sp_name_qualifier': name_id_sp_name_qualifier,
})
logger.debug('session_dump %s', session_dump)
logout = lasso.Logout(server)
if not app_settings.PRIVATE_KEY:
logout.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID)
logout.setSessionFromDump(session_dump)
return logout
def is_nonnull(s):
return not '\x00' in s