218 lines
8.4 KiB
Python
218 lines
8.4 KiB
Python
# authentic2-wallonie-connect - Authentic2 plugin for the Wallonie Connect usecase
|
|
# Copyright (C) 2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import functools
|
|
import json
|
|
import sys
|
|
|
|
from django.db import transaction
|
|
from django.utils import six
|
|
|
|
from authentic2_idp_oidc.models import OIDCClient
|
|
from authentic2.a2_rbac.models import Role, OrganizationalUnit
|
|
from authentic2.custom_user.models import User
|
|
|
|
|
|
from django.core.management.base import BaseCommand
|
|
|
|
|
|
class DryRun(Exception):
|
|
pass
|
|
|
|
|
|
def dryrun(func):
|
|
@functools.wraps(func)
|
|
def f(*args, **kwargs):
|
|
try:
|
|
with transaction.atomic():
|
|
return func(*args, **kwargs)
|
|
except DryRun:
|
|
pass
|
|
return f
|
|
|
|
|
|
class Command(BaseCommand):
|
|
help = 'Create validation requests'
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument('--no-dry-run')
|
|
parser.add_argument('paths', nargs='+')
|
|
|
|
def handle(self, paths, no_dry_run=False, verbosity=1, **options):
|
|
self.no_dry_run = no_dry_run
|
|
self.verbosity = verbosity
|
|
for path in paths:
|
|
content = json.load(open(path))
|
|
self.do(content=content)
|
|
|
|
def info(self, *args, **kwargs):
|
|
if self.verbosity >= 1:
|
|
self.stdout.write(*args, **kwargs)
|
|
|
|
@dryrun
|
|
def do(self, content):
|
|
locality = content['locality']
|
|
|
|
self.info('Locality %s' % locality['name'], ending=' ')
|
|
|
|
ou, created = OrganizationalUnit.objects.get_or_create(
|
|
slug=locality['slug'],
|
|
defaults={'name': locality['name']})
|
|
|
|
if not created:
|
|
if ou.name != locality['name']:
|
|
ou.name = locality['name']
|
|
ou.save()
|
|
self.info(self.style.SUCCESS('UPDATED'))
|
|
else:
|
|
self.info('unchanged')
|
|
else:
|
|
self.info(self.style.SUCCESS('CREATED'))
|
|
|
|
services = {}
|
|
|
|
content_services = content.get('services', [])
|
|
assert isinstance(content_services, list)
|
|
|
|
for service in content_services:
|
|
name = service['name']
|
|
self.info('Service %s ' % name, ending=' ')
|
|
slug = service['slug']
|
|
client_id = service['client_id']
|
|
client_secret = service['client_secret']
|
|
frontchannel_logout_uri = service['frontchannel_logout_uri']
|
|
assert isinstance(frontchannel_logout_uri, six.text_type)
|
|
post_logout_redirect_uris = service.get('post_logout_redirect_uris', [])
|
|
assert isinstance(post_logout_redirect_uris, list)
|
|
open_to_all = service.get('open_to_all', False)
|
|
redirect_uris = service.get('redirect_uris', [])
|
|
assert isinstance(redirect_uris, list)
|
|
|
|
oidc_client, created = OIDCClient.objects.get_or_create(slug=service['slug'], ou=ou, defaults={
|
|
'name': name,
|
|
'client_id': client_id,
|
|
'client_secret': client_secret,
|
|
'frontchannel_logout_uri': frontchannel_logout_uri,
|
|
'post_logout_redirect_uris': '\n'.join(post_logout_redirect_uris),
|
|
'redirect_uris': '\n'.join(redirect_uris),
|
|
})
|
|
services[slug] = {
|
|
'oidc_client': oidc_client,
|
|
}
|
|
if not created:
|
|
modified = False
|
|
for key in ('name', 'client_id', 'client_secret',
|
|
'frontchannel_logout_uri', 'post_logout_redirect_uris',
|
|
'redirect_uris'):
|
|
if getattr(oidc_client, key) != locals()[key]:
|
|
setattr(oidc_client, key, locals()[key])
|
|
modified = True
|
|
# FIXME: open_to_all
|
|
if modified:
|
|
oidc_client.save()
|
|
self.info(self.style.SUCCESS('MODIFIED'))
|
|
else:
|
|
self.info('unchanged')
|
|
else:
|
|
self.info(self.style.SUCCESS('CREATED'))
|
|
if not open_to_all:
|
|
access_role, created = Role.objects.get_or_create(
|
|
slug=slug,
|
|
ou=ou,
|
|
defaults={
|
|
'name': name,
|
|
})
|
|
if not created and access_role.name != name:
|
|
access_role.name = name
|
|
access_role.save()
|
|
services[slug]['access_role'] = access_role
|
|
else:
|
|
Role.objects.filter(slug=slug, ou=ou).delete()
|
|
|
|
content_users = content.get('users', [])
|
|
assert isinstance(content_users, list)
|
|
|
|
password = None
|
|
email = None
|
|
first_name = None
|
|
last_name = None
|
|
for content_user in content_users:
|
|
required = ['email', 'username']
|
|
data = {}
|
|
for string_key in ('email', 'first_name', 'last_name', 'password', 'username'):
|
|
assert string_key in content_user, 'missing key ' + string_key
|
|
value = content_user[string_key]
|
|
assert isinstance(value, six.text_type), 'invalid type for key ' + string_key
|
|
if string_key in required:
|
|
assert value, 'missing value for key ' + string_key + ' %s' % content_user
|
|
data[string_key] = content_user[string_key]
|
|
assert 'password' in data
|
|
assert data['password'].startswith('{SSHA}')
|
|
uuid = content_user.get('uuid') or None
|
|
assert uuid is None or (isinstance(uuid, six.text_type) and uuid), 'invalid uuid %s %s' % (uuid, content_user)
|
|
allowed_services = content_user.get('allowed_services', [])
|
|
assert isinstance(allowed_services, list)
|
|
|
|
defaults = data.copy()
|
|
if uuid is not None:
|
|
self.info('User %s-%s' % (data['username'], uuid), ending=' ')
|
|
kwargs = {
|
|
'uuid': uuid,
|
|
'defaults': defaults,
|
|
}
|
|
else:
|
|
self.info('User %s' % data['username'], ending=' ')
|
|
kwargs = {
|
|
'username': defaults.pop('username'),
|
|
'defaults': defaults,
|
|
}
|
|
user, created = User.objects.get_or_create(**kwargs)
|
|
if created:
|
|
self.info(self.style.SUCCESS('CREATED'))
|
|
else:
|
|
modified = False
|
|
for key in defaults:
|
|
if getattr(user, key) != defaults[key]:
|
|
setattr(user, key, defaults[key])
|
|
modified = True
|
|
if modified:
|
|
user.save()
|
|
self.info(self.style.SUCCESS('MODIFIED'))
|
|
else:
|
|
self.info('unchanged')
|
|
for service_slug in allowed_services:
|
|
role = services[service_slug]['access_role']
|
|
service = services[service_slug]['oidc_client']
|
|
self.info('Access to service %s' % service.name, ending=' ')
|
|
if role.members.filter(pk=user.pk).exists():
|
|
self.info('unchanged')
|
|
else:
|
|
role.members.add(user)
|
|
self.info(self.style.SUCCESS('ADDED'))
|
|
for service_slug in set(services) - set(allowed_services):
|
|
role = services[service_slug]['access_role']
|
|
service = services[service_slug]['oidc_client']
|
|
self.info('Access to service %s' % service.name, ending=' ')
|
|
if role.members.filter(pk=user.pk).exists():
|
|
role.members.remove(user)
|
|
self.info(self.style.SUCCESS('REMOVED'))
|
|
else:
|
|
self.info('unchanged')
|
|
|
|
if self.no_dry_run:
|
|
return
|
|
raise DryRun
|