provisionning: add ?sync=1 parameter to /__provision__ API (#56920)

When used by the /api/provision API on authentic, it garantees the
provisionning is made synchronously.
This commit is contained in:
Benjamin Dauvergne 2021-09-14 07:00:19 +02:00
parent 6e896dbcdd
commit 96ce1fb02f
4 changed files with 28 additions and 19 deletions

View File

@ -122,7 +122,7 @@ class Provisionning(threading.local):
if instance.ou_id in ous:
instance.ou = ous[instance.ou_id]
def notify_users(self, ous, users, mode='provision'):
def notify_users(self, ous, users, mode='provision', sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
if mode == 'provision':
@ -240,7 +240,8 @@ class Provisionning(threading.local):
for user in batched_users
],
},
}
},
sync=sync,
)
else:
for ou, users in ous.items():
@ -262,7 +263,8 @@ class Provisionning(threading.local):
'@type': 'user',
'data': [user_to_json(ou, None, user, user_roles) for user in users],
},
}
},
sync=sync,
)
elif users:
audience = [audience for ou in ous.keys() for s, audience in self.get_audience(ou)]
@ -284,10 +286,11 @@ class Provisionning(threading.local):
for user in users
],
},
}
},
sync=sync,
)
def notify_roles(self, ous, roles, mode='provision', full=False):
def notify_roles(self, ous, roles, mode='provision', full=False, sync=False):
allowed_technical_roles_prefixes = getattr(settings, 'HOBO_PROVISION_ROLE_PREFIXES', []) or []
def is_forbidden_technical_role(role):
@ -340,7 +343,8 @@ class Provisionning(threading.local):
'@type': 'role',
'data': data,
},
}
},
sync=sync,
)
global_roles = set(ous.get(None, []))
@ -486,7 +490,7 @@ class Provisionning(threading.local):
for other_instance in instance.members.all():
self.add_saved(other_instance)
def notify_agents(self, data):
def notify_agents(self, data, sync=False):
log_path = getattr(settings, 'DEBUG_PROVISIONNING_LOG_PATH', '')
if log_path and getattr(settings, 'HOBO_PROVISIONNING_DEBUG', False):
try:
@ -498,7 +502,7 @@ class Provisionning(threading.local):
pass
if getattr(settings, 'HOBO_HTTP_PROVISIONNING', False):
leftover_audience = self.notify_agents_http(data)
leftover_audience = self.notify_agents_http(data, sync=sync)
if not leftover_audience:
return
logger.info('leftover AMQP audience: %s', leftover_audience)
@ -515,7 +519,7 @@ class Provisionning(threading.local):
services_by_url[service['saml-sp-metadata-url']] = service
return services_by_url
def notify_agents_http(self, data):
def notify_agents_http(self, data, sync=False):
services_by_url = self.get_http_services_by_url()
audience = data.get('audience')
rest_audience = [x for x in audience if x in services_by_url]
@ -523,11 +527,11 @@ class Provisionning(threading.local):
for audience in rest_audience:
service = services_by_url[audience]
data['audience'] = [audience]
url = service['provisionning-url'] + '?orig=%s' % service['orig']
if sync:
url += '&sync=1'
try:
response = requests.put(
sign_url(service['provisionning-url'] + '?orig=%s' % service['orig'], service['secret']),
json=data,
)
response = requests.put(sign_url(url, service['secret']), json=data)
response.raise_for_status()
except requests.RequestException as e:
logger.error(u'error provisionning to %s (%s)', audience, e)

View File

