agent/a2: prevent useless thread launching (#34484)

This commit is contained in:
Benjamin Dauvergne 2019-07-02 11:38:34 +02:00
parent b5bebd3e43
commit e7abfc8ea7
6 changed files with 219 additions and 124 deletions

View File

@ -1,3 +1,5 @@
from django.conf import settings
from .provisionning import provisionning from .provisionning import provisionning
@ -6,9 +8,9 @@ class ProvisionningMiddleware(object):
provisionning.start() provisionning.start()
def process_exception(self, request, exception): def process_exception(self, request, exception):
provisionning.clean() provisionning.clear()
def process_response(self, request, response): def process_response(self, request, response):
provisionning.provision() provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
provisionning.clear()
return response return response

View File

@ -15,44 +15,69 @@ from authentic2.saml.models import LibertyProvider
from authentic2.a2_rbac.models import RoleAttribute from authentic2.a2_rbac.models import RoleAttribute
from authentic2.models import AttributeValue from authentic2.models import AttributeValue
User = get_user_model()
Role = get_role_model()
OU = get_ou_model()
RoleParenting = get_role_parenting_model()
class Provisionning(object): logger = logging.getLogger(__name__)
local = threading.local()
class Provisionning(threading.local):
__slots__ = ['threads']
threads = set() threads = set()
def __init__(self): def __init__(self):
self.User = get_user_model() self.stack = []
self.Role = get_role_model()
self.OU = get_ou_model()
self.RoleParenting = get_role_parenting_model()
self.logger = logging.getLogger(__name__)
def start(self): def start(self):
self.local.saved = {} self.stack.append({
self.local.deleted = {} 'saved': {},
'deleted': {},
})
def clean(self): def clear(self):
if hasattr(self.local, 'saved'): self.stack = []
del self.local.saved
if hasattr(self.local, 'deleted'):
del self.local.deleted
def saved(self, *args): def stop(self, provision=True, wait=True):
if not hasattr(self.local, 'saved'): if not self.stack:
return
context = self.stack.pop()
if provision:
self.provision(**context)
if wait:
self.wait()
@property
def saved(self):
if self.stack:
return self.stack[-1]['saved']
return None
@property
def deleted(self):
if self.stack:
return self.stack[-1]['deleted']
return None
def add_saved(self, *args):
if not self.stack:
return return
for instance in args: for instance in args:
klass = self.User if isinstance(instance, self.User) else self.Role klass = User if isinstance(instance, User) else Role
self.local.saved.setdefault(klass, set()).add(instance) self.saved.setdefault(klass, set()).add(instance)
def deleted(self, *args): def add_deleted(self, *args):
if not hasattr(self.local, 'saved'): if not self.stack:
return return
for instance in args: for instance in args:
klass = self.User if isinstance(instance, self.User) else self.Role klass = User if isinstance(instance, User) else Role
self.local.deleted.setdefault(klass, set()).add(instance) self.deleted.setdefault(klass, set()).add(instance)
self.local.saved.get(klass, set()).discard(instance) self.saved.get(klass, set()).discard(instance)
def resolve_ou(self, instances, ous): def resolve_ou(self, instances, ous):
for instance in instances: for instance in instances:
@ -61,7 +86,7 @@ class Provisionning(object):
def notify_users(self, ous, users, mode='provision'): def notify_users(self, ous, users, mode='provision'):
if mode == 'provision': if mode == 'provision':
users = (self.User.objects.filter(id__in=[u.id for u in users]) users = (User.objects.filter(id__in=[u.id for u in users])
.select_related('ou').prefetch_related('attribute_values__attribute')) .select_related('ou').prefetch_related('attribute_values__attribute'))
else: else:
self.resolve_ou(users, ous) self.resolve_ou(users, ous)
@ -105,19 +130,19 @@ class Provisionning(object):
# Find roles giving a superuser attribute # Find roles giving a superuser attribute
# If there is any role of this kind, we do one provisionning message for each user and # If there is any role of this kind, we do one provisionning message for each user and
# each service. # each service.
roles_with_attributes = (self.Role.objects.filter(members__in=users) roles_with_attributes = (Role.objects.filter(members__in=users)
.parents(include_self=True) .parents(include_self=True)
.filter(attributes__name='is_superuser') .filter(attributes__name='is_superuser')
.exists()) .exists())
all_roles = (self.Role.objects.filter(members__in=users).parents() all_roles = (Role.objects.filter(members__in=users).parents()
.prefetch_related('attributes').distinct()) .prefetch_related('attributes').distinct())
roles = dict((r.id, r) for r in all_roles) roles = dict((r.id, r) for r in all_roles)
user_roles = {} user_roles = {}
parents = {} parents = {}
for rp in self.RoleParenting.objects.filter(child__in=all_roles): for rp in RoleParenting.objects.filter(child__in=all_roles):
parents.setdefault(rp.child.id, []).append(rp.parent.id) parents.setdefault(rp.child.id, []).append(rp.parent.id)
Through = self.Role.members.through Through = Role.members.through
for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id', for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
'role_id'): 'role_id'):
user_roles.setdefault(u_id, set()).add(roles[r_id]) user_roles.setdefault(u_id, set()).add(roles[r_id])
@ -133,7 +158,7 @@ class Provisionning(object):
for ou, users in ous.iteritems(): for ou, users in ous.iteritems():
for service, audience in self.get_audience(ou): for service, audience in self.get_audience(ou):
for user in users: for user in users:
self.logger.info(u'provisionning user %s to %s', user, audience) logger.info(u'provisionning user %s to %s', user, audience)
notify_agents({ notify_agents({
'@type': 'provision', '@type': 'provision',
'issuer': issuer, 'issuer': issuer,
@ -149,7 +174,7 @@ class Provisionning(object):
audience = [a for service, a in self.get_audience(ou)] audience = [a for service, a in self.get_audience(ou)]
if not audience: if not audience:
continue continue
self.logger.info(u'provisionning users %s to %s', logger.info(u'provisionning users %s to %s',
u', '.join(map(unicode, users)), u', '.join(audience)) u', '.join(map(unicode, users)), u', '.join(audience))
notify_agents({ notify_agents({
'@type': 'provision', '@type': 'provision',
@ -162,9 +187,9 @@ class Provisionning(object):
} }
}) })
elif users: elif users:
audience = [audience for ou in self.OU.objects.all() audience = [audience for ou in OU.objects.all()
for s, audience in self.get_audience(ou)] for s, audience in self.get_audience(ou)]
self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)), logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
u', '.join(audience)) u', '.join(audience))
notify_agents({ notify_agents({
'@type': 'deprovision', '@type': 'deprovision',
@ -213,7 +238,7 @@ class Provisionning(object):
] ]
audience = [entity_id for service, entity_id in self.get_audience(ou)] audience = [entity_id for service, entity_id in self.get_audience(ou)]
self.logger.info(u'%sning roles %s to %s', mode, roles, audience) logger.info(u'%sning roles %s to %s', mode, roles, audience)
notify_agents({ notify_agents({
'@type': mode, '@type': mode,
'audience': audience, 'audience': audience,
@ -229,33 +254,35 @@ class Provisionning(object):
sent_roles = set(ou_roles) | global_roles sent_roles = set(ou_roles) | global_roles
helper(ou, sent_roles) helper(ou, sent_roles)
def provision(self): def provision(self, saved, deleted):
# Returns if:
# - we are not in a tenant
# - provsionning is disabled
# - there is nothing to do
if (not hasattr(connection, 'tenant') or not connection.tenant or not if (not hasattr(connection, 'tenant') or not connection.tenant or not
hasattr(connection.tenant, 'domain_url')): hasattr(connection.tenant, 'domain_url')):
return return
if not getattr(settings, 'HOBO_ROLE_EXPORT', True): if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
return return
# exit early if not started if not (saved or deleted):
if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
return return
t = threading.Thread(target=self.do_provision, kwargs={ t = threading.Thread(
'saved': getattr(self.local, 'saved', {}), target=self.do_provision,
'deleted': getattr(self.local, 'deleted', {}), kwargs={'saved': saved, 'deleted': deleted})
})
t.start() t.start()
self.threads.add(t) self.threads.add(t)
def do_provision(self, saved, deleted, thread=None): def do_provision(self, saved, deleted):
try: try:
ous = {ou.id: ou for ou in self.OU.objects.all()} ous = {ou.id: ou for ou in OU.objects.all()}
self.notify_roles(ous, saved.get(self.Role, [])) self.notify_roles(ous, saved.get(Role, []))
self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision') self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
self.notify_users(ous, saved.get(self.User, [])) self.notify_users(ous, saved.get(User, []))
self.notify_users(ous, deleted.get(self.User, []), mode='deprovision') self.notify_users(ous, deleted.get(User, []), mode='deprovision')
except Exception: except Exception:
# last step, clear everything # last step, clear everything
self.logger.exception(u'error in provisionning thread') logger.exception(u'error in provisionning thread')
finally: finally:
self.threads.discard(threading.current_thread()) self.threads.discard(threading.current_thread())
@ -267,12 +294,9 @@ class Provisionning(object):
self.start() self.start()
def __exit__(self, exc_type, exc_value, exc_tb): def __exit__(self, exc_type, exc_value, exc_tb):
if exc_type is None: if not self.stack:
self.provision() return
self.clean() self.stop(provision=exc_type is None)
self.wait()
else:
self.clean()
def get_audience(self, ou): def get_audience(self, ou):
if ou: if ou:
@ -298,64 +322,72 @@ class Provisionning(object):
return urljoin(base_url, reverse('a2-idp-saml-metadata')) return urljoin(base_url, reverse('a2-idp-saml-metadata'))
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs): def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
if not self.stack:
return
# we skip new instances # we skip new instances
if not instance.pk: if not instance.pk:
return return
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)): if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return return
# ignore last_login update on login # ignore last_login update on login
if isinstance(instance, self.User) and update_fields == ['last_login']: if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
return return
if isinstance(instance, RoleAttribute): if isinstance(instance, RoleAttribute):
instance = instance.role instance = instance.role
elif isinstance(instance, AttributeValue): elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User): if not isinstance(instance.owner, User):
return return
instance = instance.owner instance = instance.owner
self.saved(instance) self.add_saved(instance)
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs): def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
if not self.stack:
return
# during post_save we only handle new instances # during post_save we only handle new instances
if isinstance(instance, self.RoleParenting): if isinstance(instance, RoleParenting):
self.saved(*list(instance.child.all_members())) self.add_saved(*list(instance.child.all_members()))
return return
if not created: if not created:
return return
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)): if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return return
if isinstance(instance, RoleAttribute): if isinstance(instance, RoleAttribute):
instance = instance.role instance = instance.role
elif isinstance(instance, AttributeValue): elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User): if not isinstance(instance.owner, User):
return return
instance = instance.owner instance = instance.owner
self.saved(instance) self.add_saved(instance)
def pre_delete(self, sender, instance, using, **kwargs): def pre_delete(self, sender, instance, using, **kwargs):
if isinstance(instance, (self.User, self.Role)): if not self.stack:
self.deleted(copy.copy(instance)) return
if isinstance(instance, (User, Role)):
self.add_deleted(copy.copy(instance))
elif isinstance(instance, RoleAttribute): elif isinstance(instance, RoleAttribute):
instance = instance.role instance = instance.role
self.saved(instance) self.add_saved(instance)
elif isinstance(instance, AttributeValue): elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User): if not isinstance(instance.owner, User):
return return
instance = instance.owner instance = instance.owner
self.saved(instance) self.add_saved(instance)
elif isinstance(instance, self.RoleParenting): elif isinstance(instance, RoleParenting):
self.saved(*list(instance.child.all_members())) self.add_saved(*list(instance.child.all_members()))
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs): def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
if not self.stack:
return
if action != 'pre_clear' and action.startswith('pre_'): if action != 'pre_clear' and action.startswith('pre_'):
return return
if sender is self.Role.members.through: if sender is Role.members.through:
self.saved(instance) self.add_saved(instance)
# on a clear, pk_set is None # on a clear, pk_set is None
for other_instance in model.objects.filter(pk__in=pk_set or []): for other_instance in model.objects.filter(pk__in=pk_set or []):
self.saved(other_instance) self.add_saved(other_instance)
if action == 'pre_clear': if action == 'pre_clear':
# when the action is pre_clear we need to lookup the current value of the members # when the action is pre_clear we need to lookup the current value of the members
# relation, to re-provision all previously enroled users. # relation, to re-provision all previously enroled users.
if not reverse: if not reverse:
for other_instance in instance.members.all(): for other_instance in instance.members.all():
self.saved(other_instance) self.add_saved(other_instance)

