authentic2-wallonie-connect/src/authentic2_wallonie_connect/management/commands/wc-base-import.py

218 lines
8.4 KiB
Python

# authentic2-wallonie-connect - Authentic2 plugin for the Wallonie Connect usecase
# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import functools
import json
import sys
from django.db import transaction
from django.utils import six
from authentic2_idp_oidc.models import OIDCClient
from authentic2.a2_rbac.models import Role, OrganizationalUnit
from authentic2.custom_user.models import User
from django.core.management.base import BaseCommand
class DryRun(Exception):
pass
def dryrun(func):
@functools.wraps(func)
def f(*args, **kwargs):
try:
with transaction.atomic():
return func(*args, **kwargs)
except DryRun:
pass
return f
class Command(BaseCommand):
help = 'Create validation requests'
def add_arguments(self, parser):
parser.add_argument('--no-dry-run')
parser.add_argument('paths', nargs='+')
def handle(self, paths, no_dry_run=False, verbosity=1, **options):
self.no_dry_run = no_dry_run
self.verbosity = verbosity
for path in paths:
content = json.load(open(path))
self.do(content=content)
def info(self, *args, **kwargs):
if self.verbosity >= 1:
self.stdout.write(*args, **kwargs)
@dryrun
def do(self, content):
locality = content['locality']
self.info('Locality %s' % locality['name'], ending=' ')
ou, created = OrganizationalUnit.objects.get_or_create(
slug=locality['slug'],
defaults={'name': locality['name']})
if not created:
if ou.name != locality['name']:
ou.name = locality['name']
ou.save()
self.info(self.style.SUCCESS('UPDATED'))
else:
self.info('unchanged')
else:
self.info(self.style.SUCCESS('CREATED'))
services = {}
content_services = content.get('services', [])
assert isinstance(content_services, list)
for service in content_services:
name = service['name']
self.info('Service %s ' % name, ending=' ')
slug = service['slug']
client_id = service['client_id']
client_secret = service['client_secret']
frontchannel_logout_uri = service['frontchannel_logout_uri']
assert isinstance(frontchannel_logout_uri, six.text_type)
post_logout_redirect_uris = service.get('post_logout_redirect_uris', [])
assert isinstance(post_logout_redirect_uris, list)
open_to_all = service.get('open_to_all', False)
redirect_uris = service.get('redirect_uris', [])
assert isinstance(redirect_uris, list)
oidc_client, created = OIDCClient.objects.get_or_create(slug=service['slug'], ou=ou, defaults={
'name': name,
'client_id': client_id,
'client_secret': client_secret,
'frontchannel_logout_uri': frontchannel_logout_uri,
'post_logout_redirect_uris': '\n'.join(post_logout_redirect_uris),
'redirect_uris': '\n'.join(redirect_uris),
})
services[slug] = {
'oidc_client': oidc_client,
}
if not created:
modified = False
for key in ('name', 'client_id', 'client_secret',
'frontchannel_logout_uri', 'post_logout_redirect_uris',
'redirect_uris'):
if getattr(oidc_client, key) != locals()[key]:
setattr(oidc_client, key, locals()[key])
modified = True
# FIXME: open_to_all
if modified:
oidc_client.save()
self.info(self.style.SUCCESS('MODIFIED'))
else:
self.info('unchanged')
else:
self.info(self.style.SUCCESS('CREATED'))
if not open_to_all:
access_role, created = Role.objects.get_or_create(
slug=slug,
ou=ou,
defaults={
'name': name,
})
if not created and access_role.name != name:
access_role.name = name
access_role.save()
services[slug]['access_role'] = access_role
else:
Role.objects.filter(slug=slug, ou=ou).delete()
content_users = content.get('users', [])
assert isinstance(content_users, list)
password = None
email = None
first_name = None
last_name = None
for content_user in content_users:
required = ['email', 'username']
data = {}
for string_key in ('email', 'first_name', 'last_name', 'password', 'username'):
assert string_key in content_user, 'missing key ' + string_key
value = content_user[string_key]
assert isinstance(value, six.text_type), 'invalid type for key ' + string_key
if string_key in required:
assert value, 'missing value for key ' + string_key + ' %s' % content_user
data[string_key] = content_user[string_key]
assert 'password' in data
assert data['password'].startswith('{SSHA}')
uuid = content_user.get('uuid') or None
assert uuid is None or (isinstance(uuid, six.text_type) and uuid), 'invalid uuid %s %s' % (uuid, content_user)
allowed_services = content_user.get('allowed_services', [])
assert isinstance(allowed_services, list)
defaults = data.copy()
if uuid is not None:
self.info('User %s-%s' % (data['username'], uuid), ending=' ')
kwargs = {
'uuid': uuid,
'defaults': defaults,
}
else:
self.info('User %s' % data['username'], ending=' ')
kwargs = {
'username': defaults.pop('username'),
'defaults': defaults,
}
user, created = User.objects.get_or_create(**kwargs)
if created:
self.info(self.style.SUCCESS('CREATED'))
else:
modified = False
for key in defaults:
if getattr(user, key) != defaults[key]:
setattr(user, key, defaults[key])
modified = True
if modified:
user.save()
self.info(self.style.SUCCESS('MODIFIED'))
else:
self.info('unchanged')
for service_slug in allowed_services:
role = services[service_slug]['access_role']
service = services[service_slug]['oidc_client']
self.info('Access to service %s' % service.name, ending=' ')
if role.members.filter(pk=user.pk).exists():
self.info('unchanged')
else:
role.members.add(user)
self.info(self.style.SUCCESS('ADDED'))
for service_slug in set(services) - set(allowed_services):
role = services[service_slug]['access_role']
service = services[service_slug]['oidc_client']
self.info('Access to service %s' % service.name, ending=' ')
if role.members.filter(pk=user.pk).exists():
role.members.remove(user)
self.info(self.style.SUCCESS('REMOVED'))
else:
self.info('unchanged')
if self.no_dry_run:
return
raise DryRun