208 lines
7.8 KiB
Python
208 lines
7.8 KiB
Python
import base64
|
|
import requests
|
|
from requests.adapters import HTTPAdapter
|
|
import logging
|
|
|
|
from django.core.exceptions import ValidationError
|
|
|
|
from . import models, app_settings
|
|
|
|
from authentic2.saml.models import LibertyProvider, LibertyServiceProvider
|
|
from authentic2_idp_cas.models import Service as CASService
|
|
|
|
def normalize_cert(certificate_pem):
|
|
'''Normalize content of the certificate'''
|
|
base64_content = ''.join(certificate_pem.splitlines()[1:-1])
|
|
content = base64.b64decode(base64_content)
|
|
return base64.b64encode(content)
|
|
|
|
def explode_dn(dn):
|
|
'''Extract sub element of a DN as displayed by mod_ssl or nginx_ssl'''
|
|
dn = dn.strip('/')
|
|
parts = dn.split('/')
|
|
parts = [part.split('=') for part in parts]
|
|
parts = [(part[0], part[1].decode('string_escape').decode('utf-8'))
|
|
for part in parts]
|
|
return parts
|
|
|
|
TRANSFORM = {
|
|
'cert': normalize_cert,
|
|
}
|
|
|
|
class SSLInfo(object):
|
|
"""
|
|
Encapsulates the SSL environment variables in a read-only object. It
|
|
attempts to find the ssl vars based on the type of request passed to the
|
|
constructor. Currently only WSGIRequest and ModPythonRequest are
|
|
supported.
|
|
"""
|
|
def __init__(self, request):
|
|
name = request.__class__.__name__
|
|
if name == 'WSGIRequest':
|
|
env = request.environ
|
|
elif name == 'ModPythonRequest':
|
|
env = request._req.subprocess_env
|
|
else:
|
|
raise EnvironmentError, 'The SSL authentication currently only \
|
|
works with mod_python or wsgi requests'
|
|
self.read_env(env);
|
|
if self.issuer_dn and self.subject_dn:
|
|
kwargs = dict(certificate_issuer_dn=self.issuer_dn,
|
|
certificate_subject_dn=self.subject_dn)
|
|
try:
|
|
self.__dict__['collectivity'] = models.Collectivity.objects.get(**kwargs)
|
|
except models.Collectivity.DoesNotExist:
|
|
self.__dict__['collectivity'] = None
|
|
self.__dict__['users'] = models.User.objects.filter(**kwargs).select_related()
|
|
else:
|
|
self.__dict__['collectivity'] = None
|
|
self.__dict__['users'] = models.User.objects.none()
|
|
|
|
def read_env(self, env):
|
|
for attr, keys in app_settings.X509_KEYS.iteritems():
|
|
if isinstance(keys, basestring):
|
|
keys = [keys]
|
|
for key in keys:
|
|
if key in env and env[key]:
|
|
v = env[key]
|
|
if attr in TRANSFORM:
|
|
v = TRANSFORM[attr](v)
|
|
self.__dict__[attr] = v
|
|
break
|
|
else:
|
|
self.__dict__[attr] = None
|
|
|
|
|
|
if self.__dict__['verify'] == 'SUCCESS':
|
|
self.__dict__['verify'] = True
|
|
else:
|
|
self.__dict__['verify'] = False
|
|
|
|
def get(self, attr):
|
|
return self.__getattr__(attr)
|
|
|
|
def __getattr__(self, attr):
|
|
if attr in self.__dict__:
|
|
return self.__dict__[attr]
|
|
else:
|
|
raise AttributeError, 'SSLInfo does not contain key %s' % attr
|
|
|
|
def __setattr__(self, attr, value):
|
|
raise AttributeError, 'SSL vars are read only!'
|
|
|
|
def __repr__(self):
|
|
return '<SSLInfo %s>' % self.__dict__
|
|
|
|
|
|
def sync_saml_provider(service_or_service_instance):
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if not isinstance(service_or_service_instance, (models.ServiceInstance, models.Service)):
|
|
raise TypeError
|
|
if isinstance(service_or_service_instance, models.ServiceInstance):
|
|
assert service_or_service_instance.service, 'service instance badly initialized'
|
|
if service_or_service_instance.service.is_global:
|
|
# do not create provider for global service instances
|
|
return
|
|
saml_slug = u'siid-%s-%s' % (service_or_service_instance.collectivity.slug,
|
|
service_or_service_instance.slug)
|
|
name = service_or_service_instance.service.name
|
|
ou = service_or_service_instance.collectivity
|
|
else:
|
|
# if service is not global, do not create it
|
|
if not service_or_service_instance.is_global:
|
|
return
|
|
saml_slug = u'sid-%s' % service_or_service_instance.slug
|
|
name = service_or_service_instance.name
|
|
ou = None
|
|
# enforce limits of LibertyProvider model
|
|
name = name[:140]
|
|
saml_slug = saml_slug[:128]
|
|
try:
|
|
liberty_provider = LibertyProvider.objects.get(slug=saml_slug)
|
|
# if not metadata url, delete the service
|
|
if not service_or_service_instance.metadata_url:
|
|
liberty_provider.delete()
|
|
except LibertyProvider.DoesNotExist:
|
|
liberty_provider = LibertyProvider()
|
|
if not service_or_service_instance.metadata_url:
|
|
return
|
|
liberty_provider.slug = saml_slug
|
|
liberty_provider.name = name
|
|
liberty_provider.ou = ou
|
|
try:
|
|
s = requests.Session()
|
|
# do not fail early
|
|
a = HTTPAdapter(max_retries=3)
|
|
s.mount('http://', a)
|
|
s.mount('https://', a)
|
|
response = s.get(service_or_service_instance.metadata_url)
|
|
except requests.RequestException, e:
|
|
logger.warning('updating of metadata of provider %r failed: %s',
|
|
saml_slug, e)
|
|
return False
|
|
liberty_provider.metadata = unicode(response.content, 'utf-8')
|
|
try:
|
|
liberty_provider.full_clean(exclude=('ou', 'entity_id', 'protocol_conformance'))
|
|
except ValidationError, e:
|
|
logger.warning('updating of metadata of provider %r failed: %s',
|
|
saml_slug, e)
|
|
return False
|
|
try:
|
|
liberty_provider.id = LibertyProvider.objects.get(entity_id=liberty_provider.entity_id).id
|
|
except LibertyProvider.DoesNotExist:
|
|
pass
|
|
liberty_provider.save()
|
|
liberty_service_provider, created = LibertyServiceProvider.objects.get_or_create(
|
|
liberty_provider=liberty_provider,
|
|
defaults={'enabled': True, 'users_can_manage_federations': False})
|
|
if not created and liberty_service_provider.users_can_manage_federations:
|
|
liberty_service_provider.users_can_manage_federations = False
|
|
liberty_service_provider.save()
|
|
return True
|
|
|
|
def sync_oauth2_client(service_or_service_instance):
|
|
# TODO
|
|
pass
|
|
|
|
def sync_cas_provider(service_or_service_instance):
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if not isinstance(service_or_service_instance, (models.ServiceInstance, models.Service)):
|
|
raise TypeError
|
|
if isinstance(service_or_service_instance, models.ServiceInstance):
|
|
assert service_or_service_instance.service, 'service instance badly initialized'
|
|
if service_or_service_instance.service.is_global:
|
|
# do not create provider for global service instances
|
|
return
|
|
slug = u'siid-%s-%s' % (service_or_service_instance.collectivity.slug,
|
|
service_or_service_instance.slug)
|
|
name = service_or_service_instance.service.name
|
|
ou = service_or_service_instance.collectivity
|
|
else:
|
|
# if service is not global, do not create it
|
|
if not service_or_service_instance.is_global:
|
|
return
|
|
slug = u'sid-%s' % service_or_service_instance.slug
|
|
name = service_or_service_instance.name
|
|
ou = None
|
|
if not service_or_service_instance.cas_service_url:
|
|
CASService.objects.filter(slug=slug).delete()
|
|
else:
|
|
defaults = {
|
|
'urls': service_or_service_instance.cas_service_url,
|
|
'name': name,
|
|
'ou': ou,
|
|
'identifier_attribute': 'pratic_uid',
|
|
}
|
|
service, created = CASService.objects.get_or_create(
|
|
slug=slug, defaults=defaults)
|
|
if not created:
|
|
changed = False
|
|
for key, value in defaults.iteritems():
|
|
if getattr(service, key) != value:
|
|
setattr(service, key, value)
|
|
changed = True
|
|
if changed:
|
|
service.save()
|