View File

@ -1,30 +1,31 @@
import os import os
import tempfile
import shutil
import json import json
from django_webtest import WebTestMixin, DjangoTestApp
import pytest import pytest
from django.db import connection, transaction
from tenant_schemas.postgresql_backend.base import FakeTenant
from tenant_schemas.utils import tenant_context
from hobo.multitenant.models import Tenant
@pytest.fixture @pytest.fixture
def tenant_base(request, settings): def tenant_base(tmpdir, settings):
base = tempfile.mkdtemp('authentic-tenant-base') base = str(tmpdir.mkdir('authentic-tenant-base'))
settings.TENANT_BASE = base settings.TENANT_BASE = base
def fin():
shutil.rmtree(base)
request.addfinalizer(fin)
return base return base
@pytest.fixture(scope='function') @pytest.fixture
def tenant(transactional_db, request, settings, tenant_base): def tenant_factory(transactional_db, tenant_base, settings):
from hobo.multitenant.models import Tenant tenants = []
base = tenant_base
@pytest.mark.django_db settings.ALLOWED_HOSTS = ['*']
def make_tenant(name):
tenant_dir = os.path.join(base, name) def factory(name):
tenant_dir = os.path.join(tenant_base, name)
os.mkdir(tenant_dir) os.mkdir(tenant_dir)
with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd: with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
fd.write('1') fd.write('1')
@ -37,32 +38,63 @@ def tenant(transactional_db, request, settings, tenant_base):
'other_variable': 'foo', 'other_variable': 'foo',
}, },
'services': [ 'services': [
{'slug': 'test', {
'service-id': 'authentic', 'slug': 'test',
'title': 'Test', 'service-id': 'authentic',
'this': True, 'title': 'Test',
'secret_key': '12345', 'this': True,
'base_url': 'http://%s' % name, 'secret_key': '12345',
'variables': { 'base_url': 'http://%s' % name,
'other_variable': 'bar', 'variables': {
} 'other_variable': 'bar',
}
}, },
{'slug': 'other', {
'title': 'Other', 'slug': 'other',
'service-id': 'welco', 'title': 'Other',
'secret_key': 'abcdef', 'service-id': 'welco',
'base_url': 'http://other.example.net'}, 'secret_key': 'abcdef',
]}, fd) 'base_url': 'http://other.example.net'
t = Tenant(domain_url=name, },
schema_name=name.replace('-', '_').replace('.', '_')) ]
t.create_schema() }, fd)
schema_name = name.replace('-', '_').replace('.', '_')
t = Tenant(domain_url=name, schema_name=schema_name)
with transaction.atomic():
t.create_schema()
tenants.append(t)
return t return t
tenants = [make_tenant('authentic.example.net')] try:
yield factory
def fin(): finally:
from django.db import connection # cleanup all created tenants
connection.set_schema_to_public() connection.set_schema_to_public()
for t in tenants: with tenant_context(FakeTenant('public')):
t.delete(True) for tenant in tenants:
request.addfinalizer(fin) tenant.delete(force_drop=True)
return tenants[0]
@pytest.fixture
def tenant(tenant_factory):
return tenant_factory('authentic.example.net')
@pytest.fixture
def app_factory(request):
wtm = WebTestMixin()
wtm._patch_settings()
def factory(hostname='testserver'):
if hasattr(hostname, 'domain_url'):
hostname = hostname.domain_url
return DjangoTestApp(extra_environ={'HTTP_HOST': hostname})
try:
yield factory
finally:
wtm._unpatch_settings()
@pytest.fixture
def notify_agents(mocker):
yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')

