import copy import datetime import json import logging import threading import urllib.parse from itertools import chain, islice import requests from authentic2.a2_rbac.models import OrganizationalUnit as OU from authentic2.a2_rbac.models import Role, RoleParenting from authentic2.models import AttributeValue from authentic2.saml.models import LibertyProvider from django.conf import settings from django.contrib.auth import get_user_model from django.db import connection, transaction from django.urls import reverse from django.utils.encoding import force_str from hobo.agent.common import notify_agents from hobo.signature import sign_url User = get_user_model() logger = logging.getLogger(__name__) def batch(iterable, size): """Batch an iterable as an iterable of iterables of at most size element long. """ sourceiter = iter(iterable) for first in sourceiter: yield chain([first], islice(sourceiter, size - 1)) class Provisionning(threading.local): __slots__ = ['threads'] def __init__(self): self.threads = set() self.stack = [] def start(self): self.stack.append( { 'saved': {}, 'deleted': {}, 'in_atomic_block': connection.in_atomic_block, } ) def clear(self): self.stack = [] def stop(self, provision=True, wait=True): if not self.stack: return context = self.stack.pop() context.pop('in_atomic_block') if provision: def callback(): self.provision(**context) if wait: self.wait() transaction.on_commit(callback) @property def saved(self): if self.stack: return self.stack[-1]['saved'] return None @property def deleted(self): if self.stack: return self.stack[-1]['deleted'] return None def add_saved(self, *args): if not self.stack: return in_atomic_block = self.stack[-1]['in_atomic_block'] def callback(): for instance in args: klass = User if isinstance(instance, User) else Role self.saved.setdefault(klass, set()).add(instance) if in_atomic_block: callback() else: transaction.on_commit(callback) def add_deleted(self, *args): if not self.stack: return in_atomic_block = self.stack[-1]['in_atomic_block'] def callback(): for instance in args: klass = User if isinstance(instance, User) else Role self.deleted.setdefault(klass, set()).add(instance) self.saved.get(klass, set()).discard(instance) if in_atomic_block: callback() else: transaction.on_commit(callback) def resolve_ou(self, instances, ous): for instance in instances: if instance.ou_id in ous: instance.ou = ous[instance.ou_id] def notify_users(self, ous, users, mode='provision', sync=False): allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or [] if mode == 'provision': users = ( User.objects.filter(id__in=[u.id for u in users]) .select_related('ou') .prefetch_related('attribute_values__attribute') ) else: self.resolve_ou(users, ous) ous = {} for ou in [None] + list(OU.objects.all()): for user in users: ous.setdefault(ou, set()).add(user) def is_forbidden_technical_role(role): return role.slug.startswith('_') and not role.slug.startswith( tuple(allowed_technical_roles_prefixes) ) issuer = force_str(self.get_entity_id()) if mode == 'provision': def user_to_json(ou, service, user, user_roles): from authentic2.api_views import BaseUserSerializer data = {} # filter user's roles visible by the service's ou roles = [ role for role in user_roles.get(user.id, []) if ( not is_forbidden_technical_role(role) and (role.ou_id is None or (ou and role.ou_id == ou.id)) ) ] data.update( { 'uuid': user.uuid, 'username': user.username, 'first_name': user.first_name, 'last_name': user.last_name, 'email': user.email, 'is_active': user.is_active, 'roles': [ { 'uuid': role.uuid, 'name': role.name, 'slug': role.slug, } for role in roles ], } ) data.update(BaseUserSerializer(user).data) # check if user is superuser through a role role_is_superuser = False if service: for role in user_roles.get(user.id, []): if role.service_id != service.pk: continue role_is_superuser = role.is_superuser data['is_superuser'] = user.is_superuser or role_is_superuser return data # Find roles giving a superuser attribute # If there is any role of this kind, we do one provisionning message for each user and # each service. roles_with_attributes = ( Role.objects.filter(members__in=users) .parents(include_self=True) .filter(is_superuser=True) .exists() ) all_roles = Role.objects.all() roles = {r.id: r for r in all_roles} user_roles = {} parents = {} for rp in RoleParenting.objects.filter(deleted__isnull=True): # broken parent/child relationship can happen try: parents.setdefault(rp.child.id, []).append(rp.parent.id) except Role.DoesNotExist: pass Through = Role.members.through qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id') for u_id, r_id in qs: # unkwon r_id can happen if r_id in roles: user_roles.setdefault(u_id, set()).add(roles[r_id]) for p_id in parents.get(r_id, []): user_roles[u_id].add(roles[p_id]) if roles_with_attributes: for ou, users in ous.items(): for service, audience in self.get_audience(ou): for batched_users in batch(users, 500): batched_users = list(batched_users) for user in batched_users: logger.info('provisionning user %s to %s', user, audience) self.notify_agents( { '@type': 'provision', 'issuer': issuer, 'audience': [audience], 'full': False, 'objects': { '@type': 'user', 'data': [ user_to_json(ou, service, user, user_roles) for user in batched_users ], }, }, sync=sync, ) else: for ou, users in ous.items(): audience = [a for service, a in self.get_audience(ou)] if not audience: continue logger.info( 'provisionning users %s to %s', ', '.join(map(force_str, users)), ', '.join(audience), ) self.notify_agents( { '@type': 'provision', 'issuer': issuer, 'audience': audience, 'full': False, 'objects': { '@type': 'user', 'data': [user_to_json(ou, None, user, user_roles) for user in users], }, }, sync=sync, ) elif users: audience = [audience for ou in ous.keys() for s, audience in self.get_audience(ou)] logger.info( 'deprovisionning users %s from %s', ', '.join(map(force_str, users)), ', '.join(audience) ) self.notify_agents( { '@type': 'deprovision', 'issuer': issuer, 'audience': audience, 'full': False, 'objects': { '@type': 'user', 'data': [ { 'uuid': user.uuid, } for user in users ], }, }, sync=sync, ) def notify_roles(self, ous, roles, mode='provision', full=False, sync=False): allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or [] def is_forbidden_technical_role(role): return role.slug.startswith('_') and not role.slug.startswith( tuple(allowed_technical_roles_prefixes) ) roles = {role for role in roles if not is_forbidden_technical_role(role)} if not roles: return self.resolve_ou(roles, ous) ous = {} for role in roles: ous.setdefault(role.ou, []).append(role) def helper(ou, roles): if mode == 'provision': data = [ { 'uuid': role.uuid, 'name': role.name, 'slug': role.slug, 'description': role.description, 'details': role.details, 'emails': role.emails, 'emails_to_members': role.emails_to_members, } for role in roles ] else: data = [ { 'uuid': role.uuid, } for role in roles ] audience = [entity_id for service, entity_id in self.get_audience(ou)] logger.info('%sning roles %s to %s', mode, roles, audience) self.notify_agents( { '@type': mode, 'audience': audience, 'full': full, 'objects': { '@type': 'role', 'data': data, }, }, sync=sync, ) global_roles = set(ous.get(None, [])) for ou, ou_roles in ous.items(): sent_roles = set(ou_roles) | global_roles helper(ou, sent_roles) def provision(self, saved, deleted): # Returns if: # - we are not in a tenant # - provsionning is disabled # - there is nothing to do if ( not hasattr(connection, 'tenant') or not connection.tenant or not hasattr(connection.tenant, 'domain_url') ): return if not getattr(settings, 'HOBO_ROLE_EXPORT', True): return if not (saved or deleted): return t = threading.Thread(target=self.do_provision, kwargs={'saved': saved, 'deleted': deleted}) t.start() self.threads.add(t) def do_provision(self, saved, deleted): try: ous = {ou.id: ou for ou in OU.objects.all()} self.notify_roles(ous, saved.get(Role, [])) self.notify_roles(ous, deleted.get(Role, []), mode='deprovision') self.notify_users(ous, saved.get(User, [])) self.notify_users(ous, deleted.get(User, []), mode='deprovision') except Exception: # last step, clear everything logger.exception('error in provisionning thread') finally: self.threads.discard(threading.current_thread()) def wait(self): for thread in list(self.threads): thread.join() def __enter__(self): self.start() def __exit__(self, exc_type, exc_value, exc_tb): if not self.stack: return self.stop(provision=exc_type is None) def get_audience(self, ou): if ou: qs = LibertyProvider.objects.filter(ou=ou) else: qs = LibertyProvider.objects.filter(ou__isnull=True) return [(service, service.entity_id) for service in qs] def get_entity_id(self): tenant = getattr(connection, 'tenant', None) assert tenant base_url = tenant.get_base_url() return urllib.parse.urljoin(base_url, reverse('a2-idp-saml-metadata')) def pre_save(self, sender, instance, raw, using, update_fields, **kwargs): if not self.stack: return # we skip new instances if not instance.pk: return if not isinstance(instance, (User, Role, AttributeValue)): return # ignore last_login update on login if isinstance(instance, User) and (update_fields and set(update_fields) == {'last_login'}): return elif isinstance(instance, AttributeValue): if not isinstance(instance.owner, User): return instance = instance.owner self.add_saved(instance) def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs): if not self.stack: return # during post_save we only handle new instances if isinstance(instance, RoleParenting): self.add_saved(*list(instance.child.all_members())) return if not created: return if not isinstance(instance, (User, Role, AttributeValue)): return elif isinstance(instance, AttributeValue): if not isinstance(instance.owner, User): return instance = instance.owner self.add_saved(instance) def pre_delete(self, sender, instance, using, **kwargs): if not self.stack: return if isinstance(instance, (User, Role)): self.add_deleted(copy.copy(instance)) elif isinstance(instance, AttributeValue): if not isinstance(instance.owner, User): return instance = instance.owner self.add_saved(instance) elif isinstance(instance, RoleParenting): self.add_saved(*list(instance.child.all_members())) def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs): if not self.stack: return if action != 'pre_clear' and action.startswith('pre_'): return if sender is Role.members.through: self.add_saved(instance) # on a clear, pk_set is None for other_instance in model.objects.filter(pk__in=pk_set or []): self.add_saved(other_instance) if action == 'pre_clear': # when the action is pre_clear we need to lookup the current value of the members # relation, to re-provision all previously enroled users. if not reverse: for other_instance in instance.members.all(): self.add_saved(other_instance) def post_soft_create(self, sender, instance, **kwargs): if isinstance(instance, RoleParenting): self.add_saved(*list(instance.child.all_members())) def post_soft_delete(self, sender, instance, **kwargs): if isinstance(instance, RoleParenting): self.add_saved(*list(instance.child.all_members())) def notify_agents(self, data, sync=False): log_path = getattr(settings, 'DEBUG_PROVISIONNING_LOG_PATH', '') if log_path and getattr(settings, 'HOBO_PROVISIONNING_DEBUG', False): try: with open(log_path, 'a') as f: f.write('%s %s ' % (datetime.datetime.now().isoformat(), connection.tenant.domain_url)) json.dump(data, f, indent=2) f.write('\n') except OSError: pass if getattr(settings, 'HOBO_HTTP_PROVISIONNING', True): leftover_audience = self.notify_agents_http(data, sync=sync) if not leftover_audience: return logger.info('leftover AMQP audience: %s', leftover_audience) data['audience'] = leftover_audience if data['audience']: notify_agents(data) def get_http_services_by_url(self): services_by_url = {} known_services = getattr(settings, 'KNOWN_SERVICES', {}) for services in known_services.values(): for service in services.values(): if service.get('provisionning-url'): services_by_url[service['saml-sp-metadata-url']] = service return services_by_url def notify_agents_http(self, data, sync=False): services_by_url = self.get_http_services_by_url() audience = data.get('audience') rest_audience = [x for x in audience if x in services_by_url] leftover_audience = audience for audience in rest_audience: service = services_by_url[audience] data['audience'] = [audience] url = service['provisionning-url'] + '?orig=%s' % service['orig'] if sync: url += '&sync=1' try: response = requests.put(sign_url(url, service['secret']), json=data) response.raise_for_status() except requests.RequestException as e: logger.error('error provisionning to %s (%s)', audience, e) else: leftover_audience.remove(audience) return leftover_audience