527 lines
19 KiB
Python
527 lines
19 KiB
Python
import copy
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import threading
|
|
import urllib.parse
|
|
from itertools import chain, islice
|
|
|
|
import requests
|
|
from authentic2.a2_rbac.models import OrganizationalUnit as OU
|
|
from authentic2.a2_rbac.models import Role, RoleParenting
|
|
from authentic2.models import AttributeValue
|
|
from authentic2.saml.models import LibertyProvider
|
|
from django.conf import settings
|
|
from django.contrib.auth import get_user_model
|
|
from django.db import connection, transaction
|
|
from django.urls import reverse
|
|
from django.utils.encoding import force_str
|
|
|
|
from hobo.agent.common import notify_agents
|
|
from hobo.signature import sign_url
|
|
|
|
User = get_user_model()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def batch(iterable, size):
|
|
"""Batch an iterable as an iterable of iterables of at most size element
|
|
long.
|
|
"""
|
|
sourceiter = iter(iterable)
|
|
for first in sourceiter:
|
|
yield chain([first], islice(sourceiter, size - 1))
|
|
|
|
|
|
class Provisionning(threading.local):
|
|
__slots__ = ['threads']
|
|
|
|
def __init__(self):
|
|
self.threads = set()
|
|
self.stack = []
|
|
|
|
def start(self):
|
|
self.stack.append(
|
|
{
|
|
'saved': {},
|
|
'deleted': {},
|
|
'in_atomic_block': connection.in_atomic_block,
|
|
}
|
|
)
|
|
|
|
def clear(self):
|
|
self.stack = []
|
|
|
|
def stop(self, provision=True, wait=True):
|
|
if not self.stack:
|
|
return
|
|
|
|
context = self.stack.pop()
|
|
context.pop('in_atomic_block')
|
|
|
|
if provision:
|
|
|
|
def callback():
|
|
self.provision(**context)
|
|
if wait:
|
|
self.wait()
|
|
|
|
transaction.on_commit(callback)
|
|
|
|
@property
|
|
def saved(self):
|
|
if self.stack:
|
|
return self.stack[-1]['saved']
|
|
return None
|
|
|
|
@property
|
|
def deleted(self):
|
|
if self.stack:
|
|
return self.stack[-1]['deleted']
|
|
return None
|
|
|
|
def add_saved(self, *args):
|
|
if not self.stack:
|
|
return
|
|
|
|
in_atomic_block = self.stack[-1]['in_atomic_block']
|
|
|
|
def callback():
|
|
for instance in args:
|
|
klass = User if isinstance(instance, User) else Role
|
|
self.saved.setdefault(klass, set()).add(instance)
|
|
|
|
if in_atomic_block:
|
|
callback()
|
|
else:
|
|
transaction.on_commit(callback)
|
|
|
|
def add_deleted(self, *args):
|
|
if not self.stack:
|
|
return
|
|
|
|
in_atomic_block = self.stack[-1]['in_atomic_block']
|
|
|
|
def callback():
|
|
for instance in args:
|
|
klass = User if isinstance(instance, User) else Role
|
|
self.deleted.setdefault(klass, set()).add(instance)
|
|
self.saved.get(klass, set()).discard(instance)
|
|
|
|
if in_atomic_block:
|
|
callback()
|
|
else:
|
|
transaction.on_commit(callback)
|
|
|
|
def resolve_ou(self, instances, ous):
|
|
for instance in instances:
|
|
if instance.ou_id in ous:
|
|
instance.ou = ous[instance.ou_id]
|
|
|
|
def notify_users(self, ous, users, mode='provision', sync=False):
|
|
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
|
|
|
|
if mode == 'provision':
|
|
users = (
|
|
User.objects.filter(id__in=[u.id for u in users])
|
|
.select_related('ou')
|
|
.prefetch_related('attribute_values__attribute')
|
|
)
|
|
else:
|
|
self.resolve_ou(users, ous)
|
|
|
|
ous = {}
|
|
for ou in [None] + list(OU.objects.all()):
|
|
for user in users:
|
|
ous.setdefault(ou, set()).add(user)
|
|
|
|
def is_forbidden_technical_role(role):
|
|
return role.slug.startswith('_') and not role.slug.startswith(
|
|
tuple(allowed_technical_roles_prefixes)
|
|
)
|
|
|
|
issuer = force_str(self.get_entity_id())
|
|
if mode == 'provision':
|
|
|
|
def user_to_json(ou, service, user, user_roles):
|
|
from authentic2.api_views import BaseUserSerializer
|
|
|
|
data = {}
|
|
# filter user's roles visible by the service's ou
|
|
roles = [
|
|
role
|
|
for role in user_roles.get(user.id, [])
|
|
if (
|
|
not is_forbidden_technical_role(role)
|
|
and (role.ou_id is None or (ou and role.ou_id == ou.id))
|
|
)
|
|
]
|
|
data.update(
|
|
{
|
|
'uuid': user.uuid,
|
|
'username': user.username,
|
|
'first_name': user.first_name,
|
|
'last_name': user.last_name,
|
|
'email': user.email,
|
|
'is_active': user.is_active,
|
|
'roles': [
|
|
{
|
|
'uuid': role.uuid,
|
|
'name': role.name,
|
|
'slug': role.slug,
|
|
}
|
|
for role in roles
|
|
],
|
|
}
|
|
)
|
|
data.update(BaseUserSerializer(user).data)
|
|
# check if user is superuser through a role
|
|
role_is_superuser = False
|
|
if service:
|
|
for role in user_roles.get(user.id, []):
|
|
if role.service_id != service.pk:
|
|
continue
|
|
role_is_superuser = role.is_superuser
|
|
data['is_superuser'] = user.is_superuser or role_is_superuser
|
|
return data
|
|
|
|
# Find roles giving a superuser attribute
|
|
# If there is any role of this kind, we do one provisionning message for each user and
|
|
# each service.
|
|
roles_with_attributes = (
|
|
Role.objects.filter(members__in=users)
|
|
.parents(include_self=True)
|
|
.filter(is_superuser=True)
|
|
.exists()
|
|
)
|
|
|
|
all_roles = Role.objects.all()
|
|
roles = {r.id: r for r in all_roles}
|
|
user_roles = {}
|
|
parents = {}
|
|
for rp in RoleParenting.objects.filter(deleted__isnull=True):
|
|
# broken parent/child relationship can happen
|
|
try:
|
|
parents.setdefault(rp.child.id, []).append(rp.parent.id)
|
|
except Role.DoesNotExist:
|
|
pass
|
|
Through = Role.members.through
|
|
qs = Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id')
|
|
for u_id, r_id in qs:
|
|
# unkwon r_id can happen
|
|
if r_id in roles:
|
|
user_roles.setdefault(u_id, set()).add(roles[r_id])
|
|
for p_id in parents.get(r_id, []):
|
|
user_roles[u_id].add(roles[p_id])
|
|
|
|
if roles_with_attributes:
|
|
for ou, users in ous.items():
|
|
for service, audience in self.get_audience(ou):
|
|
for batched_users in batch(users, 500):
|
|
batched_users = list(batched_users)
|
|
for user in batched_users:
|
|
logger.info('provisionning user %s to %s', user, audience)
|
|
self.notify_agents(
|
|
{
|
|
'@type': 'provision',
|
|
'issuer': issuer,
|
|
'audience': [audience],
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [
|
|
user_to_json(ou, service, user, user_roles)
|
|
for user in batched_users
|
|
],
|
|
},
|
|
},
|
|
sync=sync,
|
|
)
|
|
else:
|
|
for ou, users in ous.items():
|
|
audience = [a for service, a in self.get_audience(ou)]
|
|
if not audience:
|
|
continue
|
|
logger.info(
|
|
'provisionning users %s to %s',
|
|
', '.join(map(force_str, users)),
|
|
', '.join(audience),
|
|
)
|
|
self.notify_agents(
|
|
{
|
|
'@type': 'provision',
|
|
'issuer': issuer,
|
|
'audience': audience,
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [user_to_json(ou, None, user, user_roles) for user in users],
|
|
},
|
|
},
|
|
sync=sync,
|
|
)
|
|
elif users:
|
|
audience = [audience for ou in ous.keys() for s, audience in self.get_audience(ou)]
|
|
logger.info(
|
|
'deprovisionning users %s from %s', ', '.join(map(force_str, users)), ', '.join(audience)
|
|
)
|
|
self.notify_agents(
|
|
{
|
|
'@type': 'deprovision',
|
|
'issuer': issuer,
|
|
'audience': audience,
|
|
'full': False,
|
|
'objects': {
|
|
'@type': 'user',
|
|
'data': [
|
|
{
|
|
'uuid': user.uuid,
|
|
}
|
|
for user in users
|
|
],
|
|
},
|
|
},
|
|
sync=sync,
|
|
)
|
|
|
|
def notify_roles(self, ous, roles, mode='provision', full=False, sync=False):
|
|
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
|
|
|
|
def is_forbidden_technical_role(role):
|
|
return role.slug.startswith('_') and not role.slug.startswith(
|
|
tuple(allowed_technical_roles_prefixes)
|
|
)
|
|
|
|
roles = {role for role in roles if not is_forbidden_technical_role(role)}
|
|
|
|
if not roles:
|
|
return
|
|
|
|
self.resolve_ou(roles, ous)
|
|
ous = {}
|
|
for role in roles:
|
|
ous.setdefault(role.ou, []).append(role)
|
|
|
|
def helper(ou, roles):
|
|
if mode == 'provision':
|
|
data = [
|
|
{
|
|
'uuid': role.uuid,
|
|
'name': role.name,
|
|
'slug': role.slug,
|
|
'description': role.description,
|
|
'details': role.details,
|
|
'emails': role.emails,
|
|
'emails_to_members': role.emails_to_members,
|
|
}
|
|
for role in roles
|
|
]
|
|
else:
|
|
data = [
|
|
{
|
|
'uuid': role.uuid,
|
|
}
|
|
for role in roles
|
|
]
|
|
|
|
audience = [entity_id for service, entity_id in self.get_audience(ou)]
|
|
logger.info('%sning roles %s to %s', mode, roles, audience)
|
|
self.notify_agents(
|
|
{
|
|
'@type': mode,
|
|
'audience': audience,
|
|
'full': full,
|
|
'objects': {
|
|
'@type': 'role',
|
|
'data': data,
|
|
},
|
|
},
|
|
sync=sync,
|
|
)
|
|
|
|
global_roles = set(ous.get(None, []))
|
|
for ou, ou_roles in ous.items():
|
|
sent_roles = set(ou_roles) | global_roles
|
|
helper(ou, sent_roles)
|
|
|
|
def provision(self, saved, deleted):
|
|
# Returns if:
|
|
# - we are not in a tenant
|
|
# - provsionning is disabled
|
|
# - there is nothing to do
|
|
if (
|
|
not hasattr(connection, 'tenant')
|
|
or not connection.tenant
|
|
or not hasattr(connection.tenant, 'domain_url')
|
|
):
|
|
return
|
|
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
|
|
return
|
|
if not (saved or deleted):
|
|
return
|
|
|
|
t = threading.Thread(target=self.do_provision, kwargs={'saved': saved, 'deleted': deleted})
|
|
t.start()
|
|
self.threads.add(t)
|
|
|
|
def do_provision(self, saved, deleted):
|
|
try:
|
|
ous = {ou.id: ou for ou in OU.objects.all()}
|
|
self.notify_roles(ous, saved.get(Role, []))
|
|
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
|
|
self.notify_users(ous, saved.get(User, []))
|
|
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
|
|
except Exception:
|
|
# last step, clear everything
|
|
logger.exception('error in provisionning thread')
|
|
finally:
|
|
self.threads.discard(threading.current_thread())
|
|
|
|
def wait(self):
|
|
for thread in list(self.threads):
|
|
thread.join()
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
if not self.stack:
|
|
return
|
|
self.stop(provision=exc_type is None)
|
|
|
|
def get_audience(self, ou):
|
|
if ou:
|
|
qs = LibertyProvider.objects.filter(ou=ou)
|
|
else:
|
|
qs = LibertyProvider.objects.filter(ou__isnull=True)
|
|
return [(service, service.entity_id) for service in qs]
|
|
|
|
def get_entity_id(self):
|
|
tenant = getattr(connection, 'tenant', None)
|
|
assert tenant
|
|
base_url = tenant.get_base_url()
|
|
return urllib.parse.urljoin(base_url, reverse('a2-idp-saml-metadata'))
|
|
|
|
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
# we skip new instances
|
|
if not instance.pk:
|
|
return
|
|
if not isinstance(instance, (User, Role, AttributeValue)):
|
|
return
|
|
# ignore last_login update on login
|
|
if isinstance(instance, User) and (update_fields and set(update_fields) == {'last_login'}):
|
|
return
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
|
|
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
# during post_save we only handle new instances
|
|
if isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
return
|
|
if not created:
|
|
return
|
|
if not isinstance(instance, (User, Role, AttributeValue)):
|
|
return
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
|
|
def pre_delete(self, sender, instance, using, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
if isinstance(instance, (User, Role)):
|
|
self.add_deleted(copy.copy(instance))
|
|
elif isinstance(instance, AttributeValue):
|
|
if not isinstance(instance.owner, User):
|
|
return
|
|
instance = instance.owner
|
|
self.add_saved(instance)
|
|
elif isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
|
|
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
|
|
if not self.stack:
|
|
return
|
|
if action != 'pre_clear' and action.startswith('pre_'):
|
|
return
|
|
if sender is Role.members.through:
|
|
self.add_saved(instance)
|
|
# on a clear, pk_set is None
|
|
for other_instance in model.objects.filter(pk__in=pk_set or []):
|
|
self.add_saved(other_instance)
|
|
if action == 'pre_clear':
|
|
# when the action is pre_clear we need to lookup the current value of the members
|
|
# relation, to re-provision all previously enroled users.
|
|
if not reverse:
|
|
for other_instance in instance.members.all():
|
|
self.add_saved(other_instance)
|
|
|
|
def post_soft_create(self, sender, instance, **kwargs):
|
|
if isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
|
|
def post_soft_delete(self, sender, instance, **kwargs):
|
|
if isinstance(instance, RoleParenting):
|
|
self.add_saved(*list(instance.child.all_members()))
|
|
|
|
def notify_agents(self, data, sync=False):
|
|
log_path = getattr(settings, 'DEBUG_PROVISIONNING_LOG_PATH', '')
|
|
if log_path and getattr(settings, 'HOBO_PROVISIONNING_DEBUG', False):
|
|
try:
|
|
with open(log_path, 'a') as f:
|
|
f.write('%s %s ' % (datetime.datetime.now().isoformat(), connection.tenant.domain_url))
|
|
json.dump(data, f, indent=2)
|
|
f.write('\n')
|
|
except OSError:
|
|
pass
|
|
|
|
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', True):
|
|
leftover_audience = self.notify_agents_http(data, sync=sync)
|
|
if not leftover_audience:
|
|
return
|
|
logger.info('leftover AMQP audience: %s', leftover_audience)
|
|
data['audience'] = leftover_audience
|
|
|
|
if data['audience']:
|
|
notify_agents(data)
|
|
|
|
def get_http_services_by_url(self):
|
|
services_by_url = {}
|
|
known_services = getattr(settings, 'KNOWN_SERVICES', {})
|
|
for services in known_services.values():
|
|
for service in services.values():
|
|
if service.get('provisionning-url'):
|
|
services_by_url[service['saml-sp-metadata-url']] = service
|
|
return services_by_url
|
|
|
|
def notify_agents_http(self, data, sync=False):
|
|
services_by_url = self.get_http_services_by_url()
|
|
audience = data.get('audience')
|
|
rest_audience = [x for x in audience if x in services_by_url]
|
|
leftover_audience = audience
|
|
for audience in rest_audience:
|
|
service = services_by_url[audience]
|
|
data['audience'] = [audience]
|
|
url = service['provisionning-url'] + '?orig=%s' % service['orig']
|
|
if sync:
|
|
url += '&sync=1'
|
|
try:
|
|
response = requests.put(sign_url(url, service['secret']), json=data)
|
|
response.raise_for_status()
|
|
except requests.RequestException as e:
|
|
logger.error('error provisionning to %s (%s)', audience, e)
|
|
else:
|
|
leftover_audience.remove(audience)
|
|
return leftover_audience
|