View File

@ -39,3 +39,6 @@ CACHES = {
} }
HOBO_ROLE_EXPORT = True HOBO_ROLE_EXPORT = True
SESSION_COOKIE_SECURE = False
CSRF_COOKIE_SECURE = False

View File

@ -17,6 +17,8 @@ from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Attribute, AttributeValue from authentic2.models import Attribute, AttributeValue
from hobo.agent.authentic2.provisionning import provisionning from hobo.agent.authentic2.provisionning import provisionning
User = get_user_model()
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@ -126,7 +128,6 @@ def test_provision_user(transactional_db, tenant, caplog):
role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]') role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
child_role = Role.objects.create(name='child', ou=get_default_ou()) child_role = Role.objects.create(name='child', ou=get_default_ou())
notify_agents.reset_mock() notify_agents.reset_mock()
User = get_user_model()
attribute = Attribute.objects.create(label='Code postal', name='code_postal', attribute = Attribute.objects.create(label='Code postal', name='code_postal',
kind='string') kind='string')
with provisionning: with provisionning:
@ -277,7 +278,6 @@ def test_provision_user(transactional_db, tenant, caplog):
data = objects['data'] data = objects['data']
assert isinstance(data, list) assert isinstance(data, list)
assert len(data) == 1 assert len(data) == 1
print data
for o in data: for o in data:
assert set(o.keys()) >= set(['uuid', 'username', 'first_name', assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
'is_superuser', 'last_name', 'email', 'roles']) 'is_superuser', 'last_name', 'email', 'roles'])
@ -447,17 +447,15 @@ def test_provision_createsuperuser(transactional_db, tenant, caplog):
with tenant_context(tenant): with tenant_context(tenant):
# create a provider so notification messages have an audience. # create a provider so notification messages have an audience.
LibertyProvider.objects.create(ou=None, name='provider', LibertyProvider.objects.create(ou=None, name='provider',
entity_id='http://provider.com', entity_id='http://provider.com',
protocol_conformance=lasso.PROTOCOL_SAML_2_0) protocol_conformance=lasso.PROTOCOL_SAML_2_0)
with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents: with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
call_command('createsuperuser', domain=tenant.domain_url, uuid='coin', call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
username='coin', email='coin@coin.org', interactive=False) username='coin', email='coin@coin.org', interactive=False)
assert notify_agents.call_count == 1 assert notify_agents.call_count == 1
@patch('hobo.agent.authentic2.provisionning.notify_agents')
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog): def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
User = get_user_model()
with tenant_context(tenant): with tenant_context(tenant):
ou = get_default_ou() ou = get_default_ou()
LibertyProvider.objects.create(ou=ou, name='provider', LibertyProvider.objects.create(ou=ou, name='provider',
@ -486,3 +484,31 @@ def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog)
assert msg_2['full'] is False assert msg_2['full'] is False
assert msg_2['objects']['@type'] == 'user' assert msg_2['objects']['@type'] == 'user'
assert len(msg_2['objects']['data']) == 10 assert len(msg_2['objects']['data']) == 10
def test_middleware(notify_agents, app_factory, tenant, settings):
settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
with tenant_context(tenant):
user = User.objects.create(username='john', ou=get_default_ou())
user.set_password('password')
user.save()
LibertyProvider.objects.create(ou=get_default_ou(),
name='provider',
entity_id='http://provider.com',
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
assert notify_agents.call_count == 0
app = app_factory(tenant)
resp = app.get('/login/')
form = resp.form
form.set('username', 'john')
form.set('password', 'password')
resp = form.submit(name='login-password-submit').follow()
resp = resp.click('Your account')
resp = resp.click('Edit')
resp.form.set('edit-profile-first_name', 'John')
resp.form.set('edit-profile-last_name', 'Doe')
assert notify_agents.call_count == 0
resp = resp.form.submit().follow()
assert notify_agents.call_count == 1

View File

@ -39,7 +39,7 @@ deps:
cssselect cssselect
WebTest WebTest
django-mellon django-mellon
django-webtest<1.9.3 django-webtest
celery<4 celery<4
Markdown<3 Markdown<3
django-tables2<2.0 django-tables2<2.0