authentic/src/authentic2_idp_oidc/apps.py

111 lines
4.6 KiB
Python

# authentic2_idp_oidc - Authentic2 OIDC IdP plugin
# Copyright (C) 2017 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import django.apps
from django.utils.encoding import smart_bytes
from rest_framework.exceptions import APIException
class AppConfig(django.apps.AppConfig):
name = 'authentic2_idp_oidc'
# implement translation of encrypted pairwise identifiers when and OIDC Client is using the
# A2 API
def a2_hook_api_modify_serializer(self, view, serializer):
from . import utils
from rest_framework import serializers
if hasattr(view.request.user, 'oidc_client'):
client = view.request.user.oidc_client
if client.identifier_policy == client.POLICY_PAIRWISE_REVERSIBLE:
def get_oidc_uuuid(user):
return utils.make_pairwise_reversible_sub(client, user)
serializer.get_oidc_uuid = get_oidc_uuuid
serializer.fields['uuid'] = serializers.SerializerMethodField(
method_name='get_oidc_uuid')
def a2_hook_api_modify_view_before_get_object(self, view):
'''Decrypt sub used as pk argument in URL.'''
import uuid
from . import utils
request = view.request
if not hasattr(request.user, 'oidc_client'):
return
client = request.user.oidc_client
if client.identifier_policy != client.POLICY_PAIRWISE_REVERSIBLE:
return
lookup_url_kwarg = view.lookup_url_kwarg or view.lookup_field
if lookup_url_kwarg not in view.kwargs:
return
sub = smart_bytes(view.kwargs[lookup_url_kwarg])
decrypted = utils.reverse_pairwise_sub(client, sub)
if decrypted:
view.kwargs[lookup_url_kwarg] = uuid.UUID(bytes=decrypted).hex
def a2_hook_api_modify_serializer_after_validation(self, view, serializer):
import uuid
from . import utils
if view.__class__.__name__ != 'UsersAPI':
return
if serializer.__class__.__name__ != 'SynchronizationSerializer':
return
request = view.request
if not hasattr(request.user, 'oidc_client'):
return
client = request.user.oidc_client
if client.identifier_policy != client.POLICY_PAIRWISE_REVERSIBLE:
return
new_known_uuids = []
uuid_map = request.uuid_map = {}
request.unknown_uuids = []
for u in serializer.validated_data['known_uuids']:
decrypted = utils.reverse_pairwise_sub(client, smart_bytes(u))
if decrypted:
new_known_uuid = uuid.UUID(bytes=decrypted).hex
new_known_uuids.append(new_known_uuid)
uuid_map[new_known_uuid] = u
else:
request.unknown_uuids.append(u)
# undecipherable sub are just not checked at all
serializer.validated_data['known_uuids'] = new_known_uuids
def a2_hook_api_modify_response(self, view, method_name, data):
'''Reverse mapping applied in a2_hook_api_modify_serializer_after_validation using the
uuid_map saved on the request.
'''
request = view.request
if not hasattr(request.user, 'oidc_client'):
return
if view.__class__.__name__ != 'UsersAPI':
return
if method_name != 'synchronization':
return
if not hasattr(request, 'uuid_map'):
return
uuid_map = request.uuid_map
unknown_uuids = data['unknown_uuids']
new_unknown_uuids = []
for u in unknown_uuids:
new_unknown_uuids.append(uuid_map[u])
new_unknown_uuids.extend(request.unknown_uuids)
data['unknown_uuids'] = new_unknown_uuids