django-mellon/mellon/adapters.py

616 lines
24 KiB
Python

# django-mellon - SAML2 authentication for Django
# Copyright (C) 2014-2019 Entr'ouvert
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib
import logging
import os
import threading
import time
import uuid
from urllib.parse import urlparse
from xml.etree import ElementTree as ET
import lasso
import requests
import requests.exceptions
from atomicwrites import atomic_write
from django.contrib import auth, messages
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.core.exceptions import FieldDoesNotExist, PermissionDenied
from django.core.files.storage import default_storage
from django.utils.translation import gettext as _
from . import app_settings, models, models_utils, utils
User = auth.get_user_model()
logger = logging.getLogger(__name__)
class UserCreationError(Exception):
pass
def display_truncated_list(l, max_length=10):
s = '[' + ', '.join(map(str, l))
if len(l) > max_length:
s += '..truncated more than %d items (%d)]' % (max_length, len(l))
else:
s += ']'
return s
class DefaultAdapter:
user_class = get_user_model()
def __init__(self, request=None):
self.request = request
def get_idp(self, entity_id):
'''Find the first IdP definition matching entity_id'''
for idp in self.get_idps():
if entity_id == idp['ENTITY_ID']:
return idp
def get_identity_providers_setting(self):
return app_settings.IDENTITY_PROVIDERS
def get_users_queryset(self, idp, saml_attributes):
return User.objects.all()
def get_idps(self):
for i, idp in enumerate(self.get_identity_providers_setting()):
if self.load_idp(idp, i):
yield idp
def load_metadata_path(self, idp, i):
now = time.time()
path = idp['METADATA_PATH']
if not os.path.exists(path):
logger.error('mellon: metadata path %s does not exist', path)
return
metadata_last_update = idp.get('METADATA_LAST_UPDATE', None)
if not idp.get('METADATA'):
should_load = True
elif not metadata_last_update:
should_load = True
else:
try:
mtime = os.stat(path).st_mtime
except OSError as e:
logger.error('mellon: metadata path %s : stat() call failed, %s', path, e)
return
should_load = metadata_last_update <= mtime
if not should_load:
return
try:
with open(path) as fd:
metadata = fd.read()
except OSError as e:
logger.error('mellon: metadata path %s : open()/read() call failed, %s', path, e)
return
entity_id = self.load_entity_id(metadata, i)
if not entity_id:
logger.error('mellon: invalid metadata file retrieved from %s', path)
return
if 'ENTITY_ID' in idp and idp['ENTITY_ID'] != entity_id:
raise RuntimeError(
f'mellon: metadata path {path} : entityID changed {entity_id!r} != {idp["ENTITY_ID"]!r}'
)
idp['METADATA'] = metadata
idp['METADATA_LAST_UPDATE'] = now
def load_metadata_url(self, idp, i):
now = time.time()
url = idp['METADATA_URL']
metadata_cache_time = utils.get_setting(idp, 'METADATA_CACHE_TIME')
timeout = utils.get_setting(idp, 'METADATA_HTTP_TIMEOUT')
metadata_last_update = idp.get('METADATA_LAST_UPDATE', None)
metadata_last_url_update = idp.get('METADATA_LAST_URL_CHECK', 0)
try:
hostname = urlparse(url).hostname
except (ValueError, TypeError) as e:
logger.error('invalid METADATA_URL %r: %s', url, e)
return
if not hostname:
logger.error('no hostname in METADATA_URL %r', url)
return
try:
url_fingerprint = hashlib.md5(url.encode('ascii')).hexdigest()
file_cache_key = '%s_%s.xml' % (hostname, url_fingerprint)
except (UnicodeError, TypeError, ValueError):
logger.error('unable to compute file_cache_key')
return
cache_directory = default_storage.path('mellon_metadata_cache')
file_cache_path = os.path.join(cache_directory, file_cache_key)
if not os.path.exists(cache_directory):
os.makedirs(cache_directory)
if os.path.exists(file_cache_path) and 'METADATA' not in idp:
try:
with open(file_cache_path) as fd:
metadata = fd.read()
# use file cache mtime as last_update time, prevent too many loading from different workers
except OSError:
logger.error('metadata url %s : error when loading the file cache %s', url, file_cache_path)
else:
idp['METADATA'] = metadata
metadata_last_update = idp['METADATA_LAST_UPDATE'] = os.stat(file_cache_path).st_mtime
# fresh cache, skip loading
if (
metadata_last_update
and 'METADATA' in idp
and (
(now - metadata_last_update) < metadata_cache_time
# if HTTP GET is failing, try every 5 minutes
or (now - metadata_last_url_update) < 60 * 5
)
):
return
def __http_get():
try:
verify_ssl_certificate = utils.get_setting(idp, 'VERIFY_SSL_CERTIFICATE')
try:
response = requests.get(url, verify=verify_ssl_certificate, timeout=timeout)
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.warning('metadata url %s : HTTP request failed %s', url, e)
return
entity_id = self.load_entity_id(response.text, i)
if not entity_id:
logger.warning('invalid metadata file retrieved from %s', url)
return
if 'ENTITY_ID' in idp and idp['ENTITY_ID'] != entity_id:
# entityID change is always en error
logger.error(
'mellon: metadata url %s, entityID changed %r != %r', url, entity_id, idp['ENTITY_ID']
)
return
if idp.get('METADATA') != response.text or not os.path.exists(file_cache_path):
idp['METADATA'] = response.text
logger.info('mellon: metadata url %s updated', url)
idp['METADATA_LAST_UPDATE'] = now
try:
with atomic_write(file_cache_path, mode='wb', overwrite=True) as fd:
fd.write(response.text.encode('utf-8'))
except OSError as e:
logger.error(
'mellon: metadata url %s : could not write file cache %s, %s',
url,
file_cache_path,
e,
)
else:
os.utime(file_cache_path)
finally:
# release thread object
idp.pop('METADATA_URL_UPDATE_THREAD', None)
if metadata_last_update:
# we have cache, update in background
idp['METADATA_LAST_URL_CHECK'] = time.time()
t = threading.Thread(target=__http_get)
t.start()
# store thread in idp for tests
idp['METADATA_URL_UPDATE_THREAD'] = t
else:
# synchronous update
__http_get()
def load_metadata(self, idp, i):
now = time.time()
if 'METADATA' in idp and idp['METADATA'].startswith('/'):
# support legacy configuration
idp['METADATA_PATH'] = idp['METADATA']
del idp['METADATA']
metadata_cache_time = utils.get_setting(idp, 'METADATA_CACHE_TIME')
metadata_last_update = idp.get('METADATA_LAST_UPDATE', None)
age = metadata_last_update and (now - metadata_last_update)
if age is None or age > metadata_cache_time:
if 'METADATA_PATH' in idp:
self.load_metadata_path(idp, i)
elif 'METADATA_URL' in idp:
self.load_metadata_url(idp, i)
if 'METADATA' in idp:
if 'ENTITY_ID' not in idp:
entity_id = self.load_entity_id(idp['METADATA'], i)
if entity_id:
idp['ENTITY_ID'] = entity_id
if age and age > (24 * metadata_cache_time):
logger.error(
'mellon: metadata for idp %s: not updated since %.1f hours',
idp.get('ENTITY_ID', i),
age / 3600.0,
)
if 'ENTITY_ID' in idp:
return idp['METADATA']
logger.error('mellon: could not load metadata for %d-th idp: %s', i, idp)
return None
def load_entity_id(self, metadata, i):
try:
doc = ET.fromstring(metadata)
except (TypeError, ET.ParseError):
return None
if doc.tag != '{%s}EntityDescriptor' % lasso.SAML2_METADATA_HREF:
return None
if 'entityID' not in doc.attrib:
return None
return doc.attrib['entityID']
def load_idp(self, idp, i):
self.load_metadata(idp, i)
return 'ENTITY_ID' in idp
def authorize(self, idp, saml_attributes):
if not idp:
return False
required_classref = utils.get_setting(idp, 'AUTHN_CLASSREF')
if required_classref:
given_classref = saml_attributes['authn_context_class_ref']
if given_classref is None or given_classref not in required_classref:
logger.info(
'mellon: refused login because of authn_classref mismatch (%r vs %s)',
given_classref,
required_classref,
)
raise PermissionDenied
return True
def format_username(self, idp, saml_attributes):
realm = utils.get_setting(idp, 'REALM')
username_template = utils.get_setting(idp, 'USERNAME_TEMPLATE')
try:
username = username_template.format(realm=realm, attributes=saml_attributes, idp=idp)[
: self.user_class._meta.get_field('username').max_length
]
except ValueError:
logger.error('mellon: invalid username template %r', username_template)
except (AttributeError, KeyError, IndexError) as e:
logger.error('mellon: invalid reference in username template %r: %s', username_template, e)
except Exception:
logger.exception('mellon: unknown error when formatting username')
else:
return username
def create_user(self, user_class):
return user_class.objects.create(username=uuid.uuid4().hex[:30])
def finish_create_user(self, idp, saml_attributes, user):
username = self.format_username(idp, saml_attributes)
if not username:
raise UserCreationError
user.username = username
user.save()
def lookup_user(self, idp, saml_attributes):
transient_federation_attribute = utils.get_setting(idp, 'TRANSIENT_FEDERATION_ATTRIBUTE')
if saml_attributes['name_id_format'] == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
if transient_federation_attribute and saml_attributes.get(transient_federation_attribute):
name_id = saml_attributes[transient_federation_attribute]
if not isinstance(name_id, str):
if len(name_id) == 1:
name_id = name_id[0]
else:
logger.warning(
'mellon: more than one value for attribute %r, cannot federate',
transient_federation_attribute,
)
return None
saml_attributes['transient_name_id_content'] = name_id
else:
if self.request:
messages.warning(
self.request,
_('A transient NameID was received but TRANSIENT_FEDERATION_ATTRIBUTE is not set.'),
)
logger.warning(
'mellon: transient NameID was received but TRANSIENT_FEDERATION_ATTRIBUTE is not set'
)
return None
else:
name_id = saml_attributes['name_id_content']
entity_id = saml_attributes['issuer']
try:
to_update = {
'nid_format': saml_attributes['name_id_format'],
'nid_name_qualifier': saml_attributes.get('name_id_name_qualifier'),
'nid_sp_name_qualifier': saml_attributes.get('name_id_sp_name_qualifier'),
'nid_sp_provided_id': saml_attributes.get('name_id_sp_provided_id'),
}
saml_identifier = models.UserSAMLIdentifier.objects.select_related('user').get(
name_id=name_id, issuer=models_utils.get_issuer(entity_id)
)
# nid_* attributes are new, we must update them if they are not initialized, eventually
for key, value in to_update.items():
if getattr(saml_identifier, key) != value:
models.UserSAMLIdentifier.objects.filter(pk=saml_identifier.pk).update(**to_update)
break
user = saml_identifier.user
user.saml_identifier = saml_identifier
logger.info('mellon: looked up user %s with name_id %s from issuer %s', user, name_id, entity_id)
return user
except models.UserSAMLIdentifier.DoesNotExist:
pass
user = self.lookup_by_attributes(idp, saml_attributes)
created = False
if not user:
if not utils.get_setting(idp, 'PROVISION'):
logger.warning('mellon: login refused for %s, PROVISION is disabled', saml_attributes)
return None
created = True
user = self.create_user(User)
nameid_user = self._link_user(idp, saml_attributes, user)
if user != nameid_user:
logger.info(
'mellon: looked up user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
)
if created:
user.delete()
return nameid_user
if created:
try:
self.finish_create_user(idp, saml_attributes, nameid_user)
except UserCreationError:
user.delete()
return None
logger.info(
'mellon: created new user %s with name_id %s from issuer %s', nameid_user, name_id, entity_id
)
return nameid_user
def lookup_by_attributes(self, idp, saml_attributes):
rules = utils.get_setting(idp, 'LOOKUP_BY_ATTRIBUTES')
if rules:
return self._lookup_by_attributes(idp, saml_attributes, rules)
return None
def _lookup_by_attributes(self, idp, saml_attributes, lookup_by_attributes):
if not isinstance(lookup_by_attributes, list):
logger.error(
'mellon: invalid LOOKUP_BY_ATTRIBUTES configuration %r: it must be a list',
lookup_by_attributes,
)
return None
users = set()
for line in lookup_by_attributes:
if not isinstance(line, dict):
logger.error(
'mellon: invalid LOOKUP_BY_ATTRIBUTES configuration %r: it must be a list of dicts', line
)
continue
user_field = line.get('user_field')
if not hasattr(user_field, 'isalpha'):
logger.error(
'mellon: invalid LOOKUP_BY_ATTRIBUTES configuration %r: user_field is missing', line
)
continue
try:
User._meta.get_field(user_field)
except FieldDoesNotExist:
logger.error(
'mellon: invalid LOOKUP_BY_ATTRIBUTES configuration %r, user field %s does not exist',
line,
user_field,
)
continue
saml_attribute = line.get('saml_attribute')
if not hasattr(saml_attribute, 'isalpha'):
logger.error(
'mellon: invalid LOOKUP_BY_ATTRIBUTES configuration %r: saml_attribute is missing', line
)
continue
values = saml_attributes.get(saml_attribute)
if not values:
logger.warning(
'mellon: looking for user by saml attribute %r and user field %r, skipping because empty',
saml_attribute,
user_field,
)
continue
ignore_case = line.get('ignore-case', False)
for value in values:
key = user_field
if ignore_case:
key += '__iexact'
users_qs = self.get_users_queryset(idp, saml_attributes)
users_found = users_qs.filter(**{key: value})
if not users_found:
logger.warning(
'mellon: looking for users by attribute %r and user field %r with value %r: not found',
saml_attribute,
user_field,
value,
)
continue
logger.info(
'mellon: looking for user by attribute %r and user field %r with value %r: found %s',
saml_attribute,
user_field,
value,
display_truncated_list(users_found),
)
users.update(users_found)
if len(users) == 1:
user = list(users)[0]
logger.info(
'mellon: looking for user by attributes %r: found user %s', lookup_by_attributes, user
)
return user
elif len(users) > 1:
logger.warning(
'mellon: looking for user by attributes %r: too many users found(%d), failing',
lookup_by_attributes,
len(users),
)
return None
def _link_user(self, idp, saml_attributes, user):
name_id_content = saml_attributes['name_id_content']
if saml_attributes['name_id_format'] == lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT:
name_id_content = saml_attributes['transient_name_id_content']
to_update = {
'nid_format': saml_attributes['name_id_format'],
'nid_name_qualifier': saml_attributes.get('name_id_name_qualifier'),
'nid_sp_name_qualifier': saml_attributes.get('name_id_sp_name_qualifier'),
'nid_sp_provided_id': saml_attributes.get('name_id_sp_provided_id'),
}
saml_id, created = models.UserSAMLIdentifier.objects.get_or_create(
name_id=name_id_content,
issuer=models_utils.get_issuer(saml_attributes['issuer']),
defaults={
'user': user,
**to_update,
},
)
# nid_* attributes are new, we must update them eventually
for key, value in to_update.items():
if getattr(saml_id, key) != value:
models.UserSAMLIdentifier.objects.filter(pk=saml_id.pk).update(**to_update)
break
if created:
user.saml_identifier = saml_id
return user
else:
saml_id.user.saml_identifier = saml_id
return saml_id.user
def provision(self, user, idp, saml_attributes):
self.provision_attribute(user, idp, saml_attributes)
self.provision_superuser(user, idp, saml_attributes)
self.provision_groups(user, idp, saml_attributes)
def provision_attribute(self, user, idp, saml_attributes):
realm = utils.get_setting(idp, 'REALM')
attribute_mapping = utils.get_setting(idp, 'ATTRIBUTE_MAPPING')
attribute_set = False
for field, tpl in attribute_mapping.items():
try:
value = tpl.format(realm=realm, attributes=saml_attributes, idp=idp)
except ValueError:
logger.warning('mellon: invalid attribute mapping template %r', tpl)
except (AttributeError, KeyError, IndexError) as e:
logger.warning('mellon: invalid reference in attribute mapping template %r: %s', tpl, e)
else:
model_field = user._meta.get_field(field)
if hasattr(model_field, 'max_length'):
value = value[: model_field.max_length]
if getattr(user, field) != value:
old_value = getattr(user, field)
setattr(user, field, value)
attribute_set = True
logger.info(
'mellon: set field %s of user %s to value %r (old value %r)',
field,
user,
value,
old_value,
)
if attribute_set:
user.save()
def provision_superuser(self, user, idp, saml_attributes):
superuser_mapping = utils.get_setting(idp, 'SUPERUSER_MAPPING')
if not superuser_mapping:
return
attribute_set = False
for key, values in superuser_mapping.items():
if key in saml_attributes:
if not isinstance(values, (tuple, list)):
values = [values]
values = set(values)
attribute_values = saml_attributes[key]
if not isinstance(attribute_values, (tuple, list)):
attribute_values = [attribute_values]
attribute_values = set(attribute_values)
if attribute_values & values:
if not (user.is_staff and user.is_superuser):
user.is_staff = True
user.is_superuser = True
attribute_set = True
logger.info('mellon: flag is_staff and is_superuser added to user %s', user)
break
else:
if user.is_superuser or user.is_staff:
user.is_staff = False
user.is_superuser = False
logger.info('mellon: flag is_staff and is_superuser removed from user %s', user)
attribute_set = True
if attribute_set:
user.save()
def provision_groups(self, user, idp, saml_attributes):
group_attribute = utils.get_setting(idp, 'GROUP_ATTRIBUTE')
create_group = utils.get_setting(idp, 'CREATE_GROUP')
if group_attribute in saml_attributes:
values = saml_attributes[group_attribute]
if not isinstance(values, (list, tuple)):
values = [values]
groups = []
for value in set(values):
if create_group:
group, _ = Group.objects.get_or_create(name=value)
else:
try:
group = Group.objects.get(name=value)
except Group.DoesNotExist:
continue
groups.append(group)
for group in Group.objects.filter(pk__in=[g.pk for g in groups]).exclude(user=user):
logger.info('mellon: adding group %s (%s) to user %s (%s)', group, group.pk, user, user.pk)
User.groups.through.objects.get_or_create(group=group, user=user)
qs = User.groups.through.objects.exclude(group__pk__in=[g.pk for g in groups]).filter(user=user)
for rel in qs:
logger.info(
'mellon: removing group %s (%s) from user %s (%s)',
rel.group,
rel.group.pk,
rel.user,
rel.user.pk,
)
qs.delete()