107 lines
4.3 KiB
Python
107 lines
4.3 KiB
Python
# hobo - portal to configure and deploy applications
|
|
# Copyright (C) 2015-2021 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
from django.conf import settings
|
|
|
|
from rest_framework import permissions, serializers, status
|
|
from rest_framework.response import Response
|
|
from rest_framework.generics import GenericAPIView
|
|
|
|
from . import provisionning
|
|
|
|
|
|
class ProvisionSerializer(serializers.Serializer):
|
|
user_uuid = serializers.CharField(required=False)
|
|
role_uuid = serializers.CharField(required=False)
|
|
service_type = serializers.CharField(required=False)
|
|
service_url = serializers.CharField(required=False)
|
|
|
|
def validate(self, data):
|
|
if not (data.get('user_uuid') or data.get('role_uuid')):
|
|
raise serializers.ValidationError('must provide user_uuid or role_uuid')
|
|
if data.get('user_uuid') and data.get('role_uuid'):
|
|
raise serializers.ValidationError('cannot provision both user & role')
|
|
return data
|
|
|
|
|
|
class ProvisionView(GenericAPIView):
|
|
permission_classes = (permissions.IsAuthenticated,)
|
|
serializer_class = ProvisionSerializer
|
|
|
|
def post(self, request):
|
|
serializer = self.get_serializer(data=request.data)
|
|
if not serializer.is_valid():
|
|
return Response({'err': 1, 'errors': serializer.errors}, status.HTTP_400_BAD_REQUEST)
|
|
|
|
engine = ApiProvisionningEngine(
|
|
service_type=serializer.validated_data.get('service_type'),
|
|
service_url=serializer.validated_data.get('service_url'),
|
|
)
|
|
|
|
user_uuid = serializer.validated_data.get('user_uuid')
|
|
role_uuid = serializer.validated_data.get('role_uuid')
|
|
if user_uuid:
|
|
try:
|
|
user = provisionning.User.objects.get(uuid=user_uuid)
|
|
except provisionning.User.DoesNotExist:
|
|
return Response({'err': 1, 'err_desc': 'unknown user UUID'})
|
|
engine.notify_users(ous=None, users=[user])
|
|
elif role_uuid:
|
|
try:
|
|
role = provisionning.Role.objects.get(uuid=role_uuid)
|
|
except provisionning.Role.DoesNotExist:
|
|
return Response({'err': 1, 'err_desc': 'unknown role UUID'})
|
|
ous = {ou.id: ou for ou in provisionning.OU.objects.all()}
|
|
engine.notify_roles(ous=ous, roles=[role])
|
|
|
|
response = {
|
|
'err': 0,
|
|
'leftover_audience': engine.leftover_audience,
|
|
'reached_audience': engine.reached_audience,
|
|
}
|
|
if engine.leftover_audience:
|
|
response['err'] = 1
|
|
return Response(response)
|
|
|
|
|
|
provision_view = ProvisionView.as_view()
|
|
|
|
|
|
class ApiProvisionningEngine(provisionning.Provisionning):
|
|
def __init__(self, service_type=None, service_url=None):
|
|
super().__init__()
|
|
self.service_type = service_type
|
|
self.service_url = service_url
|
|
|
|
def get_http_services_by_url(self):
|
|
if self.service_type:
|
|
services_by_url = {}
|
|
for service in settings.KNOWN_SERVICES[self.service_type].values():
|
|
if service.get('provisionning-url'):
|
|
services_by_url[service['saml-sp-metadata-url']] = service
|
|
else:
|
|
services_by_url = super().get_http_services_by_url()
|
|
if self.service_url:
|
|
services_by_url = {k: v for k, v in services_by_url.items() if self.service_url in v['url']}
|
|
return services_by_url
|
|
|
|
def notify_agents(self, data):
|
|
self.leftover_audience = self.notify_agents_http(data)
|
|
# only include filtered services in leftovers
|
|
services_by_url = self.get_http_services_by_url()
|
|
self.leftover_audience = [x for x in self.leftover_audience if x in services_by_url]
|
|
self.reached_audience = [x for x in services_by_url if x not in self.leftover_audience]
|