@ -57,14 +57,14 @@ class ProvisionView(GenericAPIView):
user = provisionning.User.objects.get(uuid=user_uuid)
except provisionning.User.DoesNotExist:
return Response({'err': 1, 'err_desc': 'unknown user UUID'})
engine.notify_users(ous=None, users=[user])
engine.notify_users(ous=None, users=[user], sync=True)
elif role_uuid:
try:
role = provisionning.Role.objects.get(uuid=role_uuid)
except provisionning.Role.DoesNotExist:
return Response({'err': 1, 'err_desc': 'unknown role UUID'})
ous = {ou.id: ou for ou in provisionning.OU.objects.all()}
engine.notify_roles(ous=ous, roles=[role])
engine.notify_roles(ous=ous, roles=[role], sync=True)
response = {
'err': 0,
@ -97,8 +97,8 @@ class ApiProvisionningEngine(provisionning.Provisionning):
services_by_url = {k: v for k, v in services_by_url.items() if self.service_url in v['url']}
return services_by_url
def notify_agents(self, data):
self.leftover_audience = self.notify_agents_http(data)
def notify_agents(self, data, sync=False):
self.leftover_audience = self.notify_agents_http(data, sync=sync)
# only include filtered services in leftovers
services_by_url = self.get_http_services_by_url()
self.leftover_audience = [x for x in self.leftover_audience if x in services_by_url]

View File

@ -24,7 +24,7 @@ from django.utils.deprecation import MiddlewareMixin
from django.utils.encoding import force_bytes, force_text
from django.utils.six.moves.urllib.parse import urlparse
from hobo.provisionning.utils import NotificationProcessing, TryAgain
from hobo.provisionning.utils import NotificationProcessing
from hobo.rest_authentication import PublikAuthentication, PublikAuthenticationFailed
@ -54,7 +54,7 @@ class ProvisionningMiddleware(MiddlewareMixin, NotificationProcessing):
full = notification['full'] if 'full' in notification else False
data = notification['objects']['data']
if 'uwsgi' in sys.modules:
if 'uwsgi' in sys.modules and 'sync' not in request.GET:
from hobo.provisionning.spooler import provision
tenant = getattr(connection, 'tenant', None)

View File

@ -623,6 +623,7 @@ def test_provision_using_http(transactional_db, tenant, settings, caplog):
assert notify_agents.call_count == 1
assert notify_agents.call_args[0][0]['audience'] == ['http://example.com']
assert requests_put.call_count == 1
assert '&sync=1' not in requests_put.call_args[0][0]
# cannot check audience passed to requests.put as it's the same
# dictionary that is altered afterwards and would thus also contain
# http://example.com.
@ -644,6 +645,7 @@ def test_provision_using_http(transactional_db, tenant, settings, caplog):
)
assert notify_agents.call_count == 0
assert requests_put.call_count == 2
assert '&sync=1' not in requests_put.call_args[0][0]
def test_provisionning_api(transactional_db, app_factory, tenant, settings, caplog):
@ -703,6 +705,7 @@ def test_provisionning_api(transactional_db, app_factory, tenant, settings, capl
signature.sign_url('/api/provision/?orig=%s' % orig, key), {'user_uuid': user.uuid}
)
assert requests_put.call_count == 2
assert '&sync=1' in requests_put.call_args[0][0]
assert not resp.json['leftover_audience']
assert set(resp.json['reached_audience']) == {
'http://other.example.net/metadata/',
@ -715,6 +718,7 @@ def test_provisionning_api(transactional_db, app_factory, tenant, settings, capl
{'user_uuid': user.uuid, 'service_type': 'welco'},
)
assert requests_put.call_count == 1
assert '&sync=1' in requests_put.call_args[0][0]
assert not resp.json['leftover_audience']
assert set(resp.json['reached_audience']) == {'http://other.example.net/metadata/'}
@ -724,6 +728,7 @@ def test_provisionning_api(transactional_db, app_factory, tenant, settings, capl
{'user_uuid': user.uuid, 'service_url': 'example.net'},
)
assert requests_put.call_count == 2
assert '&sync=1' in requests_put.call_args[0][0]
with patch('hobo.agent.authentic2.provisionning.requests.put') as requests_put:
resp = app.post_json(