hobo/hobo/agent/authentic2/provisionning.py

453 lines
17 KiB
Python

from itertools import chain, islice
import json
from django.utils.six.moves.urllib.parse import urljoin
import threading
import copy
import logging
import requests
from django.contrib.auth import get_user_model
from django.db import connection
from django.urls import reverse
from django.conf import settings
from django.utils.encoding import force_text
from django_rbac.utils import get_role_model, get_ou_model, get_role_parenting_model
from hobo.agent.common import notify_agents
from hobo.signature import sign_url
from authentic2.saml.models import LibertyProvider
from authentic2.a2_rbac.models import RoleAttribute
from authentic2.models import AttributeValue
User = get_user_model()
Role = get_role_model()
OU = get_ou_model()
RoleParenting = get_role_parenting_model()
logger = logging.getLogger(__name__)
def batch(iterable, size):
"""Batch an iterable as an iterable of iterables of at most size element
long.
"""
sourceiter = iter(iterable)
for first in sourceiter:
yield chain([first], islice(sourceiter, size - 1))
class Provisionning(threading.local):
__slots__ = ['threads']
def __init__(self):
self.threads = set()
self.stack = []
def start(self):
self.stack.append({
'saved': {},
'deleted': {},
})
def clear(self):
self.stack = []
def stop(self, provision=True, wait=True):
if not self.stack:
return
context = self.stack.pop()
if provision:
self.provision(**context)
if wait:
self.wait()
@property
def saved(self):
if self.stack:
return self.stack[-1]['saved']
return None
@property
def deleted(self):
if self.stack:
return self.stack[-1]['deleted']
return None
def add_saved(self, *args):
if not self.stack:
return
for instance in args:
klass = User if isinstance(instance, User) else Role
self.saved.setdefault(klass, set()).add(instance)
def add_deleted(self, *args):
if not self.stack:
return
for instance in args:
klass = User if isinstance(instance, User) else Role
self.deleted.setdefault(klass, set()).add(instance)
self.saved.get(klass, set()).discard(instance)
def resolve_ou(self, instances, ous):
for instance in instances:
if instance.ou_id in ous:
instance.ou = ous[instance.ou_id]
def notify_users(self, ous, users, mode='provision'):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
if mode == 'provision':
users = (User.objects.filter(id__in=[u.id for u in users])
.select_related('ou').prefetch_related('attribute_values__attribute'))
else:
self.resolve_ou(users, ous)
ous = {}
for ou in [None] + list(OU.objects.all()):
for user in users:
ous.setdefault(ou, set()).add(user)
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(tuple(allowed_technical_roles_prefixes))
issuer = force_text(self.get_entity_id())
if mode == 'provision':
def user_to_json(ou, service, user, user_roles):
from authentic2.api_views import BaseUserSerializer
data = {}
# filter user's roles visible by the service's ou
roles = [role for role in user_roles.get(user.id, [])
if (not is_forbidden_technical_role(role)
and (role.ou_id is None or (ou and role.ou_id == ou.id)))]
data.update({
'uuid': user.uuid,
'username': user.username,
'first_name': user.first_name,
'last_name': user.last_name,
'email': user.email,
'is_active': user.is_active,
'roles': [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
} for role in roles],
})
data.update(BaseUserSerializer(user).data)
# check if user is superuser through a role
role_is_superuser = False
if service:
for role in user_roles.get(user.id, []):
if role.service_id != service.pk:
continue
for attribute in role.attributes.all():
if attribute.name == 'is_superuser' and attribute.value == 'true':
role_is_superuser = True
data['is_superuser'] = user.is_superuser or role_is_superuser
return data
# Find roles giving a superuser attribute
# If there is any role of this kind, we do one provisionning message for each user and
# each service.
roles_with_attributes = (Role.objects.filter(members__in=users)
.parents(include_self=True)
.filter(attributes__name='is_superuser')
.exists())
all_roles = Role.objects.all().prefetch_related('attributes')
roles = dict((r.id, r) for r in all_roles)
user_roles = {}
parents = {}
for rp in RoleParenting.objects.all():
# broken parent/child relationship can happen
try:
parents.setdefault(rp.child.id, []).append(rp.parent.id)
except Role.DoesNotExist:
pass
Through = Role.members.through
qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id')
for u_id, r_id in qs:
# unkwon r_id can happen
if r_id in roles:
user_roles.setdefault(u_id, set()).add(roles[r_id])
for p_id in parents.get(r_id, []):
user_roles[u_id].add(roles[p_id])
if roles_with_attributes:
for ou, users in ous.items():
for service, audience in self.get_audience(ou):
for batched_users in batch(users, 500):
batched_users = list(batched_users)
for user in batched_users:
logger.info('provisionning user %s to %s', user, audience)
self.notify_agents({
'@type': 'provision',
'issuer': issuer,
'audience': [audience],
'full': False,
'objects': {
'@type': 'user',
'data': [user_to_json(ou, service, user, user_roles) for user in batched_users],
}
})
else:
for ou, users in ous.items():
audience = [a for service, a in self.get_audience(ou)]
if not audience:
continue
logger.info(u'provisionning users %s to %s', u', '.join(
map(force_text, users)), u', '.join(audience))
self.notify_agents({
'@type': 'provision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [user_to_json(ou, None, user, user_roles) for user in users],
}
})
elif users:
audience = [audience for ou in ous.keys()
for s, audience in self.get_audience(ou)]
logger.info(u'deprovisionning users %s from %s', u', '.join(
map(force_text, users)), u', '.join(audience))
self.notify_agents({
'@type': 'deprovision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [{
'uuid': user.uuid,
} for user in users]
}
})
def notify_roles(self, ous, roles, mode='provision', full=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(tuple(allowed_technical_roles_prefixes))
roles = set([role for role in roles if not is_forbidden_technical_role(role)])
if mode == 'provision':
self.complete_roles(roles)
if not roles:
return
self.resolve_ou(roles, ous)
ous = {}
for role in roles:
ous.setdefault(role.ou, []).append(role)
def helper(ou, roles):
if mode == 'provision':
data = [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
'description': role.description,
'details': role.details,
'emails': role.emails,
'emails_to_members': role.emails_to_members,
} for role in roles
]
else:
data = [
{
'uuid': role.uuid,
} for role in roles
]
audience = [entity_id for service, entity_id in self.get_audience(ou)]
logger.info(u'%sning roles %s to %s', mode, roles, audience)
self.notify_agents({
'@type': mode,
'audience': audience,
'full': full,
'objects': {
'@type': 'role',
'data': data,
}
})
global_roles = set(ous.get(None, []))
for ou, ou_roles in ous.items():
sent_roles = set(ou_roles) | global_roles
helper(ou, sent_roles)
def provision(self, saved, deleted):
# Returns if:
# - we are not in a tenant
# - provsionning is disabled
# - there is nothing to do
if (not hasattr(connection, 'tenant') or not connection.tenant or not
hasattr(connection.tenant, 'domain_url')):
return
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
return
if not (saved or deleted):
return
t = threading.Thread(
target=self.do_provision,
kwargs={'saved': saved, 'deleted': deleted})
t.start()
self.threads.add(t)
def do_provision(self, saved, deleted):
try:
ous = {ou.id: ou for ou in OU.objects.all()}
self.notify_roles(ous, saved.get(Role, []))
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
self.notify_users(ous, saved.get(User, []))
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
except Exception:
# last step, clear everything
logger.exception(u'error in provisionning thread')
finally:
self.threads.discard(threading.current_thread())
def wait(self):
for thread in list(self.threads):
thread.join()
def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_value, exc_tb):
if not self.stack:
return
self.stop(provision=exc_type is None)
def get_audience(self, ou):
if ou:
qs = LibertyProvider.objects.filter(ou=ou)
else:
qs = LibertyProvider.objects.filter(ou__isnull=True)
return [(service, service.entity_id) for service in qs]
def complete_roles(self, roles):
for role in roles:
role.emails = []
role.emails_to_members = True
role.details = u''
for attribute in role.attributes.all():
if (attribute.name in ('emails', 'emails_to_members', 'details')
and attribute.kind == 'json'):
setattr(role, attribute.name, json.loads(attribute.value))
def get_entity_id(self):
tenant = getattr(connection, 'tenant', None)
assert tenant
base_url = tenant.get_base_url()
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
if not self.stack:
return
# we skip new instances
if not instance.pk:
return
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return
# ignore last_login update on login
if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
return
if isinstance(instance, RoleAttribute):
instance = instance.role
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
if not self.stack:
return
# during post_save we only handle new instances
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
return
if not created:
return
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return
if isinstance(instance, RoleAttribute):
instance = instance.role
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def pre_delete(self, sender, instance, using, **kwargs):
if not self.stack:
return
if isinstance(instance, (User, Role)):
self.add_deleted(copy.copy(instance))
elif isinstance(instance, RoleAttribute):
instance = instance.role
self.add_saved(instance)
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
elif isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
if not self.stack:
return
if action != 'pre_clear' and action.startswith('pre_'):
return
if sender is Role.members.through:
self.add_saved(instance)
# on a clear, pk_set is None
for other_instance in model.objects.filter(pk__in=pk_set or []):
self.add_saved(other_instance)
if action == 'pre_clear':
# when the action is pre_clear we need to lookup the current value of the members
# relation, to re-provision all previously enroled users.
if not reverse:
for other_instance in instance.members.all():
self.add_saved(other_instance)
def notify_agents(self, data):
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', False):
services_by_url = {}
for services in settings.KNOWN_SERVICES.values():
for service in services.values():
if service.get('provisionning-url'):
services_by_url[service['saml-sp-metadata-url']] = service
audience = data.get('audience')
rest_audience = [x for x in audience if x in services_by_url]
amqp_audience = audience
for audience in rest_audience:
service = services_by_url[audience]
data['audience'] = [audience]
try:
response = requests.put(
sign_url(service['provisionning-url'] + '?orig=%s' % service['orig'], service['secret']),
json=data)
response.raise_for_status()
except requests.RequestException as e:
logger.error(u'error provisionning to %s (%s)', audience, e)
else:
amqp_audience.remove(audience)
data['audience'] = amqp_audience
if amqp_audience:
logger.info(u'leftover AMQP audience: %s', amqp_audience)
if data['audience']:
notify_agents(data)