agent/a2: prevent useless thread launching (#34484)

This commit is contained in:
Benjamin Dauvergne 2019-07-02 11:38:34 +02:00
parent b5bebd3e43
commit e7abfc8ea7
6 changed files with 219 additions and 124 deletions

View File

@ -1,3 +1,5 @@
from django.conf import settings
from .provisionning import provisionning
@ -6,9 +8,9 @@ class ProvisionningMiddleware(object):
provisionning.start()
def process_exception(self, request, exception):
provisionning.clean()
provisionning.clear()
def process_response(self, request, response):
provisionning.provision()
provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
provisionning.clear()
return response

View File

@ -15,44 +15,69 @@ from authentic2.saml.models import LibertyProvider
from authentic2.a2_rbac.models import RoleAttribute
from authentic2.models import AttributeValue
User = get_user_model()
Role = get_role_model()
OU = get_ou_model()
RoleParenting = get_role_parenting_model()
class Provisionning(object):
local = threading.local()
logger = logging.getLogger(__name__)
class Provisionning(threading.local):
__slots__ = ['threads']
threads = set()
def __init__(self):
self.User = get_user_model()
self.Role = get_role_model()
self.OU = get_ou_model()
self.RoleParenting = get_role_parenting_model()
self.logger = logging.getLogger(__name__)
self.stack = []
def start(self):
self.local.saved = {}
self.local.deleted = {}
self.stack.append({
'saved': {},
'deleted': {},
})
def clean(self):
if hasattr(self.local, 'saved'):
del self.local.saved
if hasattr(self.local, 'deleted'):
del self.local.deleted
def clear(self):
self.stack = []
def saved(self, *args):
if not hasattr(self.local, 'saved'):
def stop(self, provision=True, wait=True):
if not self.stack:
return
context = self.stack.pop()
if provision:
self.provision(**context)
if wait:
self.wait()
@property
def saved(self):
if self.stack:
return self.stack[-1]['saved']
return None
@property
def deleted(self):
if self.stack:
return self.stack[-1]['deleted']
return None
def add_saved(self, *args):
if not self.stack:
return
for instance in args:
klass = self.User if isinstance(instance, self.User) else self.Role
self.local.saved.setdefault(klass, set()).add(instance)
klass = User if isinstance(instance, User) else Role
self.saved.setdefault(klass, set()).add(instance)
def deleted(self, *args):
if not hasattr(self.local, 'saved'):
def add_deleted(self, *args):
if not self.stack:
return
for instance in args:
klass = self.User if isinstance(instance, self.User) else self.Role
self.local.deleted.setdefault(klass, set()).add(instance)
self.local.saved.get(klass, set()).discard(instance)
klass = User if isinstance(instance, User) else Role
self.deleted.setdefault(klass, set()).add(instance)
self.saved.get(klass, set()).discard(instance)
def resolve_ou(self, instances, ous):
for instance in instances:
@ -61,7 +86,7 @@ class Provisionning(object):
def notify_users(self, ous, users, mode='provision'):
if mode == 'provision':
users = (self.User.objects.filter(id__in=[u.id for u in users])
users = (User.objects.filter(id__in=[u.id for u in users])
.select_related('ou').prefetch_related('attribute_values__attribute'))
else:
self.resolve_ou(users, ous)
@ -105,19 +130,19 @@ class Provisionning(object):
# Find roles giving a superuser attribute
# If there is any role of this kind, we do one provisionning message for each user and
# each service.
roles_with_attributes = (self.Role.objects.filter(members__in=users)
roles_with_attributes = (Role.objects.filter(members__in=users)
.parents(include_self=True)
.filter(attributes__name='is_superuser')
.exists())
all_roles = (self.Role.objects.filter(members__in=users).parents()
all_roles = (Role.objects.filter(members__in=users).parents()
.prefetch_related('attributes').distinct())
roles = dict((r.id, r) for r in all_roles)
user_roles = {}
parents = {}
for rp in self.RoleParenting.objects.filter(child__in=all_roles):
for rp in RoleParenting.objects.filter(child__in=all_roles):
parents.setdefault(rp.child.id, []).append(rp.parent.id)
Through = self.Role.members.through
Through = Role.members.through
for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
'role_id'):
user_roles.setdefault(u_id, set()).add(roles[r_id])
@ -133,7 +158,7 @@ class Provisionning(object):
for ou, users in ous.iteritems():
for service, audience in self.get_audience(ou):
for user in users:
self.logger.info(u'provisionning user %s to %s', user, audience)
logger.info(u'provisionning user %s to %s', user, audience)
notify_agents({
'@type': 'provision',
'issuer': issuer,
@ -149,7 +174,7 @@ class Provisionning(object):
audience = [a for service, a in self.get_audience(ou)]
if not audience:
continue
self.logger.info(u'provisionning users %s to %s',
logger.info(u'provisionning users %s to %s',
u', '.join(map(unicode, users)), u', '.join(audience))
notify_agents({
'@type': 'provision',
@ -162,9 +187,9 @@ class Provisionning(object):
}
})
elif users:
audience = [audience for ou in self.OU.objects.all()
audience = [audience for ou in OU.objects.all()
for s, audience in self.get_audience(ou)]
self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
u', '.join(audience))
notify_agents({
'@type': 'deprovision',
@ -213,7 +238,7 @@ class Provisionning(object):
]
audience = [entity_id for service, entity_id in self.get_audience(ou)]
self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
logger.info(u'%sning roles %s to %s', mode, roles, audience)
notify_agents({
'@type': mode,
'audience': audience,
@ -229,33 +254,35 @@ class Provisionning(object):
sent_roles = set(ou_roles) | global_roles
helper(ou, sent_roles)
def provision(self):
def provision(self, saved, deleted):
# Returns if:
# - we are not in a tenant
# - provsionning is disabled
# - there is nothing to do
if (not hasattr(connection, 'tenant') or not connection.tenant or not
hasattr(connection.tenant, 'domain_url')):
return
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
return
# exit early if not started
if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
if not (saved or deleted):
return
t = threading.Thread(target=self.do_provision, kwargs={
'saved': getattr(self.local, 'saved', {}),
'deleted': getattr(self.local, 'deleted', {}),
})
t = threading.Thread(
target=self.do_provision,
kwargs={'saved': saved, 'deleted': deleted})
t.start()
self.threads.add(t)
def do_provision(self, saved, deleted, thread=None):
def do_provision(self, saved, deleted):
try:
ous = {ou.id: ou for ou in self.OU.objects.all()}
self.notify_roles(ous, saved.get(self.Role, []))
self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
self.notify_users(ous, saved.get(self.User, []))
self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
ous = {ou.id: ou for ou in OU.objects.all()}
self.notify_roles(ous, saved.get(Role, []))
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
self.notify_users(ous, saved.get(User, []))
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
except Exception:
# last step, clear everything
self.logger.exception(u'error in provisionning thread')
logger.exception(u'error in provisionning thread')
finally:
self.threads.discard(threading.current_thread())
@ -267,12 +294,9 @@ class Provisionning(object):
self.start()
def __exit__(self, exc_type, exc_value, exc_tb):
if exc_type is None:
self.provision()
self.clean()
self.wait()
else:
self.clean()
if not self.stack:
return
self.stop(provision=exc_type is None)
def get_audience(self, ou):
if ou:
@ -298,64 +322,72 @@ class Provisionning(object):
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
if not self.stack:
return
# we skip new instances
if not instance.pk:
return
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return
# ignore last_login update on login
if isinstance(instance, self.User) and update_fields == ['last_login']:
if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
return
if isinstance(instance, RoleAttribute):
instance = instance.role
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.saved(instance)
self.add_saved(instance)
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
if not self.stack:
return
# during post_save we only handle new instances
if isinstance(instance, self.RoleParenting):
self.saved(*list(instance.child.all_members()))
if isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
return
if not created:
return
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
return
if isinstance(instance, RoleAttribute):
instance = instance.role
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.saved(instance)
self.add_saved(instance)
def pre_delete(self, sender, instance, using, **kwargs):
if isinstance(instance, (self.User, self.Role)):
self.deleted(copy.copy(instance))
if not self.stack:
return
if isinstance(instance, (User, Role)):
self.add_deleted(copy.copy(instance))
elif isinstance(instance, RoleAttribute):
instance = instance.role
self.saved(instance)
self.add_saved(instance)
elif isinstance(instance, AttributeValue):
if not isinstance(instance.owner, self.User):
if not isinstance(instance.owner, User):
return
instance = instance.owner
self.saved(instance)
elif isinstance(instance, self.RoleParenting):
self.saved(*list(instance.child.all_members()))
self.add_saved(instance)
elif isinstance(instance, RoleParenting):
self.add_saved(*list(instance.child.all_members()))
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
if not self.stack:
return
if action != 'pre_clear' and action.startswith('pre_'):
return
if sender is self.Role.members.through:
self.saved(instance)
if sender is Role.members.through:
self.add_saved(instance)
# on a clear, pk_set is None
for other_instance in model.objects.filter(pk__in=pk_set or []):
self.saved(other_instance)
self.add_saved(other_instance)
if action == 'pre_clear':
# when the action is pre_clear we need to lookup the current value of the members
# relation, to re-provision all previously enroled users.
if not reverse:
for other_instance in instance.members.all():
self.saved(other_instance)
self.add_saved(other_instance)

View File

@ -1,30 +1,31 @@
import os
import tempfile
import shutil
import json
from django_webtest import WebTestMixin, DjangoTestApp
import pytest
from django.db import connection, transaction
from tenant_schemas.postgresql_backend.base import FakeTenant
from tenant_schemas.utils import tenant_context
from hobo.multitenant.models import Tenant
@pytest.fixture
def tenant_base(request, settings):
base = tempfile.mkdtemp('authentic-tenant-base')
def tenant_base(tmpdir, settings):
base = str(tmpdir.mkdir('authentic-tenant-base'))
settings.TENANT_BASE = base
def fin():
shutil.rmtree(base)
request.addfinalizer(fin)
return base
@pytest.fixture(scope='function')
def tenant(transactional_db, request, settings, tenant_base):
from hobo.multitenant.models import Tenant
base = tenant_base
@pytest.fixture
def tenant_factory(transactional_db, tenant_base, settings):
tenants = []
@pytest.mark.django_db
def make_tenant(name):
tenant_dir = os.path.join(base, name)
settings.ALLOWED_HOSTS = ['*']
def factory(name):
tenant_dir = os.path.join(tenant_base, name)
os.mkdir(tenant_dir)
with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
fd.write('1')
@ -37,32 +38,63 @@ def tenant(transactional_db, request, settings, tenant_base):
'other_variable': 'foo',
},
'services': [
{'slug': 'test',
'service-id': 'authentic',
'title': 'Test',
'this': True,
'secret_key': '12345',
'base_url': 'http://%s' % name,
'variables': {
'other_variable': 'bar',
}
{
'slug': 'test',
'service-id': 'authentic',
'title': 'Test',
'this': True,
'secret_key': '12345',
'base_url': 'http://%s' % name,
'variables': {
'other_variable': 'bar',
}
},
{'slug': 'other',
'title': 'Other',
'service-id': 'welco',
'secret_key': 'abcdef',
'base_url': 'http://other.example.net'},
]}, fd)
t = Tenant(domain_url=name,
schema_name=name.replace('-', '_').replace('.', '_'))
t.create_schema()
{
'slug': 'other',
'title': 'Other',
'service-id': 'welco',
'secret_key': 'abcdef',
'base_url': 'http://other.example.net'
},
]
}, fd)
schema_name = name.replace('-', '_').replace('.', '_')
t = Tenant(domain_url=name, schema_name=schema_name)
with transaction.atomic():
t.create_schema()
tenants.append(t)
return t
tenants = [make_tenant('authentic.example.net')]
def fin():
from django.db import connection
try:
yield factory
finally:
# cleanup all created tenants
connection.set_schema_to_public()
for t in tenants:
t.delete(True)
request.addfinalizer(fin)
return tenants[0]
with tenant_context(FakeTenant('public')):
for tenant in tenants:
tenant.delete(force_drop=True)
@pytest.fixture
def tenant(tenant_factory):
return tenant_factory('authentic.example.net')
@pytest.fixture
def app_factory(request):
wtm = WebTestMixin()
wtm._patch_settings()
def factory(hostname='testserver'):
if hasattr(hostname, 'domain_url'):
hostname = hostname.domain_url
return DjangoTestApp(extra_environ={'HTTP_HOST': hostname})
try:
yield factory
finally:
wtm._unpatch_settings()
@pytest.fixture
def notify_agents(mocker):
yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')

View File

@ -39,3 +39,6 @@ CACHES = {
}
HOBO_ROLE_EXPORT = True
SESSION_COOKIE_SECURE = False
CSRF_COOKIE_SECURE = False

View File

@ -17,6 +17,8 @@ from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Attribute, AttributeValue
from hobo.agent.authentic2.provisionning import provisionning
User = get_user_model()
pytestmark = pytest.mark.django_db
@ -126,7 +128,6 @@ def test_provision_user(transactional_db, tenant, caplog):
role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
child_role = Role.objects.create(name='child', ou=get_default_ou())
notify_agents.reset_mock()
User = get_user_model()
attribute = Attribute.objects.create(label='Code postal', name='code_postal',
kind='string')
with provisionning:
@ -277,7 +278,6 @@ def test_provision_user(transactional_db, tenant, caplog):
data = objects['data']
assert isinstance(data, list)
assert len(data) == 1
print data
for o in data:
assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
'is_superuser', 'last_name', 'email', 'roles'])
@ -447,17 +447,15 @@ def test_provision_createsuperuser(transactional_db, tenant, caplog):
with tenant_context(tenant):
# create a provider so notification messages have an audience.
LibertyProvider.objects.create(ou=None, name='provider',
entity_id='http://provider.com',
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
entity_id='http://provider.com',
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
username='coin', email='coin@coin.org', interactive=False)
assert notify_agents.call_count == 1
@patch('hobo.agent.authentic2.provisionning.notify_agents')
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
User = get_user_model()
with tenant_context(tenant):
ou = get_default_ou()
LibertyProvider.objects.create(ou=ou, name='provider',
@ -486,3 +484,31 @@ def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog)
assert msg_2['full'] is False
assert msg_2['objects']['@type'] == 'user'
assert len(msg_2['objects']['data']) == 10
def test_middleware(notify_agents, app_factory, tenant, settings):
settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
with tenant_context(tenant):
user = User.objects.create(username='john', ou=get_default_ou())
user.set_password('password')
user.save()
LibertyProvider.objects.create(ou=get_default_ou(),
name='provider',
entity_id='http://provider.com',
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
assert notify_agents.call_count == 0
app = app_factory(tenant)
resp = app.get('/login/')
form = resp.form
form.set('username', 'john')
form.set('password', 'password')
resp = form.submit(name='login-password-submit').follow()
resp = resp.click('Your account')
resp = resp.click('Edit')
resp.form.set('edit-profile-first_name', 'John')
resp.form.set('edit-profile-last_name', 'Doe')
assert notify_agents.call_count == 0
resp = resp.form.submit().follow()
assert notify_agents.call_count == 1

View File

@ -39,7 +39,7 @@ deps:
cssselect
WebTest
django-mellon
django-webtest<1.9.3
django-webtest
celery<4
Markdown<3
django-tables2<2.0