453 lines
17 KiB
Python
453 lines
17 KiB
Python
from itertools import chain, islice
|
|
import json
|
|
from django.utils.six.moves.urllib.parse import urljoin
|
|
import threading
|
|
import copy
|
|
import logging
|
|
import requests
|
|
|
|
from django.contrib.auth import get_user_model
|
|
from django.db import connection
|
|
from django.urls import reverse
|
|
from django.conf import settings
|
|
from django.utils.encoding import force_text
|
|
|
|
from django_rbac.utils import get_role_model, get_ou_model, get_role_parenting_model
|
|
from hobo.agent.common import notify_agents
|
|
from hobo.signature import sign_url
|
|
from authentic2.saml.models import LibertyProvider
|
|
from authentic2.a2_rbac.models import RoleAttribute
|
|
from authentic2.models import AttributeValue
|
|
|
|
User = get_user_model()
|
|
Role = get_role_model()
|
|
OU = get_ou_model()
|
|
RoleParenting = get_role_parenting_model()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def batch(iterable, size):
|
|
"""Batch an iterable as an iterable of iterables of at most size element
|
|
long.
|
|
"""
|
|
sourceiter = iter(iterable)
|
|
for first in sourceiter:
|
|
yield chain([first], islice(sourceiter, size - 1))
|
|
|
|
|
|
class Provisionning(threading.local):
|
|
__slots__ = ['threads']
|
|
|
|
def __init__(self):
|
|
self.threads = set()
|
|
self.stack = []
|
|
|
|
def start(self):
|
|
self.stack.append({
|
|
'saved': {},
|
|
'deleted': {},
|
|
})
|
|
|
|
def clear(self):
|
|
self.stack = []
|
|
|
|
def stop(self, provision=True, wait=True):
|
|
if not self.stack:
|
|
return
|
|
|
|
context = self.stack.pop()
|
|
|
|
if provision:
|
|
self.provision(**context)
|
|
if wait:
|
|
self.wait()
|
|
|
|
@property
|
|
def saved(self):
|
|
if self.stack:
|
|
return self.stack[-1]['saved']
|
|
return None
|
|
|
|
@property
|
|
def deleted(self):
|
|
if self.stack:
|
|
return self.stack[-1]['deleted']
|
|
return None
|
|
|
|
def add_saved(self, *args):
|
|
if not self.stack:
|
|
return
|
|
|
|
for instance in args:
|
|
klass = User if isinstance(instance, User) else Role
|
|
self.saved.setdefault(klass, set()).add(instance)
|
|
|
|
def add_deleted(self, *args):
|
|
if not self.stack:
|
|
return
|
|
|
|
for instance in args:
|
|
klass = User if isinstance(instance, User) else Role
|
|
self.deleted.setdefault(klass, set()).add(instance)
|
|
self.saved.get(klass, set()).discard(instance)
|
|
|
|
def resolve_ou(self, instances, ous):
|
|
for instance in instances:
|
|
if instance.ou_id in ous:
|
|
instance.ou = ous[instance.ou_id]
|
|
|
|
def notify_users(self, ous, users, mode='provision'):
|
|
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
|
|
|
|
if mode == 'provision':
|
|
users = (User.objects.filter(id__in=[u.id for u in users])
|
|
.select_related('ou').prefetch_related('attribute_values__attribute'))
|
|
else:
|
|
self.resolve_ou(users, ous)
|
|
|
|
ous = {}
|
|
for ou in [None] + list(OU.objects.all()):
|
|
for user in users:
|
|
ous.setdefault(ou, set()).add(user)
|
|
|
|
def is_forbidden_technical_role(role):
|
|
return role.slug.startswith('_') and not role.slug.startswith(tuple(allowed_technical_roles_prefixes))
|
|
|
|
issuer = force_text(self.get_entity_id())
|
|
if mode == 'provision':
|
|
|
|
def user_to_json(ou, service, user, user_roles):
|
|
from authentic2.api_views import BaseUserSerializer
|
|
data = {}
|
|
# filter user's roles visible by the service's ou
|
|
roles = [role for role in user_roles.get(user.id, [])
|
|
if (not is_forbidden_technical_role(role)
|
|
and (role.ou_id is None or (ou and role.ou_id == ou.id)))]
|
|
data.update({
|
|
'uuid': user.uuid,
|
|
'username': user.username,
|
|
'first_name': user.first_name,
|
|
'last_name': user.last_name,
|
|
'email': user.email,
|
|
'is_active': user.is_active,
|
|
'roles': [
|
|
{
|
|
'uuid': role.uuid,
|
|
'name': role.name,
|
|
'slug': role.slug,
|
|
} for role in roles],
|
|
})
|
|
data.update(BaseUserSerializer(user).data)
|
|
# check if user is superuser through a role
|
|
role_is_superuser = False
|
|
if service:
|
|
for role in user_roles.get(user.id, []):
|
|
if role.service_id != service.pk:
|
|
continue
|
|
for attribute in role.attributes.all():
|
|
if attribute.name == 'is_superuser' and attribute.value == 'true':
|
|
role_is_superuser = True
|
|
data['is_superuser'] = user.is_superuser or role_is_superuser
|
|
return data
|
|
# Find roles giving a superuser attribute
|
|
# If there is any role of this kind, we do one provisionning message for each user and
|
|
# each service.
|
|
roles_with_attributes = (Role.objects.filter(members__in=users)
|
|
.parents(include_self=True)
|
|
.filter(attributes__name='is_superuser')
|
|
.exists())
|
|
|
|
all_roles = Role.objects.all().prefetch_related('attributes')
|
|
roles = dict((r.id, r) for r in all_roles)
|
|
user_roles = {}
|
|
parents = {}
|
|
for rp in RoleParenting.objects.all():
|
|
# broken parent/child relationship can happen
|
|
try:
|
|
parents.setdefault(rp.child.id, []).append(rp.parent.id)
|
|
except Role.DoesNotExist:
|
|
pass
|
|
Through = Role.members.through
|
|
qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id')
|
|
for u_id, r_id in qs:
|
|
# unkwon r_id can happen
|
|
if r_id in roles:
|
|
user_roles.setdefault(u_id, set()).add(roles[r_id])
|
|
for p_id in parents.get(r_id, []):
|
|
user_roles[u_id].add(roles[p_id])
|
|
|
|
if roles_with_attributes:
|
|
for ou, users in ous.items():
|
|
for service, audience in self.get_audience(ou):
|
|
for batched_users in batch(users, 500):
|
|
batched_users = list(batched_users)
|
|
for user in batched_users:
|
|
logger.info('provisionning user %s to %s', user, audience)
|
|
self.notify_agents({
|
|
'@type': 'provision',
|
|
'issuer': issuer,
|
|
'audience': [audience],
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [user_to_json(ou, service, user, user_roles) for user in batched_users],
|
|
}
|
|
})
|
|
else:
|
|
for ou, users in ous.items():
|
|
audience = [a for service, a in self.get_audience(ou)]
|
|
if not audience:
|
|
continue
|
|
logger.info(u'provisionning users %s to %s', u', '.join(
|
|
map(force_text, users)), u', '.join(audience))
|
|
self.notify_agents({
|
|
'@type': 'provision',
|
|
'issuer': issuer,
|
|
'audience': audience,
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [user_to_json(ou, None, user, user_roles) for user in users],
|
|
}
|
|
})
|
|
elif users:
|
|
audience = [audience for ou in ous.keys()
|
|
for s, audience in self.get_audience(ou)]
|
|
logger.info(u'deprovisionning users %s from %s', u', '.join(
|
|
map(force_text, users)), u', '.join(audience))
|
|
self.notify_agents({
|
|
'@type': 'deprovision',
|
|
'issuer': issuer,
|
|
'audience': audience,
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [{
|
|
'uuid': user.uuid,
|
|
} for user in users]
|
|
}
|
|
})
|
|
|
|
def notify_roles(self, ous, roles, mode='provision', full=False):
|
|
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
|
|
|
|
def is_forbidden_technical_role(role):
|
|
return role.slug.startswith('_') and not role.slug.startswith(tuple(allowed_technical_roles_prefixes))
|
|
|
|
roles = set([role for role in roles if not is_forbidden_technical_role(role)])
|
|
if mode == 'provision':
|
|
self.complete_roles(roles)
|
|
|
|
if not roles:
|
|
return
|
|
|
|
self.resolve_ou(roles, ous)
|
|
ous = {}
|
|
for role in roles:
|
|
ous.setdefault(role.ou, []).append(role)
|
|
|
|
def helper(ou, roles):
|
|
if mode == 'provision':
|
|
data = [
|
|
{
|
|
'uuid': role.uuid,
|
|
'name': role.name,
|
|
'slug': role.slug,
|
|
'description': role.description,
|
|
'details': role.details,
|
|
'emails': role.emails,
|
|
'emails_to_members': role.emails_to_members,
|
|
} for role in roles
|
|
]
|
|
else:
|
|
data = [
|
|
{
|
|
'uuid': role.uuid,
|
|
} for role in roles
|
|
]
|
|
|
|
audience = [entity_id for service, entity_id in self.get_audience(ou)]
|
|
logger.info(u'%sning roles %s to %s', mode, roles, audience)
|
|
self.notify_agents({
|
|
'@type': mode,
|
|
'audience': audience,
|
|
'full': full,
|
|
'objects': {
|
|
'@type': 'role',
|
|
'data': data,
|
|
}
|
|
})
|
|
|
|
global_roles = set(ous.get(None, []))
|
|
for ou, ou_roles in ous.items():
|
|
sent_roles = set(ou_roles) | global_roles
|
|
helper(ou, sent_roles)
|
|
|
|
def provision(self, saved, deleted):
|
|
# Returns if:
|
|
# - we are not in a tenant
|
|
# - provsionning is disabled
|
|
# - there is nothing to do
|
|
if (not hasattr(connection, 'tenant') or not connection.tenant or not
|
|
hasattr(connection.tenant, 'domain_url')):
|
|
return
|
|
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
|
|
return
|
|
if not (saved or deleted):
|
|
return
|
|
|
|
t = threading.Thread(
|
|
target=self.do_provision,
|
|
kwargs={'saved': saved, 'deleted': deleted})
|
|
t.start()
|
|
self.threads.add(t)
|
|
|
|
def do_provision(self, saved, deleted):
|
|
try:
|
|
ous = {ou.id: ou for ou in OU.objects.all()}
|
|
self.notify_roles(ous, saved.get(Role, []))
|
|
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
|
|
self.notify_users(ous, saved.get(User, []))
|
|
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
|
|
except Exception:
|
|
# last step, clear everything
|
|
logger.exception(u'error in provisionning thread')
|
|
finally:
|
|
self.threads.discard(threading.current_thread())
|
|
|
|
def wait(self):
|
|
for thread in list(self.threads):
|
|
thread.join()
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
if not self.stack:
|
|
return
|
|
self.stop(provision=exc_type is None)
|
|
|
|
def get_audience(self, ou):
|
|
if ou:
|
|
qs = LibertyProvider.objects.filter(ou=ou)
|
|
else:
|
|
qs = LibertyProvider.objects.filter(ou__isnull=True)
|
|
return [(service, service.entity_id) for service in qs]
|
|
|
|
def complete_roles(self, roles):
|
|
for role in roles:
|
|
role.emails = []
|
|
role.emails_to_members = True
|
|
role.details = u''
|
|
for attribute in role.attributes.all():
|
|
if (attribute.name in ('emails', 'emails_to_members', 'details')
|
|
and attribute.kind == 'json'):
|
|
setattr(role, attribute.name, json.loads(attribute.value))
|
|
|
|
def get_entity_id(self):
|
|
tenant = getattr(connection, 'tenant', None)
|
|
assert tenant
|
|
base_url = tenant.get_base_url()
|
|
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
|
|
|
|
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
# we skip new instances
|
|
if not instance.pk:
|
|
return
|
|
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
|
return
|
|
# ignore last_login update on login
|
|
if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
|
|
return
|
|
if isinstance(instance, RoleAttribute):
|
|
instance = instance.role
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
|
|
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
# during post_save we only handle new instances
|
|
if isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
return
|
|
if not created:
|
|
return
|
|
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
|
return
|
|
if isinstance(instance, RoleAttribute):
|
|
instance = instance.role
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
|
|
def pre_delete(self, sender, instance, using, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
if isinstance(instance, (User, Role)):
|
|
self.add_deleted(copy.copy(instance))
|
|
elif isinstance(instance, RoleAttribute):
|
|
instance = instance.role
|
|
self.add_saved(instance)
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
elif isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
|
|
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
if action != 'pre_clear' and action.startswith('pre_'):
|
|
return
|
|
if sender is Role.members.through:
|
|
self.add_saved(instance)
|
|
# on a clear, pk_set is None
|
|
for other_instance in model.objects.filter(pk__in=pk_set or []):
|
|
self.add_saved(other_instance)
|
|
if action == 'pre_clear':
|
|
# when the action is pre_clear we need to lookup the current value of the members
|
|
# relation, to re-provision all previously enroled users.
|
|
if not reverse:
|
|
for other_instance in instance.members.all():
|
|
self.add_saved(other_instance)
|
|
|
|
def notify_agents(self, data):
|
|
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', False):
|
|
services_by_url = {}
|
|
for services in settings.KNOWN_SERVICES.values():
|
|
for service in services.values():
|
|
if service.get('provisionning-url'):
|
|
services_by_url[service['saml-sp-metadata-url']] = service
|
|
audience = data.get('audience')
|
|
rest_audience = [x for x in audience if x in services_by_url]
|
|
amqp_audience = audience
|
|
for audience in rest_audience:
|
|
service = services_by_url[audience]
|
|
data['audience'] = [audience]
|
|
try:
|
|
response = requests.put(
|
|
sign_url(service['provisionning-url'] + '?orig=%s' % service['orig'], service['secret']),
|
|
json=data)
|
|
response.raise_for_status()
|
|
except requests.RequestException as e:
|
|
logger.error(u'error provisionning to %s (%s)', audience, e)
|
|
else:
|
|
amqp_audience.remove(audience)
|
|
data['audience'] = amqp_audience
|
|
if amqp_audience:
|
|
logger.info(u'leftover AMQP audience: %s', amqp_audience)
|
|
|
|
if data['audience']:
|
|
notify_agents(data)
|