api: accept get/update_or_create parameter to user and role creation endpoint (fixes #22376)

This commit is contained in:
Benjamin Dauvergne 2019-03-15 03:17:04 +01:00
parent f685bb066e
commit d03f4fc8d3
3 changed files with 187 additions and 10 deletions

View File

@ -0,0 +1,100 @@
# authentic2 - versatile identity manager
# Copyright (C) 2010-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from django.db import transaction
from rest_framework.serializers import raise_errors_on_nested_writes
from rest_framework.utils import model_meta
class GetOrCreateModelSerializer(object):
def get_or_create(self, keys, validated_data):
raise_errors_on_nested_writes('get_or_create', self, validated_data)
ModelClass = self.Meta.model
# Remove many-to-many relationships from validated_data.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in validated_data):
many_to_many[field_name] = validated_data.pop(field_name)
kwargs = {}
defaults = kwargs['defaults'] = {}
missing_keys = set(keys) - set(validated_data)
if missing_keys:
raise TypeError('Keys %s are missing' % missing_keys)
for key, value in validated_data.items():
if key in keys:
kwargs[key] = value
else:
defaults[key] = value
with transaction.atomic():
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
if many_to_many and created:
self.update(instance, many_to_many)
return instance
def update_or_create(self, keys, validated_data):
raise_errors_on_nested_writes('update_or_create', self, validated_data)
ModelClass = self.Meta.model
# Remove many-to-many relationships from validated_data.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
get_or_create_data = validated_data.copy()
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in validated_data):
many_to_many[field_name] = get_or_create_data.pop(field_name)
kwargs = {}
defaults = kwargs['defaults'] = {}
missing_keys = set(keys) - set(get_or_create_data)
if missing_keys:
raise TypeError('Keys %s are missing' % missing_keys)
for key, value in get_or_create_data.items():
if key in keys:
kwargs[key] = value
else:
defaults[key] = value
with transaction.atomic():
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
if many_to_many or not created:
self.update(instance, validated_data)
return instance
def create(self, validated_data):
try:
keys = self.context['view'].request.GET.getlist('get_or_create')
except Exception:
pass
else:
if keys:
return self.get_or_create(keys, validated_data)
try:
keys = self.context['view'].request.GET.getlist('update_or_create')
except Exception:
pass
else:
if keys:
return self.update_or_create(keys, validated_data)
return super(GetOrCreateModelSerializer, self).create(validated_data)

View File

@ -46,7 +46,8 @@ from django_filters.rest_framework import FilterSet
from .passwords import get_password_checker
from .custom_user.models import User
from . import utils, decorators, attribute_kinds, app_settings, hooks
from . import (utils, decorators, attribute_kinds, app_settings, hooks,
api_mixins)
from .models import Attribute, PasswordReset, Service
from .a2_rbac.utils import get_default_ou
@ -321,7 +322,8 @@ def user(request):
return request.user.to_json()
class BaseUserSerializer(serializers.ModelSerializer):
class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer,
serializers.ModelSerializer):
ou = serializers.SlugRelatedField(
queryset=get_ou_model().objects.all(),
slug_field='slug',
@ -490,7 +492,7 @@ class BaseUserSerializer(serializers.ModelSerializer):
exclude = ('date_joined', 'user_permissions', 'groups', 'last_login')
class RoleSerializer(serializers.ModelSerializer):
class RoleSerializer(api_mixins.GetOrCreateModelSerializer, serializers.ModelSerializer):
ou = serializers.SlugRelatedField(
many=False,
required=False,

View File

@ -22,22 +22,25 @@ import random
import uuid
from django.core.urlresolvers import reverse
from django.contrib.auth.hashers import check_password
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from authentic2.a2_rbac.utils import get_default_ou
from django_rbac.utils import get_role_model, get_ou_model
from django_rbac.models import SEARCH_OP
from authentic2.models import Service
from django.core import mail
from django.contrib.auth.hashers import check_password
from django.core.urlresolvers import reverse
from authentic2_idp_oidc.models import OIDCClient
from django_rbac.models import SEARCH_OP
from django_rbac.utils import get_role_model, get_ou_model
from authentic2.a2_rbac.models import Role
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Service
from utils import login, basic_authorization_header, get_link_from_mail
pytestmark = pytest.mark.django_db
User = get_user_model()
def test_api_user_simple(logged_app):
resp = logged_app.get('/api/user/')
@ -1146,3 +1149,75 @@ def test_validate_password_regex(app, settings):
assert response.json['checks'][3]['result'] is True
assert response.json['checks'][4]['label'] == 'must contain "ok"'
assert response.json['checks'][4]['result'] is True
def test_api_users_get_or_create(settings, app, admin):
app.authorization = ('Basic', (admin.username, admin.username))
# test missing first_name
payload = {
'email': 'john.doe@example.net',
'first_name': 'John',
'last_name': 'Doe',
}
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
id = resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
payload['first_name'] = 'Jane'
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=201)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'Jane'
assert User.objects.get(id=id).last_name == 'Doe'
def test_api_users_get_or_create_multi_key(settings, app, admin):
app.authorization = ('Basic', (admin.username, admin.username))
# test missing first_name
payload = {
'email': 'john.doe@example.net',
'first_name': 'John',
'last_name': 'Doe',
}
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201)
id = resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
payload['email'] = 'john.doe@example2.net'
resp = app.post_json('/api/users/?update_or_create=first_name&update_or_create=last_name', params=payload, status=201)
assert id == resp.json['id']
assert User.objects.get(id=id).email == 'john.doe@example2.net'
def test_api_roles_get_or_create(settings, ou1, app, admin):
app.authorization = ('Basic', (admin.username, admin.username))
# test missing first_name
payload = {
'ou_slug': 'ou1',
'name': 'Role 1',
'slug': 'role-1',
}
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201)
uuid = resp.json['uuid']
assert Role.objects.get(uuid=uuid).name == 'Role 1'
assert Role.objects.get(uuid=uuid).slug == 'role-1'
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201)
assert uuid == resp.json['uuid']
payload['name'] = 'Role 2'
resp = app.post_json('/api/roles/?update_or_create=slug', params=payload, status=201)
assert uuid == resp.json['uuid']
assert Role.objects.get(uuid=uuid).name == 'Role 2'
assert Role.objects.get(uuid=uuid).slug == 'role-1'