hobo/hobo/agent/authentic2/provisionning.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

527 lines
19 KiB
Python
Raw Normal View History

import copy
import datetime
import json
import logging
import threading
import urllib.parse
from itertools import chain, islice
import requests
from authentic2.a2_rbac.models import OrganizationalUnit as OU
from authentic2.a2_rbac.models import Role, RoleParenting
from authentic2.models import AttributeValue
from authentic2.saml.models import LibertyProvider
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import connection, transaction
from django.urls import reverse
from django.utils.encoding import force_str
from hobo.agent.common import notify_agents
from hobo.signature import sign_url
User = get_user_model()
logger = logging.getLogger(__name__)
def batch(iterable, size):
"""Batch an iterable as an iterable of iterables of at most size element
long.
"""
sourceiter = iter(iterable)
for first in sourceiter:
yield chain([first], islice(sourceiter, size - 1))
class Provisionning(threading.local):
__slots__ = ['threads']
def __init__(self):
2020-03-04 14:31:36 +01:00
self.threads = set()
self.stack = []
def start(self):
self.stack.append(
{
'saved': {},
'deleted': {},
'in_atomic_block': connection.in_atomic_block,
}
)
def clear(self):
self.stack = []
def stop(self, provision=True, wait=True):
if not self.stack:
return
context = self.stack.pop()
context.pop('in_atomic_block')
if provision:
def callback():
self.provision(**context)
if wait:
self.wait()
transaction.on_commit(callback)
@property
def saved(self):
if self.stack:
return self.stack[-1]['saved']
return None
@property
def deleted(self):
if self.stack:
return self.stack[-1]['deleted']
return None
def add_saved(self, *args):
if not self.stack:
return
in_atomic_block = self.stack[-1]['in_atomic_block']
def callback():
for instance in args:
klass = User if isinstance(instance, User) else Role
self.saved.setdefault(klass, set()).add(instance)
if in_atomic_block:
callback()
else:
transaction.on_commit(callback)
def add_deleted(self, *args):
if not self.stack:
return
in_atomic_block = self.stack[-1]['in_atomic_block']
def callback():
for instance in args:
klass = User if isinstance(instance, User) else Role
self.deleted.setdefault(klass, set()).add(instance)
self.saved.get(klass, set()).discard(instance)
if in_atomic_block:
callback()
else:
transaction.on_commit(callback)
def resolve_ou(self, instances, ous):
for instance in instances:
if instance.ou_id in ous:
instance.ou = ous[instance.ou_id]
def notify_users(self, ous, users, mode='provision', sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
if mode == 'provision':
users = (
User.objects.filter(id__in=[u.id for u in users])
.select_related('ou')
.prefetch_related('attribute_values__attribute')
2021-05-14 18:39:27 +02:00
)
else:
self.resolve_ou(users, ous)
ous = {}
for ou in [None] + list(OU.objects.all()):
for user in users:
ous.setdefault(ou, set()).add(user)
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(
tuple(allowed_technical_roles_prefixes)
2021-05-14 18:39:27 +02:00
)
issuer = force_str(self.get_entity_id())
if mode == 'provision':
def user_to_json(ou, service, user, user_roles):
from authentic2.api_views import BaseUserSerializer
2021-05-14 18:39:27 +02:00
data = {}
# filter user's roles visible by the service's ou
roles = [
role
for role in user_roles.get(user.id, [])
if (
not is_forbidden_technical_role(role)
and (role.ou_id is None or (ou and role.ou_id == ou.id))
2021-05-14 18:39:27 +02:00
)
]
data.update(
{
'uuid': user.uuid,
'username': user.username,
'first_name': user.first_name,
'last_name': user.last_name,
'email': user.email,
'is_active': user.is_active,
'roles': [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
}
for role in roles
],
}
)
data.update(BaseUserSerializer(user).data)
# check if user is superuser through a role
role_is_superuser = False
if service:
for role in user_roles.get(user.id, []):
if role.service_id != service.pk:
continue
role_is_superuser = role.is_superuser
data['is_superuser'] = user.is_superuser or role_is_superuser
return data
2021-05-14 18:39:27 +02:00
# Find roles giving a superuser attribute
# If there is any role of this kind, we do one provisionning message for each user and
# each service.
roles_with_attributes = (
Role.objects.filter(members__in=users)
.parents(include_self=True)
.filter(is_superuser=True)
.exists()
)
all_roles = Role.objects.all()
roles = {r.id: r for r in all_roles}
user_roles = {}
parents = {}
for rp in RoleParenting.objects.filter(deleted__isnull=True):
# broken parent/child relationship can happen
try:
parents.setdefault(rp.child.id, []).append(rp.parent.id)
except Role.DoesNotExist:
pass
Through = Role.members.through
2020-03-04 14:31:36 +01:00
qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id')
for u_id, r_id in qs:
# unkwon r_id can happen
if r_id in roles:
user_roles.setdefault(u_id, set()).add(roles[r_id])
for p_id in parents.get(r_id, []):
user_roles[u_id].add(roles[p_id])
if roles_with_attributes:
2020-03-04 14:31:36 +01:00
for ou, users in ous.items():
for service, audience in self.get_audience(ou):
for batched_users in batch(users, 500):
batched_users = list(batched_users)
for user in batched_users:
logger.info('provisionning user %s to %s', user, audience)
self.notify_agents(
{
'@type': 'provision',
'issuer': issuer,
'audience': [audience],
'full': False,
'objects': {
'@type': 'user',
'data': [
user_to_json(ou, service, user, user_roles)
for user in batched_users
],
2021-05-14 18:39:27 +02:00
},
},
sync=sync,
)
else:
2020-03-04 14:31:36 +01:00
for ou, users in ous.items():
audience = [a for service, a in self.get_audience(ou)]
if not audience:
continue
2020-03-04 14:31:36 +01:00
logger.info(
'provisionning users %s to %s',
', '.join(map(force_str, users)),
2020-03-04 14:31:36 +01:00
', '.join(audience),
)
self.notify_agents(
{
'@type': 'provision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [user_to_json(ou, None, user, user_roles) for user in users],
2021-05-14 18:39:27 +02:00
},
},
sync=sync,
)
elif users:
audience = [audience for ou in ous.keys() for s, audience in self.get_audience(ou)]
2020-03-04 14:31:36 +01:00
logger.info(
'deprovisionning users %s from %s', ', '.join(map(force_str, users)), ', '.join(audience)
2020-03-04 14:31:36 +01:00
)
self.notify_agents(
{
'@type': 'deprovision',
'issuer': issuer,
'audience': audience,
'full': False,
'objects': {
'@type': 'user',
'data': [
{
'uuid': user.uuid,
}
for user in users
2021-05-14 18:39:27 +02:00
],
},
},
sync=sync,
)
def notify_roles(self, ous, roles, mode='provision', full=False, sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
def is_forbidden_technical_role(role):
return role.slug.startswith('_') and not role.slug.startswith(
tuple(allowed_technical_roles_prefixes)
2021-05-14 18:39:27 +02:00
)
roles = {role for role in roles if not is_forbidden_technical_role(role)}
if not roles:
return
self.resolve_ou(roles, ous)
ous = {}
for role in roles:
ous.setdefault(role.ou, []).append(role)
def helper(ou, roles):
if mode == 'provision':
data = [
{
'uuid': role.uuid,
'name': role.name,
'slug': role.slug,
'description': role.description,
'details': role.details,
'emails': role.emails,
'emails_to_members': role.emails_to_members,
}
for role in roles
]
else:
data = [
{
'uuid': role.uuid,
}
for role in roles
]
audience = [entity_id for service, entity_id in self.get_audience(ou)]
logger.info('%sning roles %s to %s', mode, roles, audience)
self.notify_agents(
{
'@type': mode,
'audience': audience,
'full': full,
'objects': {
'@type': 'role',
'data': data,
2021-05-14 18:39:27 +02:00
},
},
sync=sync,
)
global_roles = set(ous.get(None, []))
2020-03-04 14:31:36 +01:00
for ou, ou_roles in ous.items():
sent_roles = set(ou_roles) | global_roles
helper(ou, sent_roles)
def provision(self, saved, deleted):
# Returns if:
# - we are not in a tenant
# - provsionning is disabled
# - there is nothing to do
if (
not hasattr(connection, 'tenant')
or not connection.tenant
or not hasattr(connection.tenant, 'domain_url')
):
return
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
return
if not (saved or deleted):
return
t = threading.Thread(target=self.do_provision, kwargs={'saved': saved, 'deleted': deleted})
t.start()
self.threads.add(t)
def do_provision(self, saved, deleted):
try:
ous = {ou.id: ou for ou in OU.objects.all()}
self.notify_roles(ous, saved.get(Role, []))
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
self.notify_users(ous, saved.get(User, []))
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
except Exception:
# last step, clear everything
logger.exception('error in provisionning thread')
finally:
self.threads.discard(threading.current_thread())
def wait(self):
for thread in list(self.threads):
thread.join()
def __enter__(self):
self.start()
def __exit__(self, exc_type, exc_value, exc_tb):
if not self.stack:
return
self.stop(provision=exc_type is None)
def get_audience(self, ou):
if ou:
qs = LibertyProvider.objects.filter(ou=ou)
else:
qs = LibertyProvider.objects.filter(ou__isnull=True)
return [(service, service.entity_id) for service in qs]
def get_entity_id(self):
tenant = getattr(connection, 'tenant', None)
assert tenant
base_url = tenant.get_base_url()
return urllib.parse.urljoin(base_url, reverse('a2-idp-saml-metadata'))
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
if not self.stack:
return
# we skip new instances
if not instance.pk:
return
if not isinstance(instance, (User, Role, AttributeValue)):
return
# ignore last_login update on login
if isinstance(instance, User) and (update_fields and set(update_fields) == {'last_login'}):
return
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
if not self.stack:
return
# during post_save we only handle new instances
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
return
if not created:
return
if not isinstance(instance, (User, Role, AttributeValue)):
return
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
def pre_delete(self, sender, instance, using, **kwargs):
if not self.stack:
return
if isinstance(instance, (User, Role)):
self.add_deleted(copy.copy(instance))
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.add_saved(instance)
elif isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
if not self.stack:
return
if action != 'pre_clear' and action.startswith('pre_'):
return
if sender is Role.members.through:
self.add_saved(instance)
# on a clear, pk_set is None
for other_instance in model.objects.filter(pk__in=pk_set or []):
self.add_saved(other_instance)
if action == 'pre_clear':
# when the action is pre_clear we need to lookup the current value of the members
# relation, to re-provision all previously enroled users.
if not reverse:
for other_instance in instance.members.all():
self.add_saved(other_instance)
def post_soft_create(self, sender, instance, **kwargs):
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def post_soft_delete(self, sender, instance, **kwargs):
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def notify_agents(self, data, sync=False):
log_path = getattr(settings, 'DEBUG_PROVISIONNING_LOG_PATH', '')
if log_path and getattr(settings, 'HOBO_PROVISIONNING_DEBUG', False):
try:
with open(log_path, 'a') as f:
f.write('%s %s ' % (datetime.datetime.now().isoformat(), connection.tenant.domain_url))
json.dump(data, f, indent=2)
f.write('\n')
except OSError:
pass
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', True):
leftover_audience = self.notify_agents_http(data, sync=sync)
if not leftover_audience:
return
logger.info('leftover AMQP audience: %s', leftover_audience)
data['audience'] = leftover_audience
if data['audience']:
notify_agents(data)
def get_http_services_by_url(self):
services_by_url = {}
known_services = getattr(settings, 'KNOWN_SERVICES', {})
for services in known_services.values():
for service in services.values():
if service.get('provisionning-url'):
services_by_url[service['saml-sp-metadata-url']] = service
return services_by_url
def notify_agents_http(self, data, sync=False):
services_by_url = self.get_http_services_by_url()
audience = data.get('audience')
rest_audience = [x for x in audience if x in services_by_url]
leftover_audience = audience
for audience in rest_audience:
service = services_by_url[audience]
data['audience'] = [audience]
url = service['provisionning-url'] + '?orig=%s' % service['orig']
if sync:
url += '&sync=1'
try:
response = requests.put(sign_url(url, service['secret']), json=data)
response.raise_for_status()
except requests.RequestException as e:
logger.error('error provisionning to %s (%s)', audience, e)
else:
leftover_audience.remove(audience)
return leftover_audience