This repository has been archived on 2023-02-21. You can view files and clone it, but cannot push or open issues or pull requests.
authentic2-pratic/src/authentic2_pratic/utils.py

211 lines
7.9 KiB
Python

import base64
import requests
from requests.adapters import HTTPAdapter
import logging
from django.core.exceptions import ValidationError
from django.db.transaction import atomic
from . import models, app_settings
from authentic2.saml.models import LibertyProvider, LibertyServiceProvider
from authentic2_idp_cas.models import Service as CASService
def normalize_cert(certificate_pem):
'''Normalize content of the certificate'''
base64_content = ''.join(certificate_pem.splitlines()[1:-1])
content = base64.b64decode(base64_content)
return base64.b64encode(content)
def explode_dn(dn):
'''Extract sub element of a DN as displayed by mod_ssl or nginx_ssl'''
dn = dn.strip('/')
parts = dn.split('/')
parts = [part.split('=') for part in parts]
parts = [(part[0], part[1].decode('string_escape').decode('utf-8'))
for part in parts]
return parts
TRANSFORM = {
'cert': normalize_cert,
}
class SSLInfo(object):
"""
Encapsulates the SSL environment variables in a read-only object. It
attempts to find the ssl vars based on the type of request passed to the
constructor. Currently only WSGIRequest and ModPythonRequest are
supported.
"""
def __init__(self, request):
name = request.__class__.__name__
if name == 'WSGIRequest':
env = request.environ
elif name == 'ModPythonRequest':
env = request._req.subprocess_env
else:
raise EnvironmentError, 'The SSL authentication currently only \
works with mod_python or wsgi requests'
self.read_env(env);
if self.issuer_dn and self.subject_dn:
kwargs = dict(certificate_issuer_dn=self.issuer_dn,
certificate_subject_dn=self.subject_dn)
try:
self.__dict__['collectivity'] = models.Collectivity.objects.get(**kwargs)
except models.Collectivity.DoesNotExist:
self.__dict__['collectivity'] = None
self.__dict__['users'] = models.User.objects.filter(**kwargs).select_related()
else:
self.__dict__['collectivity'] = None
self.__dict__['users'] = models.User.objects.none()
def read_env(self, env):
for attr, keys in app_settings.X509_KEYS.iteritems():
if isinstance(keys, basestring):
keys = [keys]
for key in keys:
if key in env and env[key]:
v = env[key]
if attr in TRANSFORM:
v = TRANSFORM[attr](v)
self.__dict__[attr] = v
break
else:
self.__dict__[attr] = None
if self.__dict__['verify'] == 'SUCCESS':
self.__dict__['verify'] = True
else:
self.__dict__['verify'] = False
def get(self, attr):
return self.__getattr__(attr)
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
else:
raise AttributeError, 'SSLInfo does not contain key %s' % attr
def __setattr__(self, attr, value):
raise AttributeError, 'SSL vars are read only!'
def __repr__(self):
return '<SSLInfo %s>' % self.__dict__
@atomic
def sync_saml_provider(service_or_service_instance):
logger = logging.getLogger(__name__)
if not isinstance(service_or_service_instance, (models.ServiceInstance, models.Service)):
raise TypeError
if isinstance(service_or_service_instance, models.ServiceInstance):
assert service_or_service_instance.service, 'service instance badly initialized'
if service_or_service_instance.service.is_global:
# do not create provider for global service instances
return
saml_slug = u'siid-%s-%s' % (service_or_service_instance.collectivity.slug,
service_or_service_instance.slug)
name = service_or_service_instance.service.name
ou = service_or_service_instance.collectivity
else:
# if service is not global, do not create it
if not service_or_service_instance.is_global:
return
saml_slug = u'sid-%s' % service_or_service_instance.slug
name = service_or_service_instance.name
ou = None
# enforce limits of LibertyProvider model
name = name[:140]
saml_slug = saml_slug[:128]
try:
liberty_provider = LibertyProvider.objects.get(slug=saml_slug)
# if not metadata url, delete the service
if not service_or_service_instance.metadata_url:
liberty_provider.delete()
except LibertyProvider.DoesNotExist:
liberty_provider = LibertyProvider()
if not service_or_service_instance.metadata_url:
return
liberty_provider.slug = saml_slug
liberty_provider.name = name
liberty_provider.ou = ou
try:
s = requests.Session()
# do not fail early
a = HTTPAdapter(max_retries=3)
s.mount('http://', a)
s.mount('https://', a)
response = s.get(service_or_service_instance.metadata_url)
except requests.RequestException, e:
logger.warning('updating of metadata of provider %r failed: %s',
saml_slug, e)
return False
liberty_provider.metadata = unicode(response.content, 'utf-8')
try:
liberty_provider.full_clean(exclude=('ou', 'entity_id', 'protocol_conformance'))
except ValidationError, e:
logger.warning('updating of metadata of provider %r failed: %s',
saml_slug, e)
return False
try:
liberty_provider.id = LibertyProvider.objects.get(entity_id=liberty_provider.entity_id).id
except LibertyProvider.DoesNotExist:
pass
liberty_provider.save()
liberty_service_provider, created = LibertyServiceProvider.objects.get_or_create(
liberty_provider=liberty_provider,
defaults={'enabled': True, 'users_can_manage_federations': False})
if not created and liberty_service_provider.users_can_manage_federations:
liberty_service_provider.users_can_manage_federations = False
liberty_service_provider.save()
return True
def sync_oauth2_client(service_or_service_instance):
# TODO
pass
@atomic
def sync_cas_provider(service_or_service_instance):
logger = logging.getLogger(__name__)
if not isinstance(service_or_service_instance, (models.ServiceInstance, models.Service)):
raise TypeError
if isinstance(service_or_service_instance, models.ServiceInstance):
assert service_or_service_instance.service, 'service instance badly initialized'
if service_or_service_instance.service.is_global:
# do not create provider for global service instances
return
slug = u'siid-%s-%s' % (service_or_service_instance.collectivity.slug,
service_or_service_instance.slug)
name = service_or_service_instance.service.name
ou = service_or_service_instance.collectivity
else:
# if service is not global, do not create it
if not service_or_service_instance.is_global:
return
slug = u'sid-%s' % service_or_service_instance.slug
name = service_or_service_instance.name
ou = None
if not service_or_service_instance.cas_service_url:
CASService.objects.filter(slug=slug).delete()
else:
defaults = {
'urls': service_or_service_instance.cas_service_url,
'name': name,
'ou': ou,
'identifier_attribute': 'pratic_uid',
}
service, created = CASService.objects.get_or_create(
slug=slug, defaults=defaults)
if not created:
changed = False
for key, value in defaults.iteritems():
if getattr(service, key) != value:
setattr(service, key, value)
changed = True
if changed:
service.save()