agent/a2: prevent useless thread launching (#34484)
This commit is contained in:
parent
b5bebd3e43
commit
e7abfc8ea7
|
@ -1,3 +1,5 @@
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
from .provisionning import provisionning
|
from .provisionning import provisionning
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,9 +8,9 @@ class ProvisionningMiddleware(object):
|
||||||
provisionning.start()
|
provisionning.start()
|
||||||
|
|
||||||
def process_exception(self, request, exception):
|
def process_exception(self, request, exception):
|
||||||
provisionning.clean()
|
provisionning.clear()
|
||||||
|
|
||||||
def process_response(self, request, response):
|
def process_response(self, request, response):
|
||||||
provisionning.provision()
|
provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
|
||||||
|
provisionning.clear()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
@ -15,44 +15,69 @@ from authentic2.saml.models import LibertyProvider
|
||||||
from authentic2.a2_rbac.models import RoleAttribute
|
from authentic2.a2_rbac.models import RoleAttribute
|
||||||
from authentic2.models import AttributeValue
|
from authentic2.models import AttributeValue
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
Role = get_role_model()
|
||||||
|
OU = get_ou_model()
|
||||||
|
RoleParenting = get_role_parenting_model()
|
||||||
|
|
||||||
class Provisionning(object):
|
logger = logging.getLogger(__name__)
|
||||||
local = threading.local()
|
|
||||||
|
|
||||||
|
class Provisionning(threading.local):
|
||||||
|
__slots__ = ['threads']
|
||||||
threads = set()
|
threads = set()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.User = get_user_model()
|
self.stack = []
|
||||||
self.Role = get_role_model()
|
|
||||||
self.OU = get_ou_model()
|
|
||||||
self.RoleParenting = get_role_parenting_model()
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.local.saved = {}
|
self.stack.append({
|
||||||
self.local.deleted = {}
|
'saved': {},
|
||||||
|
'deleted': {},
|
||||||
|
})
|
||||||
|
|
||||||
def clean(self):
|
def clear(self):
|
||||||
if hasattr(self.local, 'saved'):
|
self.stack = []
|
||||||
del self.local.saved
|
|
||||||
if hasattr(self.local, 'deleted'):
|
|
||||||
del self.local.deleted
|
|
||||||
|
|
||||||
def saved(self, *args):
|
def stop(self, provision=True, wait=True):
|
||||||
if not hasattr(self.local, 'saved'):
|
if not self.stack:
|
||||||
|
return
|
||||||
|
|
||||||
|
context = self.stack.pop()
|
||||||
|
|
||||||
|
if provision:
|
||||||
|
self.provision(**context)
|
||||||
|
if wait:
|
||||||
|
self.wait()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def saved(self):
|
||||||
|
if self.stack:
|
||||||
|
return self.stack[-1]['saved']
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deleted(self):
|
||||||
|
if self.stack:
|
||||||
|
return self.stack[-1]['deleted']
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_saved(self, *args):
|
||||||
|
if not self.stack:
|
||||||
return
|
return
|
||||||
|
|
||||||
for instance in args:
|
for instance in args:
|
||||||
klass = self.User if isinstance(instance, self.User) else self.Role
|
klass = User if isinstance(instance, User) else Role
|
||||||
self.local.saved.setdefault(klass, set()).add(instance)
|
self.saved.setdefault(klass, set()).add(instance)
|
||||||
|
|
||||||
def deleted(self, *args):
|
def add_deleted(self, *args):
|
||||||
if not hasattr(self.local, 'saved'):
|
if not self.stack:
|
||||||
return
|
return
|
||||||
|
|
||||||
for instance in args:
|
for instance in args:
|
||||||
klass = self.User if isinstance(instance, self.User) else self.Role
|
klass = User if isinstance(instance, User) else Role
|
||||||
self.local.deleted.setdefault(klass, set()).add(instance)
|
self.deleted.setdefault(klass, set()).add(instance)
|
||||||
self.local.saved.get(klass, set()).discard(instance)
|
self.saved.get(klass, set()).discard(instance)
|
||||||
|
|
||||||
def resolve_ou(self, instances, ous):
|
def resolve_ou(self, instances, ous):
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
|
@ -61,7 +86,7 @@ class Provisionning(object):
|
||||||
|
|
||||||
def notify_users(self, ous, users, mode='provision'):
|
def notify_users(self, ous, users, mode='provision'):
|
||||||
if mode == 'provision':
|
if mode == 'provision':
|
||||||
users = (self.User.objects.filter(id__in=[u.id for u in users])
|
users = (User.objects.filter(id__in=[u.id for u in users])
|
||||||
.select_related('ou').prefetch_related('attribute_values__attribute'))
|
.select_related('ou').prefetch_related('attribute_values__attribute'))
|
||||||
else:
|
else:
|
||||||
self.resolve_ou(users, ous)
|
self.resolve_ou(users, ous)
|
||||||
|
@ -105,19 +130,19 @@ class Provisionning(object):
|
||||||
# Find roles giving a superuser attribute
|
# Find roles giving a superuser attribute
|
||||||
# If there is any role of this kind, we do one provisionning message for each user and
|
# If there is any role of this kind, we do one provisionning message for each user and
|
||||||
# each service.
|
# each service.
|
||||||
roles_with_attributes = (self.Role.objects.filter(members__in=users)
|
roles_with_attributes = (Role.objects.filter(members__in=users)
|
||||||
.parents(include_self=True)
|
.parents(include_self=True)
|
||||||
.filter(attributes__name='is_superuser')
|
.filter(attributes__name='is_superuser')
|
||||||
.exists())
|
.exists())
|
||||||
|
|
||||||
all_roles = (self.Role.objects.filter(members__in=users).parents()
|
all_roles = (Role.objects.filter(members__in=users).parents()
|
||||||
.prefetch_related('attributes').distinct())
|
.prefetch_related('attributes').distinct())
|
||||||
roles = dict((r.id, r) for r in all_roles)
|
roles = dict((r.id, r) for r in all_roles)
|
||||||
user_roles = {}
|
user_roles = {}
|
||||||
parents = {}
|
parents = {}
|
||||||
for rp in self.RoleParenting.objects.filter(child__in=all_roles):
|
for rp in RoleParenting.objects.filter(child__in=all_roles):
|
||||||
parents.setdefault(rp.child.id, []).append(rp.parent.id)
|
parents.setdefault(rp.child.id, []).append(rp.parent.id)
|
||||||
Through = self.Role.members.through
|
Through = Role.members.through
|
||||||
for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
|
for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
|
||||||
'role_id'):
|
'role_id'):
|
||||||
user_roles.setdefault(u_id, set()).add(roles[r_id])
|
user_roles.setdefault(u_id, set()).add(roles[r_id])
|
||||||
|
@ -133,7 +158,7 @@ class Provisionning(object):
|
||||||
for ou, users in ous.iteritems():
|
for ou, users in ous.iteritems():
|
||||||
for service, audience in self.get_audience(ou):
|
for service, audience in self.get_audience(ou):
|
||||||
for user in users:
|
for user in users:
|
||||||
self.logger.info(u'provisionning user %s to %s', user, audience)
|
logger.info(u'provisionning user %s to %s', user, audience)
|
||||||
notify_agents({
|
notify_agents({
|
||||||
'@type': 'provision',
|
'@type': 'provision',
|
||||||
'issuer': issuer,
|
'issuer': issuer,
|
||||||
|
@ -149,7 +174,7 @@ class Provisionning(object):
|
||||||
audience = [a for service, a in self.get_audience(ou)]
|
audience = [a for service, a in self.get_audience(ou)]
|
||||||
if not audience:
|
if not audience:
|
||||||
continue
|
continue
|
||||||
self.logger.info(u'provisionning users %s to %s',
|
logger.info(u'provisionning users %s to %s',
|
||||||
u', '.join(map(unicode, users)), u', '.join(audience))
|
u', '.join(map(unicode, users)), u', '.join(audience))
|
||||||
notify_agents({
|
notify_agents({
|
||||||
'@type': 'provision',
|
'@type': 'provision',
|
||||||
|
@ -162,9 +187,9 @@ class Provisionning(object):
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
elif users:
|
elif users:
|
||||||
audience = [audience for ou in self.OU.objects.all()
|
audience = [audience for ou in OU.objects.all()
|
||||||
for s, audience in self.get_audience(ou)]
|
for s, audience in self.get_audience(ou)]
|
||||||
self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
|
logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
|
||||||
u', '.join(audience))
|
u', '.join(audience))
|
||||||
notify_agents({
|
notify_agents({
|
||||||
'@type': 'deprovision',
|
'@type': 'deprovision',
|
||||||
|
@ -213,7 +238,7 @@ class Provisionning(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
audience = [entity_id for service, entity_id in self.get_audience(ou)]
|
audience = [entity_id for service, entity_id in self.get_audience(ou)]
|
||||||
self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
|
logger.info(u'%sning roles %s to %s', mode, roles, audience)
|
||||||
notify_agents({
|
notify_agents({
|
||||||
'@type': mode,
|
'@type': mode,
|
||||||
'audience': audience,
|
'audience': audience,
|
||||||
|
@ -229,33 +254,35 @@ class Provisionning(object):
|
||||||
sent_roles = set(ou_roles) | global_roles
|
sent_roles = set(ou_roles) | global_roles
|
||||||
helper(ou, sent_roles)
|
helper(ou, sent_roles)
|
||||||
|
|
||||||
def provision(self):
|
def provision(self, saved, deleted):
|
||||||
|
# Returns if:
|
||||||
|
# - we are not in a tenant
|
||||||
|
# - provsionning is disabled
|
||||||
|
# - there is nothing to do
|
||||||
if (not hasattr(connection, 'tenant') or not connection.tenant or not
|
if (not hasattr(connection, 'tenant') or not connection.tenant or not
|
||||||
hasattr(connection.tenant, 'domain_url')):
|
hasattr(connection.tenant, 'domain_url')):
|
||||||
return
|
return
|
||||||
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
|
if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
|
||||||
return
|
return
|
||||||
# exit early if not started
|
if not (saved or deleted):
|
||||||
if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
t = threading.Thread(target=self.do_provision, kwargs={
|
t = threading.Thread(
|
||||||
'saved': getattr(self.local, 'saved', {}),
|
target=self.do_provision,
|
||||||
'deleted': getattr(self.local, 'deleted', {}),
|
kwargs={'saved': saved, 'deleted': deleted})
|
||||||
})
|
|
||||||
t.start()
|
t.start()
|
||||||
self.threads.add(t)
|
self.threads.add(t)
|
||||||
|
|
||||||
def do_provision(self, saved, deleted, thread=None):
|
def do_provision(self, saved, deleted):
|
||||||
try:
|
try:
|
||||||
ous = {ou.id: ou for ou in self.OU.objects.all()}
|
ous = {ou.id: ou for ou in OU.objects.all()}
|
||||||
self.notify_roles(ous, saved.get(self.Role, []))
|
self.notify_roles(ous, saved.get(Role, []))
|
||||||
self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
|
self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
|
||||||
self.notify_users(ous, saved.get(self.User, []))
|
self.notify_users(ous, saved.get(User, []))
|
||||||
self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
|
self.notify_users(ous, deleted.get(User, []), mode='deprovision')
|
||||||
except Exception:
|
except Exception:
|
||||||
# last step, clear everything
|
# last step, clear everything
|
||||||
self.logger.exception(u'error in provisionning thread')
|
logger.exception(u'error in provisionning thread')
|
||||||
finally:
|
finally:
|
||||||
self.threads.discard(threading.current_thread())
|
self.threads.discard(threading.current_thread())
|
||||||
|
|
||||||
|
@ -267,12 +294,9 @@ class Provisionning(object):
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||||
if exc_type is None:
|
if not self.stack:
|
||||||
self.provision()
|
return
|
||||||
self.clean()
|
self.stop(provision=exc_type is None)
|
||||||
self.wait()
|
|
||||||
else:
|
|
||||||
self.clean()
|
|
||||||
|
|
||||||
def get_audience(self, ou):
|
def get_audience(self, ou):
|
||||||
if ou:
|
if ou:
|
||||||
|
@ -298,64 +322,72 @@ class Provisionning(object):
|
||||||
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
|
return urljoin(base_url, reverse('a2-idp-saml-metadata'))
|
||||||
|
|
||||||
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
|
def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
|
||||||
|
if not self.stack:
|
||||||
|
return
|
||||||
# we skip new instances
|
# we skip new instances
|
||||||
if not instance.pk:
|
if not instance.pk:
|
||||||
return
|
return
|
||||||
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
|
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
||||||
return
|
return
|
||||||
# ignore last_login update on login
|
# ignore last_login update on login
|
||||||
if isinstance(instance, self.User) and update_fields == ['last_login']:
|
if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
|
||||||
return
|
return
|
||||||
if isinstance(instance, RoleAttribute):
|
if isinstance(instance, RoleAttribute):
|
||||||
instance = instance.role
|
instance = instance.role
|
||||||
elif isinstance(instance, AttributeValue):
|
elif isinstance(instance, AttributeValue):
|
||||||
if not isinstance(instance.owner, self.User):
|
if not isinstance(instance.owner, User):
|
||||||
return
|
return
|
||||||
instance = instance.owner
|
instance = instance.owner
|
||||||
self.saved(instance)
|
self.add_saved(instance)
|
||||||
|
|
||||||
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
|
def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
|
||||||
|
if not self.stack:
|
||||||
|
return
|
||||||
# during post_save we only handle new instances
|
# during post_save we only handle new instances
|
||||||
if isinstance(instance, self.RoleParenting):
|
if isinstance(instance, RoleParenting):
|
||||||
self.saved(*list(instance.child.all_members()))
|
self.add_saved(*list(instance.child.all_members()))
|
||||||
return
|
return
|
||||||
if not created:
|
if not created:
|
||||||
return
|
return
|
||||||
if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
|
if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
|
||||||
return
|
return
|
||||||
if isinstance(instance, RoleAttribute):
|
if isinstance(instance, RoleAttribute):
|
||||||
instance = instance.role
|
instance = instance.role
|
||||||
elif isinstance(instance, AttributeValue):
|
elif isinstance(instance, AttributeValue):
|
||||||
if not isinstance(instance.owner, self.User):
|
if not isinstance(instance.owner, User):
|
||||||
return
|
return
|
||||||
instance = instance.owner
|
instance = instance.owner
|
||||||
self.saved(instance)
|
self.add_saved(instance)
|
||||||
|
|
||||||
def pre_delete(self, sender, instance, using, **kwargs):
|
def pre_delete(self, sender, instance, using, **kwargs):
|
||||||
if isinstance(instance, (self.User, self.Role)):
|
if not self.stack:
|
||||||
self.deleted(copy.copy(instance))
|
return
|
||||||
|
if isinstance(instance, (User, Role)):
|
||||||
|
self.add_deleted(copy.copy(instance))
|
||||||
elif isinstance(instance, RoleAttribute):
|
elif isinstance(instance, RoleAttribute):
|
||||||
instance = instance.role
|
instance = instance.role
|
||||||
self.saved(instance)
|
self.add_saved(instance)
|
||||||
elif isinstance(instance, AttributeValue):
|
elif isinstance(instance, AttributeValue):
|
||||||
if not isinstance(instance.owner, self.User):
|
if not isinstance(instance.owner, User):
|
||||||
return
|
return
|
||||||
instance = instance.owner
|
instance = instance.owner
|
||||||
self.saved(instance)
|
self.add_saved(instance)
|
||||||
elif isinstance(instance, self.RoleParenting):
|
elif isinstance(instance, RoleParenting):
|
||||||
self.saved(*list(instance.child.all_members()))
|
self.add_saved(*list(instance.child.all_members()))
|
||||||
|
|
||||||
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
|
def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
|
||||||
|
if not self.stack:
|
||||||
|
return
|
||||||
if action != 'pre_clear' and action.startswith('pre_'):
|
if action != 'pre_clear' and action.startswith('pre_'):
|
||||||
return
|
return
|
||||||
if sender is self.Role.members.through:
|
if sender is Role.members.through:
|
||||||
self.saved(instance)
|
self.add_saved(instance)
|
||||||
# on a clear, pk_set is None
|
# on a clear, pk_set is None
|
||||||
for other_instance in model.objects.filter(pk__in=pk_set or []):
|
for other_instance in model.objects.filter(pk__in=pk_set or []):
|
||||||
self.saved(other_instance)
|
self.add_saved(other_instance)
|
||||||
if action == 'pre_clear':
|
if action == 'pre_clear':
|
||||||
# when the action is pre_clear we need to lookup the current value of the members
|
# when the action is pre_clear we need to lookup the current value of the members
|
||||||
# relation, to re-provision all previously enroled users.
|
# relation, to re-provision all previously enroled users.
|
||||||
if not reverse:
|
if not reverse:
|
||||||
for other_instance in instance.members.all():
|
for other_instance in instance.members.all():
|
||||||
self.saved(other_instance)
|
self.add_saved(other_instance)
|
||||||
|
|
|
@ -1,30 +1,31 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from django_webtest import WebTestMixin, DjangoTestApp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from django.db import connection, transaction
|
||||||
|
from tenant_schemas.postgresql_backend.base import FakeTenant
|
||||||
|
from tenant_schemas.utils import tenant_context
|
||||||
|
|
||||||
|
from hobo.multitenant.models import Tenant
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tenant_base(request, settings):
|
def tenant_base(tmpdir, settings):
|
||||||
base = tempfile.mkdtemp('authentic-tenant-base')
|
base = str(tmpdir.mkdir('authentic-tenant-base'))
|
||||||
settings.TENANT_BASE = base
|
settings.TENANT_BASE = base
|
||||||
|
|
||||||
def fin():
|
|
||||||
shutil.rmtree(base)
|
|
||||||
request.addfinalizer(fin)
|
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='function')
|
@pytest.fixture
|
||||||
def tenant(transactional_db, request, settings, tenant_base):
|
def tenant_factory(transactional_db, tenant_base, settings):
|
||||||
from hobo.multitenant.models import Tenant
|
tenants = []
|
||||||
base = tenant_base
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
settings.ALLOWED_HOSTS = ['*']
|
||||||
def make_tenant(name):
|
|
||||||
tenant_dir = os.path.join(base, name)
|
def factory(name):
|
||||||
|
tenant_dir = os.path.join(tenant_base, name)
|
||||||
os.mkdir(tenant_dir)
|
os.mkdir(tenant_dir)
|
||||||
with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
|
with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
|
||||||
fd.write('1')
|
fd.write('1')
|
||||||
|
@ -37,32 +38,63 @@ def tenant(transactional_db, request, settings, tenant_base):
|
||||||
'other_variable': 'foo',
|
'other_variable': 'foo',
|
||||||
},
|
},
|
||||||
'services': [
|
'services': [
|
||||||
{'slug': 'test',
|
{
|
||||||
'service-id': 'authentic',
|
'slug': 'test',
|
||||||
'title': 'Test',
|
'service-id': 'authentic',
|
||||||
'this': True,
|
'title': 'Test',
|
||||||
'secret_key': '12345',
|
'this': True,
|
||||||
'base_url': 'http://%s' % name,
|
'secret_key': '12345',
|
||||||
'variables': {
|
'base_url': 'http://%s' % name,
|
||||||
'other_variable': 'bar',
|
'variables': {
|
||||||
}
|
'other_variable': 'bar',
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{'slug': 'other',
|
{
|
||||||
'title': 'Other',
|
'slug': 'other',
|
||||||
'service-id': 'welco',
|
'title': 'Other',
|
||||||
'secret_key': 'abcdef',
|
'service-id': 'welco',
|
||||||
'base_url': 'http://other.example.net'},
|
'secret_key': 'abcdef',
|
||||||
]}, fd)
|
'base_url': 'http://other.example.net'
|
||||||
t = Tenant(domain_url=name,
|
},
|
||||||
schema_name=name.replace('-', '_').replace('.', '_'))
|
]
|
||||||
t.create_schema()
|
}, fd)
|
||||||
|
schema_name = name.replace('-', '_').replace('.', '_')
|
||||||
|
t = Tenant(domain_url=name, schema_name=schema_name)
|
||||||
|
with transaction.atomic():
|
||||||
|
t.create_schema()
|
||||||
|
tenants.append(t)
|
||||||
return t
|
return t
|
||||||
tenants = [make_tenant('authentic.example.net')]
|
try:
|
||||||
|
yield factory
|
||||||
def fin():
|
finally:
|
||||||
from django.db import connection
|
# cleanup all created tenants
|
||||||
connection.set_schema_to_public()
|
connection.set_schema_to_public()
|
||||||
for t in tenants:
|
with tenant_context(FakeTenant('public')):
|
||||||
t.delete(True)
|
for tenant in tenants:
|
||||||
request.addfinalizer(fin)
|
tenant.delete(force_drop=True)
|
||||||
return tenants[0]
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tenant(tenant_factory):
|
||||||
|
return tenant_factory('authentic.example.net')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_factory(request):
|
||||||
|
wtm = WebTestMixin()
|
||||||
|
wtm._patch_settings()
|
||||||
|
|
||||||
|
def factory(hostname='testserver'):
|
||||||
|
if hasattr(hostname, 'domain_url'):
|
||||||
|
hostname = hostname.domain_url
|
||||||
|
return DjangoTestApp(extra_environ={'HTTP_HOST': hostname})
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield factory
|
||||||
|
finally:
|
||||||
|
wtm._unpatch_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def notify_agents(mocker):
|
||||||
|
yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')
|
||||||
|
|
|
@ -39,3 +39,6 @@ CACHES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
HOBO_ROLE_EXPORT = True
|
HOBO_ROLE_EXPORT = True
|
||||||
|
|
||||||
|
SESSION_COOKIE_SECURE = False
|
||||||
|
CSRF_COOKIE_SECURE = False
|
||||||
|
|
|
@ -17,6 +17,8 @@ from authentic2.a2_rbac.utils import get_default_ou
|
||||||
from authentic2.models import Attribute, AttributeValue
|
from authentic2.models import Attribute, AttributeValue
|
||||||
from hobo.agent.authentic2.provisionning import provisionning
|
from hobo.agent.authentic2.provisionning import provisionning
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +128,6 @@ def test_provision_user(transactional_db, tenant, caplog):
|
||||||
role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
|
role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
|
||||||
child_role = Role.objects.create(name='child', ou=get_default_ou())
|
child_role = Role.objects.create(name='child', ou=get_default_ou())
|
||||||
notify_agents.reset_mock()
|
notify_agents.reset_mock()
|
||||||
User = get_user_model()
|
|
||||||
attribute = Attribute.objects.create(label='Code postal', name='code_postal',
|
attribute = Attribute.objects.create(label='Code postal', name='code_postal',
|
||||||
kind='string')
|
kind='string')
|
||||||
with provisionning:
|
with provisionning:
|
||||||
|
@ -277,7 +278,6 @@ def test_provision_user(transactional_db, tenant, caplog):
|
||||||
data = objects['data']
|
data = objects['data']
|
||||||
assert isinstance(data, list)
|
assert isinstance(data, list)
|
||||||
assert len(data) == 1
|
assert len(data) == 1
|
||||||
print data
|
|
||||||
for o in data:
|
for o in data:
|
||||||
assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
|
assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
|
||||||
'is_superuser', 'last_name', 'email', 'roles'])
|
'is_superuser', 'last_name', 'email', 'roles'])
|
||||||
|
@ -447,17 +447,15 @@ def test_provision_createsuperuser(transactional_db, tenant, caplog):
|
||||||
with tenant_context(tenant):
|
with tenant_context(tenant):
|
||||||
# create a provider so notification messages have an audience.
|
# create a provider so notification messages have an audience.
|
||||||
LibertyProvider.objects.create(ou=None, name='provider',
|
LibertyProvider.objects.create(ou=None, name='provider',
|
||||||
entity_id='http://provider.com',
|
entity_id='http://provider.com',
|
||||||
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
||||||
with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
|
with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
|
||||||
call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
|
call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
|
||||||
username='coin', email='coin@coin.org', interactive=False)
|
username='coin', email='coin@coin.org', interactive=False)
|
||||||
assert notify_agents.call_count == 1
|
assert notify_agents.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
@patch('hobo.agent.authentic2.provisionning.notify_agents')
|
|
||||||
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
|
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
|
||||||
User = get_user_model()
|
|
||||||
with tenant_context(tenant):
|
with tenant_context(tenant):
|
||||||
ou = get_default_ou()
|
ou = get_default_ou()
|
||||||
LibertyProvider.objects.create(ou=ou, name='provider',
|
LibertyProvider.objects.create(ou=ou, name='provider',
|
||||||
|
@ -486,3 +484,31 @@ def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog)
|
||||||
assert msg_2['full'] is False
|
assert msg_2['full'] is False
|
||||||
assert msg_2['objects']['@type'] == 'user'
|
assert msg_2['objects']['@type'] == 'user'
|
||||||
assert len(msg_2['objects']['data']) == 10
|
assert len(msg_2['objects']['data']) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware(notify_agents, app_factory, tenant, settings):
|
||||||
|
settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
|
||||||
|
|
||||||
|
with tenant_context(tenant):
|
||||||
|
user = User.objects.create(username='john', ou=get_default_ou())
|
||||||
|
user.set_password('password')
|
||||||
|
user.save()
|
||||||
|
LibertyProvider.objects.create(ou=get_default_ou(),
|
||||||
|
name='provider',
|
||||||
|
entity_id='http://provider.com',
|
||||||
|
protocol_conformance=lasso.PROTOCOL_SAML_2_0)
|
||||||
|
assert notify_agents.call_count == 0
|
||||||
|
|
||||||
|
app = app_factory(tenant)
|
||||||
|
resp = app.get('/login/')
|
||||||
|
form = resp.form
|
||||||
|
form.set('username', 'john')
|
||||||
|
form.set('password', 'password')
|
||||||
|
resp = form.submit(name='login-password-submit').follow()
|
||||||
|
resp = resp.click('Your account')
|
||||||
|
resp = resp.click('Edit')
|
||||||
|
resp.form.set('edit-profile-first_name', 'John')
|
||||||
|
resp.form.set('edit-profile-last_name', 'Doe')
|
||||||
|
assert notify_agents.call_count == 0
|
||||||
|
resp = resp.form.submit().follow()
|
||||||
|
assert notify_agents.call_count == 1
|
||||||
|
|
Loading…
Reference in New Issue