agent/a2: prevent useless thread launching (#34484)
This commit is contained in:
parent
b5bebd3e43
commit
e7abfc8ea7
|
@ -1,3 +1,5 @@
|
|||
from django.conf import settings
|
||||
|
||||
from .provisionning import provisionning
|
||||
|
||||
|
||||
|
@ -6,9 +8,9 @@ class ProvisionningMiddleware(object):
|
|||
provisionning.start()
|
||||
|
||||
def process_exception(self, request, exception):
|
||||
provisionning.clean()
|
||||
provisionning.clear()
|
||||
|
||||
def process_response(self, request, response):
|
||||
provisionning.provision()
|
||||
provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
|
||||
provisionning.clear()
|
||||
return response
|
||||
|
||||
|
|
|
@ -15,44 +15,69 @@ from authentic2.saml.models import LibertyProvider
|
|||
from authentic2.a2_rbac.models import RoleAttribute
|
||||
from authentic2.models import AttributeValue
|
||||
|
||||
User = get_user_model()
|
||||
Role = get_role_model()
|
||||
OU = get_ou_model()
|
||||
RoleParenting = get_role_parenting_model()
|
||||
|
||||
class Provisionning(object):
|
||||
local = threading.local()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Provisionning(threading.local):
|
||||
__slots__ = ['threads']
|
||||
threads = set()
|
||||
|
||||
def __init__(self):
|
||||
self.User = get_user_model()
|
||||
self.Role = get_role_model()
|
||||
self.OU = get_ou_model()
|
||||
self.RoleParenting = get_role_parenting_model()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.stack = []
|
||||
|
||||
def start(self):
|
||||
self.local.saved = {}
|
||||
self.local.deleted = {}
|
||||
self.stack.append({
|
||||
'saved': {},
|
||||
'deleted': {},
|
||||
})
|
||||
|
||||
def clean(self):
|
||||
if hasattr(self.local, 'saved'):
|
||||
del self.local.saved
|
||||
if hasattr(self.local, 'deleted'):
|
||||
del self.local.deleted
|
||||
def clear(self):
|
||||
self.stack = []
|
||||
|
||||
def saved(self, *args):
|
||||
if not hasattr(self.local, 'saved'):
|
||||
def stop(self, provision=True, wait=True):
|
||||
if not self.stack:
|
||||
return
|
||||
|
||||
context = self.stack.pop()
|
||||
|
||||
if provision:
|
||||
self.provision(**context)
|
||||
if wait:
|
||||
self.wait()
|
||||
|
||||
@property
|
||||
def saved(self):
|
||||
if self.stack:
|
||||
return self.stack[-1]['saved']
|
||||
return None
|
||||
|
||||
@property
|
||||
def deleted(self):
|
||||
if self.stack:
|
||||
return self.stack[-1]['deleted']
|
||||
return None
|
||||
|
||||
def add_saved(self, *args):
|
||||
if not self.stack:
|
||||
return
|
||||
|
||||
for instance in args:
|
||||
klass = self.User if isinstance(instance, self.User) else self.Role
|
||||
self.local.saved.setdefault(klass, set()).add(instance)
|
||||
klass = User if isinstance(instance, User) else Role
|
||||
self.saved.setdefault(klass, set()).add(instance)
|
||||
|
||||
def deleted(self, *args):
|
||||
if not hasattr(self.local, 'saved'):
|
||||
def add_deleted(self, *args):
|
||||
if not self.stack:
|
||||
return
|
||||
|
||||
for instance in args:
|
||||
klass = self.User if isinstance(instance, self.User) else self.Role
|
||||
self.local.deleted.setdefault(klass, set()).add(instance)
|
||||
self.local.saved.get(klass, set()).discard(instance)
|
||||
klass = User if isinstance(instance, User) else Role
|
||||
self.deleted.setdefault(klass, set()).add(instance)
|
||||
self.saved.get(klass, set()).discard(instance)
|
||||
|
||||
def resolve_ou(self, instances, ous):
|
||||
for instance in instances:
|
||||
|
@ -61,7 +86,7 @@ class Provisionning(object):
|
|||
|
||||
def notify_users(self, ous, users, mode='provision'):
|
||||
if mode == 'provision':
|
||||
users = (self.User.objects.filter(id__in=[u.id for u in users])
|
||||
users = (User.objects.filter(id__in=[u.id for u in users])
|
||||
.select_related('ou').prefetch_related('attribute_values__attribute'))
|
||||
else:
|
||||
self.resolve_ou(users, ous)
|
||||
|
@ -105,19 +130,19 @@ class Provisionning(object):
|
|||
# Find roles giving a superuser attribute
|
||||
# If there is any role of this kind, we do one provisionning message for each user and
|
||||
# each service.
|
||||
roles_with_attributes = (self.Role.objects.filter(members__in=users)
|
||||
roles_with_attributes = (Role.objects.filter(members__in=users)
|
||||
.parents(include_self=True)
|
||||
.filter(attributes__name='is_superuser')
|
||||
.exists())
|
||||
|
||||
all_roles = (self.Role.objects.filter(members__in=users).parents()
|
||||
all_roles = (Role.objects.filter(members__in=users).parents()
|
||||
.prefetch_related('attributes').distinct())
|
||||
roles = dict((r.id, r) for r in all_roles)
|
||||
user_roles = {}
|
||||
parents = {}
|
||||
for rp in self.RoleParenting.objects.filter(child__in=all_roles):
|
||||
for rp in RoleParenting.objects.filter(child__in=all_roles):
|
||||
parents.setdefault(rp.child.id, []).append(rp.parent.id)
|
||||
Through = self.Role.members.through
|
||||
Through = Role.members.through
|
||||
for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
|
||||
'role_id'):
|
||||
user_roles.setdefault(u_id, set()).add(roles[r_id])
|
||||
|
@ -133,7 +158,7 @@ class Provisionning(object):
|
|||
for ou, users in ous.iteritems():
|
||||
for service, audience in self.get_audience(ou):
|
||||
for user in users:
|
||||
self.logger.info(u'provisionning user %s to %s', user, audience)
|
||||
logger.info(u'provisionning user %s to %s', user, audience)
|
||||
notify_agents({
|
||||
'@type': 'provision',
|
||||
'issuer': issuer,
|
||||
|
@ -149,7 +174,7 @@ class Provisionning(object):
|
|||
audience = [a for service, a in self.get_audience(ou)]
|
||||
if not audience:
|
||||
continue
|
||||
self.logger.info(u'provisionning users %s to %s',
|
||||
logger.info(u'provisionning users %s to %s',
|
||||
u', '.join(map(unicode, users)), u', '.join(audience))
|
||||
notify_agents({
|
||||
'@type': 'provision',
|
||||
|
@ -162,9 +187,9 @@ class Provisionning(object):
|
|||
}
|
||||
})
|
||||
elif users:
|
||||
audience = [audience for ou in self.OU.objects.all()
|
||||
audience = [audience for ou in OU.objects.all()
|
||||
for s, audience in self.get_audience(ou)]
|
||||
self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
|
||||
logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
|
||||
u', '.join(audience))
|
||||
notify_agents({
|
||||
'@type': 'deprovision',
|
||||
|
@ -213,7 +238,7 @@ class Provisionning(object):
|
|||
]
|
||||
|
||||
audience = [entity_id for service, entity_id in self.get_audience(ou)]
|
||||
self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
|
||||
logger.info(u'%sning roles %s to %s', mode, roles, audience)
|
||||
notify_agents({
|
||||
'@type': mode,
|
||||
'audience': audience,
|
||||
|
@ -229,33 +254,35 @@ class Provisionning(object):
|
|||
sent_roles = set(ou_roles) | global_roles
|
||||
helper(ou, sent_roles)
|
||||
|
||||
def provision(self):
|
||||
def provision(self, saved, deleted):
|
||||
# Returns if:
|
||||
# - we are not in a tenant
|
||||
# - provsionning is disabled
|
||||
# - there is nothing to do
|
||||
if (not hasattr(connection, 'tenant') or not connection.tenant or not
|
||||
hasattr(connection.tenant, 'domain_url')):
|
||||
return
|
||||
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
|
||||
return
|
||||
# exit early if not started
|
||||
if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
|
||||
if not (saved or deleted):
|
||||
return
|
||||
|
||||
t = threading.Thread(target=self.do_provision, kwargs={
|
||||
'saved': getattr(self.local, 'saved', {}),
|
||||
'deleted': getattr(self.local, 'deleted', {}),
|
||||
})
|
||||
t = threading.Thread(
|
||||
target=self.do_provision,
|
||||
kwargs={'saved': saved, 'deleted': deleted})
|
||||
t.start()
|
||||
self.threads.add(t)
|
||||
|
||||
def do_provision(self, saved, deleted, thread=None):
|
||||
def do_provision(self, saved, deleted):
|
||||
try:
|
||||
ous = {ou.id: ou for ou in self.OU.objects.all()}
|
||||
self.notify_roles(ous, saved.get(self.Role, []))
|
||||
self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
|
||||
self.notify_users(ous, saved.get(self.User, []))
|
||||
self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
|
||||
ous = {ou.id: ou for ou in OU.objects.all()}
|
||||
self.notify_roles(ous, saved.get(Role, []))
|
||||
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
|
||||
self.notify_users(ous, saved.get(User, []))
|
||||
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
|
||||
except Exception:
|
||||
# last step, clear everything
|
||||
self.logger.exception(u'error in provisionning thread')
|
||||
logger.exception(u'error in provisionning thread')
|
||||
finally:
|
||||
self.threads.discard(threading.current_thread())
|
||||
|
||||
|
@ -267,12 +294,9 @@ class Provisionning(object):
|
|||
self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
if exc_type is None:
|
||||
self.provision()
|
||||
self.clean()
|
||||
self.wait()
|
||||
else:
|
||||
self.clean()
|
||||
if not self.stack:
|
||||
return
|
||||
self.stop(provision=exc_type is None)
|
||||
|
||||
def get_audience(self, ou):
|
||||
if ou:
|
||||
|
@ -298,64 +322,72 @@ class Provisionning(object):
|
|||
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
|
||||
|
||||
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
|
||||
if not self.stack:
|
||||
return
|
||||
# we skip new instances
|
||||
if not instance.pk:
|
||||
return
|
||||
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
|
||||
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
||||
return
|
||||
# ignore last_login update on login
|
||||
if isinstance(instance, self.User) and update_fields == ['last_login']:
|
||||
if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
|
||||
return
|
||||
if isinstance(instance, RoleAttribute):
|
||||
instance = instance.role
|
||||
elif isinstance(instance, AttributeValue):
|
||||
if not isinstance(instance.owner, self.User):
|
||||
if not isinstance(instance.owner, User):
|
||||
return
|
||||
instance = instance.owner
|
||||
self.saved(instance)
|
||||
self.add_saved(instance)
|
||||
|
||||
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
|
||||
if not self.stack:
|
||||
return
|
||||
# during post_save we only handle new instances
|
||||
if isinstance(instance, self.RoleParenting):
|
||||
self.saved(*list(instance.child.all_members()))
|
||||
if isinstance(instance, RoleParenting):
|
||||
self.add_saved(*list(instance.child.all_members()))
|
||||
return
|
||||
if not created:
|
||||
return
|
||||
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
|
||||
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
||||
return
|
||||
if isinstance(instance, RoleAttribute):
|
||||
instance = instance.role
|
||||
elif isinstance(instance, AttributeValue):
|
||||
if not isinstance(instance.owner, self.User):
|
||||
if not isinstance(instance.owner, User):
|
||||
return
|
||||
instance = instance.owner
|
||||
self.saved(instance)
|
||||
self.add_saved(instance)
|
||||
|
||||
def pre_delete(self, sender, instance, using, **kwargs):
|
||||
if isinstance(instance, (self.User, self.Role)):
|
||||
self.deleted(copy.copy(instance))
|
||||
if not self.stack:
|
||||
return
|
||||
if isinstance(instance, (User, Role)):
|
||||
self.add_deleted(copy.copy(instance))
|
||||
elif isinstance(instance, RoleAttribute):
|
||||
instance = instance.role
|
||||
self.saved(instance)
|
||||
self.add_saved(instance)
|
||||
elif isinstance(instance, AttributeValue):
|
||||
if not isinstance(instance.owner, self.User):
|
||||
if not isinstance(instance.owner, User):
|
||||
return
|
||||
instance = instance.owner
|
||||
self.saved(instance)
|
||||
elif isinstance(instance, self.RoleParenting):
|
||||
self.saved(*list(instance.child.all_members()))
|
||||
self.add_saved(instance)
|
||||
elif isinstance(instance, RoleParenting):
|
||||
self.add_saved(*list(instance.child.all_members()))
|
||||
|
||||
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
|
||||
if not self.stack:
|
||||
return
|
||||
if action != 'pre_clear' and action.startswith('pre_'):
|
||||
return
|
||||
if sender is self.Role.members.through:
|
||||
self.saved(instance)
|
||||
if sender is Role.members.through:
|
||||
self.add_saved(instance)
|
||||
# on a clear, pk_set is None
|
||||
for other_instance in model.objects.filter(pk__in=pk_set or []):
|
||||
self.saved(other_instance)
|
||||
self.add_saved(other_instance)
|
||||
if action == 'pre_clear':
|
||||
# when the action is pre_clear we need to lookup the current value of the members
|
||||
# relation, to re-provision all previously enroled users.
|
||||
if not reverse:
|
||||
for other_instance in instance.members.all():
|
||||
self.saved(other_instance)
|
||||
self.add_saved(other_instance)
|
||||
|
|
|
@ -1,30 +1,31 @@
|
|||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import json
|
||||
|
||||
from django_webtest import WebTestMixin, DjangoTestApp
|
||||
import pytest
|
||||
|
||||
from django.db import connection, transaction
|
||||
from tenant_schemas.postgresql_backend.base import FakeTenant
|
||||
from tenant_schemas.utils import tenant_context
|
||||
|
||||
from hobo.multitenant.models import Tenant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_base(request, settings):
|
||||
base = tempfile.mkdtemp('authentic-tenant-base')
|
||||
def tenant_base(tmpdir, settings):
|
||||
base = str(tmpdir.mkdir('authentic-tenant-base'))
|
||||
settings.TENANT_BASE = base
|
||||
|
||||
def fin():
|
||||
shutil.rmtree(base)
|
||||
request.addfinalizer(fin)
|
||||
return base
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def tenant(transactional_db, request, settings, tenant_base):
|
||||
from hobo.multitenant.models import Tenant
|
||||
base = tenant_base
|
||||
@pytest.fixture
|
||||
def tenant_factory(transactional_db, tenant_base, settings):
|
||||
tenants = []
|
||||
|
||||
@pytest.mark.django_db
|
||||
def make_tenant(name):
|
||||
tenant_dir = os.path.join(base, name)
|
||||
settings.ALLOWED_HOSTS = ['*']
|
||||
|
||||
def factory(name):
|
||||
tenant_dir = os.path.join(tenant_base, name)
|
||||
os.mkdir(tenant_dir)
|
||||
with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
|
||||
fd.write('1')
|
||||
|
@ -37,32 +38,63 @@ def tenant(transactional_db, request, settings, tenant_base):
|
|||
'other_variable': 'foo',
|
||||
},
|
||||
'services': [
|
||||
{'slug': 'test',
|
||||
'service-id': 'authentic',
|
||||
'title': 'Test',
|
||||
'this': True,
|
||||
'secret_key': '12345',
|
||||
'base_url': 'http://%s' % name,
|
||||
'variables': {
|
||||
'other_variable': 'bar',
|
||||
}
|
||||
{
|
||||
'slug': 'test',
|
||||
'service-id': 'authentic',
|
||||
'title': 'Test',
|
||||
'this': True,
|
||||
'secret_key': '12345',
|
||||
'base_url': 'http://%s' % name,
|
||||
'variables': {
|
||||
'other_variable': 'bar',
|
||||
}
|
||||
},
|
||||
{'slug': 'other',
|
||||
'title': 'Other',
|
||||
'service-id': 'welco',
|
||||
'secret_key': 'abcdef',
|
||||
'base_url': 'http://other.example.net'},
|
||||
]}, fd)
|
||||
t = Tenant(domain_url=name,
|
||||
schema_name=name.replace('-', '_').replace('.', '_'))
|
||||
t.create_schema()
|
||||
{
|
||||
'slug': 'other',
|
||||
'title': 'Other',
|
||||
'service-id': 'welco',
|
||||
'secret_key': 'abcdef',
|
||||
'base_url': 'http://other.example.net'
|
||||
},
|
||||
]
|
||||
}, fd)
|
||||
schema_name = name.replace('-', '_').replace('.', '_')
|
||||
t = Tenant(domain_url=name, schema_name=schema_name)
|
||||
with transaction.atomic():
|
||||
t.create_schema()
|
||||
tenants.append(t)
|
||||
return t
|
||||
tenants = [make_tenant('authentic.example.net')]
|
||||
|
||||
def fin():
|
||||
from django.db import connection
|
||||
try:
|
||||
yield factory
|
||||
finally:
|
||||
# cleanup all created tenants
|
||||
connection.set_schema_to_public()
|
||||
for t in tenants:
|
||||
t.delete(True)
|
||||
request.addfinalizer(fin)
|
||||
return tenants[0]
|
||||
with tenant_context(FakeTenant('public')):
|
||||
for tenant in tenants:
|
||||
tenant.delete(force_drop=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant(tenant_factory):
|
||||
return tenant_factory('authentic.example.net')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_factory(request):
|
||||
wtm = WebTestMixin()
|
||||
wtm._patch_settings()
|
||||
|
||||
def factory(hostname='testserver'):
|
||||
if hasattr(hostname, 'domain_url'):
|
||||
hostname = hostname.domain_url
|
||||
return DjangoTestApp(extra_environ={'HTTP_HOST': hostname})
|
||||
|
||||
try:
|
||||
yield factory
|
||||
finally:
|
||||
wtm._unpatch_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notify_agents(mocker):
|
||||
yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')
|
||||
|
|
|
@ -39,3 +39,6 @@ CACHES = {
|
|||
}
|
||||
|
||||
HOBO_ROLE_EXPORT = True
|
||||
|
||||
SESSION_COOKIE_SECURE = False
|
||||
CSRF_COOKIE_SECURE = False
|
||||
|
|
|
@ -17,6 +17,8 @@ from authentic2.a2_rbac.utils import get_default_ou
|
|||
from authentic2.models import Attribute, AttributeValue
|
||||
from hobo.agent.authentic2.provisionning import provisionning
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
|
@ -126,7 +128,6 @@ def test_provision_user(transactional_db, tenant, caplog):
|
|||
role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
|
||||
child_role = Role.objects.create(name='child', ou=get_default_ou())
|
||||
notify_agents.reset_mock()
|
||||
User = get_user_model()
|
||||
attribute = Attribute.objects.create(label='Code postal', name='code_postal',
|
||||
kind='string')
|
||||
with provisionning:
|
||||
|
@ -277,7 +278,6 @@ def test_provision_user(transactional_db, tenant, caplog):
|
|||
data = objects['data']
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
print data
|
||||
for o in data:
|
||||
assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
|
||||
'is_superuser', 'last_name', 'email', 'roles'])
|
||||
|
@ -447,17 +447,15 @@ def test_provision_createsuperuser(transactional_db, tenant, caplog):
|
|||
with tenant_context(tenant):
|
||||
# create a provider so notification messages have an audience.
|
||||
LibertyProvider.objects.create(ou=None, name='provider',
|
||||
entity_id='http://provider.com',
|
||||
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
||||
entity_id='http://provider.com',
|
||||
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
||||
with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
|
||||
call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
|
||||
username='coin', email='coin@coin.org', interactive=False)
|
||||
assert notify_agents.call_count == 1
|
||||
|
||||
|
||||
@patch('hobo.agent.authentic2.provisionning.notify_agents')
|
||||
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
|
||||
User = get_user_model()
|
||||
with tenant_context(tenant):
|
||||
ou = get_default_ou()
|
||||
LibertyProvider.objects.create(ou=ou, name='provider',
|
||||
|
@ -486,3 +484,31 @@ def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog)
|
|||
assert msg_2['full'] is False
|
||||
assert msg_2['objects']['@type'] == 'user'
|
||||
assert len(msg_2['objects']['data']) == 10
|
||||
|
||||
|
||||
def test_middleware(notify_agents, app_factory, tenant, settings):
|
||||
settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
|
||||
|
||||
with tenant_context(tenant):
|
||||
user = User.objects.create(username='john', ou=get_default_ou())
|
||||
user.set_password('password')
|
||||
user.save()
|
||||
LibertyProvider.objects.create(ou=get_default_ou(),
|
||||
name='provider',
|
||||
entity_id='http://provider.com',
|
||||
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
||||
assert notify_agents.call_count == 0
|
||||
|
||||
app = app_factory(tenant)
|
||||
resp = app.get('/login/')
|
||||
form = resp.form
|
||||
form.set('username', 'john')
|
||||
form.set('password', 'password')
|
||||
resp = form.submit(name='login-password-submit').follow()
|
||||
resp = resp.click('Your account')
|
||||
resp = resp.click('Edit')
|
||||
resp.form.set('edit-profile-first_name', 'John')
|
||||
resp.form.set('edit-profile-last_name', 'Doe')
|
||||
assert notify_agents.call_count == 0
|
||||
resp = resp.form.submit().follow()
|
||||
assert notify_agents.call_count == 1
|
||||
|
|
Loading…
Reference in New Issue