360 lines
11 KiB
Python
360 lines
11 KiB
Python
# django-mellon - SAML2 authentication for Django
|
|
# Copyright (C) 2014-2019 Entr'ouvert
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
|
|
import datetime
|
|
import importlib
|
|
import logging
|
|
from functools import wraps
|
|
from urllib.parse import urlparse
|
|
from xml.parsers import expat
|
|
|
|
import isodate
|
|
import lasso
|
|
from django.conf import settings
|
|
from django.contrib import auth
|
|
from django.template.loader import render_to_string
|
|
from django.urls import reverse
|
|
from django.utils.encoding import force_str
|
|
from django.utils.timezone import get_default_timezone, is_aware, make_aware, make_naive, now
|
|
|
|
from . import app_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CreateServerError(Exception):
|
|
pass
|
|
|
|
|
|
def create_metadata(request):
|
|
entity_id = reverse('mellon_metadata')
|
|
login_url = reverse(app_settings.LOGIN_URL)
|
|
logout_url = reverse(app_settings.LOGOUT_URL)
|
|
public_keys = []
|
|
for public_key in app_settings.PUBLIC_KEYS:
|
|
if public_key.startswith('/'):
|
|
# clean PEM file
|
|
content = open(public_key).read()
|
|
public_key = ''.join(content.splitlines()[1:-1])
|
|
public_keys.append(public_key)
|
|
name_id_formats = app_settings.NAME_ID_FORMATS
|
|
ctx = {
|
|
'request': request,
|
|
'entity_id': request.build_absolute_uri(entity_id),
|
|
'login_url': request.build_absolute_uri(login_url),
|
|
'logout_url': request.build_absolute_uri(logout_url),
|
|
'public_keys': public_keys,
|
|
'name_id_formats': name_id_formats,
|
|
'assertion_consumer_bindings': app_settings.ASSERTION_CONSUMER_BINDINGS,
|
|
'default_assertion_consumer_binding': app_settings.DEFAULT_ASSERTION_CONSUMER_BINDING,
|
|
'organization': app_settings.ORGANIZATION,
|
|
'contact_persons': app_settings.CONTACT_PERSONS,
|
|
}
|
|
if app_settings.METADATA_PUBLISH_DISCOVERY_RESPONSE:
|
|
ctx['discovery_endpoint_url'] = request.build_absolute_uri(reverse('mellon_login'))
|
|
return render_to_string('mellon/metadata.xml', ctx)
|
|
|
|
|
|
def create_server(request, remote_provider_id=None):
|
|
root = request.build_absolute_uri('/')
|
|
cache = getattr(settings, '_MELLON_SERVER_CACHE', {})
|
|
if root not in cache:
|
|
metadata = create_metadata(request)
|
|
if app_settings.PRIVATE_KEY:
|
|
private_key = app_settings.PRIVATE_KEY
|
|
private_key_password = app_settings.PRIVATE_KEY_PASSWORD
|
|
elif app_settings.PRIVATE_KEYS:
|
|
private_key = app_settings.PRIVATE_KEYS[0]
|
|
private_key_password = None
|
|
if isinstance(private_key, (tuple, list)):
|
|
private_key_password = private_key[1]
|
|
private_key = private_key[0]
|
|
else: # no signature
|
|
private_key = None
|
|
private_key_password = None
|
|
server = lasso.Server.newFromBuffers(
|
|
metadata, private_key_content=private_key, private_key_password=private_key_password
|
|
)
|
|
if not server:
|
|
raise CreateServerError
|
|
if app_settings.SIGNATURE_METHOD:
|
|
symbol_name = 'SIGNATURE_METHOD_' + app_settings.SIGNATURE_METHOD.replace('-', '_').upper()
|
|
if hasattr(lasso, symbol_name):
|
|
server.signatureMethod = getattr(lasso, symbol_name)
|
|
else:
|
|
logger.warning('mellon: unable to set signature method %s', app_settings.SIGNATURE_METHOD)
|
|
|
|
server.setEncryptionPrivateKeyWithPassword(private_key, private_key_password)
|
|
private_keys = app_settings.PRIVATE_KEYS
|
|
# skip first key if it is already loaded
|
|
if not app_settings.PRIVATE_KEY:
|
|
private_keys = app_settings.PRIVATE_KEYS[1:]
|
|
for key in private_keys:
|
|
password = None
|
|
if isinstance(key, (tuple, list)):
|
|
password = key[1]
|
|
key = key[0]
|
|
server.setEncryptionPrivateKeyWithPassword(key, password)
|
|
for idp in get_idps():
|
|
if remote_provider_id and idp.get('ENTITY_ID') != remote_provider_id:
|
|
continue
|
|
if idp and idp.get('METADATA'):
|
|
try:
|
|
server.addProviderFromBuffer(lasso.PROVIDER_ROLE_IDP, idp['METADATA'])
|
|
except lasso.Error as e:
|
|
logger.error('bad metadata in idp %s, %s', idp['ENTITY_ID'], e)
|
|
if not server.providers and remote_provider_id:
|
|
logger.warning('mellon: create_server, no provider found for issuer %r', remote_provider_id)
|
|
cache[root] = server
|
|
settings._MELLON_SERVER_CACHE = cache
|
|
return cache.get(root)
|
|
|
|
|
|
def create_login(request):
|
|
server = create_server(request)
|
|
login = lasso.Login(server)
|
|
if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS:
|
|
login.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID)
|
|
return login
|
|
|
|
|
|
def get_idp(entity_id):
|
|
for adapter in get_adapters():
|
|
if hasattr(adapter, 'get_idp'):
|
|
idp = adapter.get_idp(entity_id)
|
|
if idp:
|
|
return idp
|
|
return {}
|
|
|
|
|
|
def get_idps():
|
|
for adapter in get_adapters():
|
|
if hasattr(adapter, 'get_idps'):
|
|
yield from adapter.get_idps()
|
|
|
|
|
|
def flatten_datetime(d):
|
|
d = d.copy()
|
|
for key, value in d.items():
|
|
if isinstance(value, datetime.datetime):
|
|
d[key] = value.isoformat()
|
|
return d
|
|
|
|
|
|
def iso8601_to_datetime(date_string, default=None):
|
|
"""Convert a string formatted as an ISO8601 date into a datetime
|
|
value.
|
|
|
|
This function ignores the sub-second resolution"""
|
|
try:
|
|
dt = isodate.parse_datetime(date_string)
|
|
except Exception:
|
|
return default
|
|
if is_aware(dt):
|
|
if not settings.USE_TZ:
|
|
dt = make_naive(dt, get_default_timezone())
|
|
else:
|
|
if settings.USE_TZ:
|
|
dt = make_aware(dt, get_default_timezone())
|
|
return dt
|
|
|
|
|
|
def get_seconds_expiry(datetime_expiry):
|
|
return (datetime_expiry - now()).total_seconds()
|
|
|
|
|
|
def to_list(func):
|
|
@wraps(func)
|
|
def f(*args, **kwargs):
|
|
return list(func(*args, **kwargs))
|
|
|
|
return f
|
|
|
|
|
|
def import_object(path):
|
|
module, name = path.rsplit('.', 1)
|
|
module = importlib.import_module(module)
|
|
return getattr(module, name)
|
|
|
|
|
|
@to_list
|
|
def get_adapters(idp={}, **kwargs):
|
|
idp = idp or {}
|
|
adapters = tuple(idp.get('ADAPTER', ())) + tuple(app_settings.ADAPTER)
|
|
for adapter in adapters:
|
|
klass = import_object(adapter)
|
|
yield klass(**kwargs)
|
|
|
|
|
|
def get_values(saml_attributes, name):
|
|
values = saml_attributes.get(name)
|
|
if values is None:
|
|
return ()
|
|
if not isinstance(values, (list, tuple)):
|
|
return (values,)
|
|
return values
|
|
|
|
|
|
def get_setting(idp, name, default=None):
|
|
"""Get a parameter from an IdP specific configuration or from the main
|
|
settings.
|
|
"""
|
|
return idp.get(name) or getattr(app_settings, name, default)
|
|
|
|
|
|
def make_session_dump(session_indexes):
|
|
return render_to_string('mellon/session_dump.xml', {'session_indexes': session_indexes})
|
|
|
|
|
|
def create_logout(request):
|
|
server = create_server(request)
|
|
logout = lasso.Logout(server)
|
|
if not app_settings.PRIVATE_KEY and not app_settings.PRIVATE_KEYS:
|
|
logout.setSignatureHint(lasso.PROFILE_SIGNATURE_HINT_FORBID)
|
|
return logout
|
|
|
|
|
|
def is_nonnull(s):
|
|
return '\x00' not in s
|
|
|
|
|
|
PROTOCOLS_TO_PORT = {
|
|
'http': '80',
|
|
'https': '443',
|
|
}
|
|
|
|
|
|
def netloc_to_host_port(netloc):
|
|
if not netloc:
|
|
return None, None
|
|
splitted = netloc.split(':', 1)
|
|
if len(splitted) > 1:
|
|
return splitted[0], splitted[1]
|
|
return splitted[0], None
|
|
|
|
|
|
def same_domain(domain1, domain2):
|
|
if domain1 == domain2:
|
|
return True
|
|
|
|
if not domain1 or not domain2:
|
|
return False
|
|
|
|
if domain2.startswith('.'):
|
|
# p1 is a sub-domain or the base domain
|
|
if domain1.endswith(domain2) or domain1 == domain2[1:]:
|
|
return True
|
|
return False
|
|
|
|
|
|
def same_origin(url1, url2):
|
|
"""Checks if both URL use the same domain. It understands domain patterns on url2, i.e. .example.com
|
|
matches www.example.com.
|
|
|
|
If not scheme is given in url2, scheme compare is skipped.
|
|
If not scheme and not port are given, port compare is skipped.
|
|
The last two rules allow authorizing complete domains easily.
|
|
"""
|
|
p1, p2 = urlparse(url1), urlparse(url2)
|
|
p1_host, p1_port = netloc_to_host_port(p1.netloc)
|
|
p2_host, p2_port = netloc_to_host_port(p2.netloc)
|
|
|
|
# url2 is relative, always same domain
|
|
if url2.startswith('/') and not url2.startswith('//'):
|
|
return True
|
|
|
|
if p2.scheme and p1.scheme != p2.scheme:
|
|
return False
|
|
|
|
if not same_domain(p1_host, p2_host):
|
|
return False
|
|
|
|
try:
|
|
if (p2_port or (p1_port and p2.scheme)) and (
|
|
(p1_port or PROTOCOLS_TO_PORT[p1.scheme]) != (p2_port or PROTOCOLS_TO_PORT[p2.scheme])
|
|
):
|
|
return False
|
|
except (ValueError, KeyError):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_status_codes_and_message(profile):
|
|
assert profile, 'missing lasso.Profile'
|
|
assert profile.response, 'missing response in profile'
|
|
assert profile.response.status, 'missing status in response'
|
|
|
|
from .views import lasso_decode
|
|
|
|
status_codes = []
|
|
|
|
status = profile.response.status
|
|
a = status
|
|
while a.statusCode:
|
|
status_codes.append(lasso_decode(a.statusCode.value))
|
|
a = a.statusCode
|
|
message = None
|
|
if status.statusMessage:
|
|
message = lasso_decode(status.statusMessage)
|
|
return status_codes, message
|
|
|
|
|
|
def login(request, user):
|
|
for adapter in get_adapters():
|
|
if hasattr(adapter, 'auth_login'):
|
|
adapter.auth_login(request, user)
|
|
break
|
|
else:
|
|
auth.login(request, user)
|
|
|
|
|
|
def get_xml_encoding(content):
|
|
xml_encoding = 'utf-8'
|
|
|
|
def xmlDeclHandler(version, encoding, standalone):
|
|
global xml_encoding
|
|
|
|
if encoding:
|
|
xml_encoding = encoding
|
|
|
|
parser = expat.ParserCreate()
|
|
parser.XmlDeclHandler = xmlDeclHandler
|
|
try:
|
|
parser.Parse(content, True)
|
|
except expat.ExpatError as e:
|
|
raise ValueError('invalid XML %s' % e)
|
|
return xml_encoding
|
|
|
|
|
|
def get_local_path(request, url):
|
|
if not url:
|
|
return
|
|
parsed = urlparse(url)
|
|
path = parsed.path
|
|
if request.META.get('SCRIPT_NAME'):
|
|
path = path[len(request.META['SCRIPT_NAME']) :]
|
|
return path
|
|
|
|
|
|
def is_slo_supported(request, issuer):
|
|
server = create_server(request, remote_provider_id=issuer)
|
|
# verify that at least one logout method is supported
|
|
return (
|
|
server.getFirstHttpMethod(server.providers[issuer], lasso.MD_PROTOCOL_TYPE_SINGLE_LOGOUT)
|
|
!= lasso.HTTP_METHOD_NONE
|
|
)
|