hobo/hobo/agent/authentic2/provisionning.py

527 lines
19 KiB
Python

import copy
import datetime
import json
import logging
import threading
import urllib.parse
from itertools import chain, islice
import requests
from authentic2.a2_rbac.models import OrganizationalUnit as OU
from authentic2.a2_rbac.models import Role, RoleParenting
from authentic2.models import AttributeValue
from authentic2.saml.models import LibertyProvider
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import connection, transaction
from django.urls import reverse
from django.utils.encoding import force_str
from hobo.agent.common import notify_agents
from hobo.signature import sign_url
User = get_user_model()
logger = logging.getLogger(__name__)
def batch(iterable, size):
"""Batch an iterable as an iterable of iterables of at most size element
long.
"""
sourceiter = iter(iterable)
for first in sourceiter:
yield chain([first], islice(sourceiter, size - 1))
class Provisionning(threading.local):
__slots__ = ['threads']
def __init__(self):
self.threads = set()
self.stack = []
def start(self):
self.stack.append(
{
'saved': {},
'deleted': {},
'in_atomic_block': connection.in_atomic_block,
}
)
def clear(self):
self.stack = []
def stop(self, provision=True, wait=True):
if not self.stack:
return
context = self.stack.pop()
context.pop('in_atomic_block')
if provision:
def callback():
self.provision(**context)
if wait:
self.wait()
transaction.on_commit(callback)
@property
def saved(self):
if self.stack:
return self.stack[-1]['saved']
return None
@property
def deleted(self):
if self.stack:
return self.stack[-1]['deleted']
return None
def add_saved(self, *args):
if not self.stack:
return
in_atomic_block = self.stack[-1]['in_atomic_block']
def callback():
for instance in args:
klass = User if isinstance(instance, User) else Role
self.saved.setdefault(klass, set()).add(instance)
if in_atomic_block:
callback()
else:
transaction.on_commit(callback)
def add_deleted(self, *args):
if not self.stack:
return
in_atomic_block = self.stack[-1]['in_atomic_block']
def callback():
for instance in args:
klass = User if isinstance(instance, User) else Role
self.deleted.setdefault(klass, set()).add(instance)
self.saved.get(klass, set()).discard(instance)
if in_atomic_block:
callback()
else:
transaction.on_commit(callback)
def resolve_ou(self, instances, ous):
for instance in instances:
if instance.ou_id in ous:
instance.ou = ous[instance.ou_id]
def notify_users(self, ous, users, mode='provision', sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
if mode == 'provision':
users = (
User.objects.filter(id__in=[u.id for u in users])
.select_related('ou')
.prefetch_related('attribute_values__attribute')
)
else:
self.resolve_ou(users, ous)
ous = {}
for ou in [None] + list(OU.objects.all()):
for user in users:
ous.setdefault(ou, set()).add(user)
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(
tuple(allowed_technical_roles_prefixes)
)
issuer = force_str(self.get_entity_id())
if mode == 'provision':
def user_to_json(ou, service, user, user_roles):
from authentic2.api_views import BaseUserSerializer
data = {}
# filter user's roles visible by the service's ou
roles = [
role
for role in user_roles.get(user.id, [])
if (
not is_forbidden_technical_role(role)
and (role.ou_id is None or (ou and role.ou_id == ou.id))
)
]
data.update(
{
'uuid': user.uuid,
'username': user.username,
'first_name': user.first_name,
'last_name': user.last_name,
'email': user.email,
'is_active': user.is_active,
'roles': [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
}
for role in roles
],
}
)
data.update(BaseUserSerializer(user).data)
# check if user is superuser through a role
role_is_superuser = False
if service:
for role in user_roles.get(user.id, []):
if role.service_id != service.pk:
continue
role_is_superuser = role.is_superuser
data['is_superuser'] = user.is_superuser or role_is_superuser
return data
# Find roles giving a superuser attribute
# If there is any role of this kind, we do one provisionning message for each user and
# each service.
roles_with_attributes = (
Role.objects.filter(members__in=users)
.parents(include_self=True)
.filter(is_superuser=True)
.exists()
)
all_roles = Role.objects.all()
roles = {r.id: r for r in all_roles}
user_roles = {}
parents = {}
for rp in RoleParenting.objects.filter(deleted__isnull=True):
# broken parent/child relationship can happen
try:
parents.setdefault(rp.child.id, []).append(rp.parent.id)
except Role.DoesNotExist:
pass
Through = Role.members.through
qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id')
for u_id, r_id in qs:
# unkwon r_id can happen
if r_id in roles:
user_roles.setdefault(u_id, set()).add(roles[r_id])
for p_id in parents.get(r_id, []):
user_roles[u_id].add(roles[p_id])
if roles_with_attributes:
for ou, users in ous.items():
for service, audience in self.get_audience(ou):
for batched_users in batch(users, 500):
batched_users = list(batched_users)
for user in batched_users:
logger.info('provisionning user %s to %s', user, audience)
self.notify_agents(
{
'@type': 'provision',
'issuer': issuer,
'audience': [audience],
'full': False,
'objects': {
'@type': 'user',
'data': [
user_to_json(ou, service, user, user_roles)
for user in batched_users
],
},
},
sync=sync,
)
else:
for ou, users in ous.items():
audience = [a for service, a in self.get_audience(ou)]
if not audience:
continue
logger.info(
'provisionning users %s to %s',
', '.join(map(force_str, users)),
', '.join(audience),
)
self.notify_agents(
{
'@type': 'provision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [user_to_json(ou, None, user, user_roles) for user in users],
},
},
sync=sync,
)
elif users:
audience = [audience for ou in ous.keys() for s, audience in self.get_audience(ou)]
logger.info(
'deprovisionning users %s from %s', ', '.join(map(force_str, users)), ', '.join(audience)
)
self.notify_agents(
{
'@type': 'deprovision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [
{
'uuid': user.uuid,
}
for user in users
],
},
},
sync=sync,
)
def notify_roles(self, ous, roles, mode='provision', full=False, sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(
tuple(allowed_technical_roles_prefixes)
)
roles = {role for role in roles if not is_forbidden_technical_role(role)}
if not roles:
return
self.resolve_ou(roles, ous)
ous = {}
for role in roles:
ous.setdefault(role.ou, []).append(role)
def helper(ou, roles):
if mode == 'provision':
data = [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
'description': role.description,
'details': role.details,
'emails': role.emails,
'emails_to_members': role.emails_to_members,
}
for role in roles
]
else:
data = [
{
'uuid': role.uuid,
}
for role in roles
]
audience = [entity_id for service, entity_id in self.get_audience(ou)]
logger.info('%sning roles %s to %s', mode, roles, audience)
self.notify_agents(
{
'@type': mode,
'audience': audience,
'full': full,
'objects': {
'@type': 'role',
'data': data,
},
},
sync=sync,
)
global_roles = set(ous.get(None, []))
for ou, ou_roles in ous.items():
sent_roles = set(ou_roles) | global_roles
helper(ou, sent_roles)
def provision(self, saved, deleted):
# Returns if:
# - we are not in a tenant
# - provsionning is disabled
# - there is nothing to do
if (
not hasattr(connection, 'tenant')
or not connection.tenant
or not hasattr(connection.tenant, 'domain_url')
):
return
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
return
if not (saved or deleted):
return
t = threading.Thread(target=self.do_provision, kwargs={'saved': saved, 'deleted': deleted})
t.start()
self.threads.add(t)
def do_provision(self, saved, deleted):
try:
ous = {ou.id: ou for ou in OU.objects.all()}
self.notify_roles(ous, saved.get(Role, []))
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
self.notify_users(ous, saved.get(User, []))
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
except Exception:
# last step, clear everything
logger.exception('error in provisionning thread')
finally:
self.threads.discard(threading.current_thread())
def wait(self):
for thread in list(self.threads):
thread.join()
def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_value, exc_tb):
if not self.stack:
return
self.stop(provision=exc_type is None)
def get_audience(self, ou):
if ou:
qs = LibertyProvider.objects.filter(ou=ou)
else:
qs = LibertyProvider.objects.filter(ou__isnull=True)
return [(service, service.entity_id) for service in qs]
def get_entity_id(self):
tenant = getattr(connection, 'tenant', None)
assert tenant
base_url = tenant.get_base_url()
return urllib.parse.urljoin(base_url, reverse('a2-idp-saml-metadata'))
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
if not self.stack:
return
# we skip new instances
if not instance.pk:
return
if not isinstance(instance, (User, Role, AttributeValue)):
return
# ignore last_login update on login
if isinstance(instance, User) and (update_fields and set(update_fields) == {'last_login'}):
return
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
if not self.stack:
return
# during post_save we only handle new instances
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
return
if not created:
return
if not isinstance(instance, (User, Role, AttributeValue)):
return
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def pre_delete(self, sender, instance, using, **kwargs):
if not self.stack:
return
if isinstance(instance, (User, Role)):
self.add_deleted(copy.copy(instance))
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
elif isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
if not self.stack:
return
if action != 'pre_clear' and action.startswith('pre_'):
return
if sender is Role.members.through:
self.add_saved(instance)
# on a clear, pk_set is None
for other_instance in model.objects.filter(pk__in=pk_set or []):
self.add_saved(other_instance)
if action == 'pre_clear':
# when the action is pre_clear we need to lookup the current value of the members
# relation, to re-provision all previously enroled users.
if not reverse:
for other_instance in instance.members.all():
self.add_saved(other_instance)
def post_soft_create(self, sender, instance, **kwargs):
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def post_soft_delete(self, sender, instance, **kwargs):
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def notify_agents(self, data, sync=False):
log_path = getattr(settings, 'DEBUG_PROVISIONNING_LOG_PATH', '')
if log_path and getattr(settings, 'HOBO_PROVISIONNING_DEBUG', False):
try:
with open(log_path, 'a') as f:
f.write('%s %s ' % (datetime.datetime.now().isoformat(), connection.tenant.domain_url))
json.dump(data, f, indent=2)
f.write('\n')
except OSError:
pass
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', True):
leftover_audience = self.notify_agents_http(data, sync=sync)
if not leftover_audience:
return
logger.info('leftover AMQP audience: %s', leftover_audience)
data['audience'] = leftover_audience
if data['audience']:
notify_agents(data)
def get_http_services_by_url(self):
services_by_url = {}
known_services = getattr(settings, 'KNOWN_SERVICES', {})
for services in known_services.values():
for service in services.values():
if service.get('provisionning-url'):
services_by_url[service['saml-sp-metadata-url']] = service
return services_by_url
def notify_agents_http(self, data, sync=False):
services_by_url = self.get_http_services_by_url()
audience = data.get('audience')
rest_audience = [x for x in audience if x in services_by_url]
leftover_audience = audience
for audience in rest_audience:
service = services_by_url[audience]
data['audience'] = [audience]
url = service['provisionning-url'] + '?orig=%s' % service['orig']
if sync:
url += '&sync=1'
try:
response = requests.put(sign_url(url, service['secret']), json=data)
response.raise_for_status()
except requests.RequestException as e:
logger.error('error provisionning to %s (%s)', audience, e)
else:
leftover_audience.remove(audience)
return leftover_audience