api: accept get/update_or_create parameter to user and role creation endpoint (fixes #22376)
This commit is contained in:
parent
f685bb066e
commit
d03f4fc8d3
|
@ -0,0 +1,100 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) 2010-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from rest_framework.serializers import raise_errors_on_nested_writes
|
||||
from rest_framework.utils import model_meta
|
||||
|
||||
|
||||
class GetOrCreateModelSerializer(object):
|
||||
def get_or_create(self, keys, validated_data):
|
||||
raise_errors_on_nested_writes('get_or_create', self, validated_data)
|
||||
|
||||
ModelClass = self.Meta.model
|
||||
|
||||
# Remove many-to-many relationships from validated_data.
|
||||
# They are not valid arguments to the default `.create()` method,
|
||||
# as they require that the instance has already been saved.
|
||||
info = model_meta.get_field_info(ModelClass)
|
||||
many_to_many = {}
|
||||
for field_name, relation_info in info.relations.items():
|
||||
if relation_info.to_many and (field_name in validated_data):
|
||||
many_to_many[field_name] = validated_data.pop(field_name)
|
||||
|
||||
kwargs = {}
|
||||
defaults = kwargs['defaults'] = {}
|
||||
missing_keys = set(keys) - set(validated_data)
|
||||
if missing_keys:
|
||||
raise TypeError('Keys %s are missing' % missing_keys)
|
||||
for key, value in validated_data.items():
|
||||
if key in keys:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
defaults[key] = value
|
||||
with transaction.atomic():
|
||||
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
|
||||
if many_to_many and created:
|
||||
self.update(instance, many_to_many)
|
||||
return instance
|
||||
|
||||
def update_or_create(self, keys, validated_data):
|
||||
raise_errors_on_nested_writes('update_or_create', self, validated_data)
|
||||
|
||||
ModelClass = self.Meta.model
|
||||
|
||||
# Remove many-to-many relationships from validated_data.
|
||||
# They are not valid arguments to the default `.create()` method,
|
||||
# as they require that the instance has already been saved.
|
||||
info = model_meta.get_field_info(ModelClass)
|
||||
many_to_many = {}
|
||||
get_or_create_data = validated_data.copy()
|
||||
for field_name, relation_info in info.relations.items():
|
||||
if relation_info.to_many and (field_name in validated_data):
|
||||
many_to_many[field_name] = get_or_create_data.pop(field_name)
|
||||
|
||||
kwargs = {}
|
||||
defaults = kwargs['defaults'] = {}
|
||||
missing_keys = set(keys) - set(get_or_create_data)
|
||||
if missing_keys:
|
||||
raise TypeError('Keys %s are missing' % missing_keys)
|
||||
for key, value in get_or_create_data.items():
|
||||
if key in keys:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
defaults[key] = value
|
||||
with transaction.atomic():
|
||||
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
|
||||
if many_to_many or not created:
|
||||
self.update(instance, validated_data)
|
||||
return instance
|
||||
|
||||
def create(self, validated_data):
|
||||
try:
|
||||
keys = self.context['view'].request.GET.getlist('get_or_create')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
if keys:
|
||||
return self.get_or_create(keys, validated_data)
|
||||
try:
|
||||
keys = self.context['view'].request.GET.getlist('update_or_create')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
if keys:
|
||||
return self.update_or_create(keys, validated_data)
|
||||
return super(GetOrCreateModelSerializer, self).create(validated_data)
|
|
@ -46,7 +46,8 @@ from django_filters.rest_framework import FilterSet
|
|||
|
||||
from .passwords import get_password_checker
|
||||
from .custom_user.models import User
|
||||
from . import utils, decorators, attribute_kinds, app_settings, hooks
|
||||
from . import (utils, decorators, attribute_kinds, app_settings, hooks,
|
||||
api_mixins)
|
||||
from .models import Attribute, PasswordReset, Service
|
||||
from .a2_rbac.utils import get_default_ou
|
||||
|
||||
|
@ -321,7 +322,8 @@ def user(request):
|
|||
return request.user.to_json()
|
||||
|
||||
|
||||
class BaseUserSerializer(serializers.ModelSerializer):
|
||||
class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer,
|
||||
serializers.ModelSerializer):
|
||||
ou = serializers.SlugRelatedField(
|
||||
queryset=get_ou_model().objects.all(),
|
||||
slug_field='slug',
|
||||
|
@ -490,7 +492,7 @@ class BaseUserSerializer(serializers.ModelSerializer):
|
|||
exclude = ('date_joined', 'user_permissions', 'groups', 'last_login')
|
||||
|
||||
|
||||
class RoleSerializer(serializers.ModelSerializer):
|
||||
class RoleSerializer(api_mixins.GetOrCreateModelSerializer, serializers.ModelSerializer):
|
||||
ou = serializers.SlugRelatedField(
|
||||
many=False,
|
||||
required=False,
|
||||
|
|
|
@ -22,22 +22,25 @@ import random
|
|||
import uuid
|
||||
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from django_rbac.utils import get_role_model, get_ou_model
|
||||
from django_rbac.models import SEARCH_OP
|
||||
from authentic2.models import Service
|
||||
from django.core import mail
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
from authentic2_idp_oidc.models import OIDCClient
|
||||
from django_rbac.models import SEARCH_OP
|
||||
from django_rbac.utils import get_role_model, get_ou_model
|
||||
|
||||
from authentic2.a2_rbac.models import Role
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.models import Service
|
||||
|
||||
from utils import login, basic_authorization_header, get_link_from_mail
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def test_api_user_simple(logged_app):
|
||||
resp = logged_app.get('/api/user/')
|
||||
|
@ -1146,3 +1149,75 @@ def test_validate_password_regex(app, settings):
|
|||
assert response.json['checks'][3]['result'] is True
|
||||
assert response.json['checks'][4]['label'] == 'must contain "ok"'
|
||||
assert response.json['checks'][4]['result'] is True
|
||||
|
||||
|
||||
def test_api_users_get_or_create(settings, app, admin):
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
# test missing first_name
|
||||
payload = {
|
||||
'email': 'john.doe@example.net',
|
||||
'first_name': 'John',
|
||||
'last_name': 'Doe',
|
||||
}
|
||||
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
|
||||
id = resp.json['id']
|
||||
assert User.objects.get(id=id).first_name == 'John'
|
||||
assert User.objects.get(id=id).last_name == 'Doe'
|
||||
|
||||
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
|
||||
assert id == resp.json['id']
|
||||
assert User.objects.get(id=id).first_name == 'John'
|
||||
assert User.objects.get(id=id).last_name == 'Doe'
|
||||
|
||||
payload['first_name'] = 'Jane'
|
||||
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=201)
|
||||
assert id == resp.json['id']
|
||||
assert User.objects.get(id=id).first_name == 'Jane'
|
||||
assert User.objects.get(id=id).last_name == 'Doe'
|
||||
|
||||
|
||||
def test_api_users_get_or_create_multi_key(settings, app, admin):
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
# test missing first_name
|
||||
payload = {
|
||||
'email': 'john.doe@example.net',
|
||||
'first_name': 'John',
|
||||
'last_name': 'Doe',
|
||||
}
|
||||
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201)
|
||||
id = resp.json['id']
|
||||
assert User.objects.get(id=id).first_name == 'John'
|
||||
assert User.objects.get(id=id).last_name == 'Doe'
|
||||
|
||||
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201)
|
||||
assert id == resp.json['id']
|
||||
assert User.objects.get(id=id).first_name == 'John'
|
||||
assert User.objects.get(id=id).last_name == 'Doe'
|
||||
|
||||
payload['email'] = 'john.doe@example2.net'
|
||||
resp = app.post_json('/api/users/?update_or_create=first_name&update_or_create=last_name', params=payload, status=201)
|
||||
assert id == resp.json['id']
|
||||
assert User.objects.get(id=id).email == 'john.doe@example2.net'
|
||||
|
||||
|
||||
def test_api_roles_get_or_create(settings, ou1, app, admin):
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
# test missing first_name
|
||||
payload = {
|
||||
'ou_slug': 'ou1',
|
||||
'name': 'Role 1',
|
||||
'slug': 'role-1',
|
||||
}
|
||||
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201)
|
||||
uuid = resp.json['uuid']
|
||||
assert Role.objects.get(uuid=uuid).name == 'Role 1'
|
||||
assert Role.objects.get(uuid=uuid).slug == 'role-1'
|
||||
|
||||
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201)
|
||||
assert uuid == resp.json['uuid']
|
||||
|
||||
payload['name'] = 'Role 2'
|
||||
resp = app.post_json('/api/roles/?update_or_create=slug', params=payload, status=201)
|
||||
assert uuid == resp.json['uuid']
|
||||
assert Role.objects.get(uuid=uuid).name == 'Role 2'
|
||||
assert Role.objects.get(uuid=uuid).slug == 'role-1'
|
||||
|
|
Loading…
Reference in New Issue