api: recreate get/update_or_create mixin at the view level (#35710)

This commit is contained in:
Benjamin Dauvergne 2019-09-02 14:34:34 +02:00
parent 7669f2d659
commit aa584ad97d
3 changed files with 83 additions and 90 deletions

View File

@ -16,87 +16,68 @@
from django.db import transaction
from rest_framework.serializers import raise_errors_on_nested_writes
from rest_framework.settings import api_settings
from rest_framework.exceptions import ValidationError
from rest_framework.utils import model_meta
class GetOrCreateModelSerializer(object):
def get_or_create(self, keys, validated_data):
raise_errors_on_nested_writes('get_or_create', self, validated_data)
class GetOrCreateMixinView(object):
_lookup_object = None
ModelClass = self.Meta.model
def get_object(self):
if self._lookup_object is not None:
return self._lookup_object
return super(GetOrCreateMixinView, self).get_object()
def _get_lookup_keys(self, name):
return self.request.GET.getlist(name)
def _get_model_class(self):
serializer_class = self.get_serializer_class()
return serializer_class.Meta.model
def _lookup_instance(self, keys):
ModelClass = self._get_model_class()
kwargs = {}
for key in keys:
try:
kwargs[key] = self.request.data[key]
except KeyError:
raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: ['key %r is missing' % key]})
try:
return ModelClass.objects.get(**kwargs)
except ModelClass.DoesNotExist:
return None
def _validate_get_keys(self, keys):
ModelClass = self._get_model_class()
# Remove many-to-many relationships from validated_data.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in validated_data):
many_to_many[field_name] = validated_data.pop(field_name)
errors = []
for key in keys:
if key not in info.fields:
errors.append('unknown key %r' % key)
if key in info.relations:
errors.append('relation key %r cannot be used for lookup' % key)
if errors:
raise ValidationError({api_settings.NON_FIELD_ERRORS_KEY: errors})
kwargs = {}
defaults = kwargs['defaults'] = {}
missing_keys = set(keys) - set(validated_data)
if missing_keys:
raise TypeError('Keys %s are missing' % missing_keys)
for key, value in validated_data.items():
if key in keys:
kwargs[key] = value
else:
defaults[key] = value
with transaction.atomic():
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
instance._a2_created = created
if many_to_many and created:
self.update(instance, many_to_many)
return instance
def update_or_create(self, keys, validated_data):
raise_errors_on_nested_writes('update_or_create', self, validated_data)
ModelClass = self.Meta.model
# Remove many-to-many relationships from validated_data.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
get_or_create_data = validated_data.copy()
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in validated_data):
many_to_many[field_name] = get_or_create_data.pop(field_name)
kwargs = {}
defaults = kwargs['defaults'] = {}
missing_keys = set(keys) - set(get_or_create_data)
if missing_keys:
raise TypeError('Keys %s are missing' % missing_keys)
for key, value in get_or_create_data.items():
if key in keys:
kwargs[key] = value
else:
defaults[key] = value
with transaction.atomic():
instance, created = self.Meta.model._default_manager.get_or_create(**kwargs)
instance._a2_created = created
if many_to_many or not created:
self.update(instance, validated_data)
return instance
def create(self, validated_data):
try:
keys = self.context['view'].request.GET.getlist('get_or_create')
except Exception:
pass
else:
if keys:
return self.get_or_create(keys, validated_data)
try:
keys = self.context['view'].request.GET.getlist('update_or_create')
except Exception:
pass
else:
if keys:
return self.update_or_create(keys, validated_data)
return super(GetOrCreateModelSerializer, self).create(validated_data)
@transaction.atomic
def create(self, request, *args, **kwargs):
get_or_create_keys = self._get_lookup_keys('get_or_create')
if get_or_create_keys:
self._validate_get_keys(get_or_create_keys)
self._lookup_object = self._lookup_instance(get_or_create_keys)
if self._lookup_object is not None:
return self.retrieve(request, *args, **kwargs)
update_or_create_keys = self._get_lookup_keys('update_or_create')
if update_or_create_keys:
self._validate_get_keys(update_or_create_keys)
self._lookup_object = self._lookup_instance(update_or_create_keys)
if self._lookup_object is not None:
return self.partial_update(request, *args, **kwargs)
return super(GetOrCreateMixinView, self).create(request, *args, **kwargs)

View File

@ -328,8 +328,7 @@ def user(request):
return request.user.to_json()
class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer,
serializers.ModelSerializer):
class BaseUserSerializer(serializers.ModelSerializer):
ou = serializers.SlugRelatedField(
queryset=get_ou_model().objects.all(),
slug_field='slug',
@ -509,7 +508,7 @@ class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer,
exclude = ('date_joined', 'user_permissions', 'groups', 'last_login')
class RoleSerializer(api_mixins.GetOrCreateModelSerializer, serializers.ModelSerializer):
class RoleSerializer(serializers.ModelSerializer):
ou = serializers.SlugRelatedField(
many=False,
required=False,
@ -608,7 +607,7 @@ class FreeTextSearchFilter(BaseFilterBackend):
return queryset
class UsersAPI(HookMixin, ExceptionHandlerMixin, ModelViewSet):
class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin, ModelViewSet):
ordering_fields = ['username', 'first_name', 'last_name', 'modified', 'date_joined']
lookup_field = 'uuid'
serializer_class = BaseUserSerializer
@ -709,7 +708,7 @@ class UsersAPI(HookMixin, ExceptionHandlerMixin, ModelViewSet):
return Response({'result': 1})
class RolesAPI(ExceptionHandlerMixin, ModelViewSet):
class RolesAPI(api_mixins.GetOrCreateMixinView, ExceptionHandlerMixin, ModelViewSet):
permission_classes = (permissions.IsAuthenticated,)
serializer_class = RoleSerializer
lookup_field = 'uuid'

View File

@ -1165,20 +1165,29 @@ def test_api_users_get_or_create(settings, app, admin):
assert User.objects.get(id=id).last_name == 'Doe'
password = User.objects.get(id=id).password
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
assert User.objects.get(id=id).password == password
paylaod = {}
payload['first_name'] = 'Jane'
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=201)
payload = {
'email': 'john.doe@example.net',
'first_name': 'Jane',
}
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'Jane'
assert User.objects.get(id=id).last_name == 'Doe'
assert User.objects.get(id=id).password == password
payload['password'] = 'secret'
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=200)
assert User.objects.get(id=id).first_name == 'Jane'
assert User.objects.get(id=id).last_name == 'Doe'
assert User.objects.get(id=id).password != password
assert User.objects.get(id=id).check_password('secret')
def test_api_users_get_or_create_email_is_unique(settings, app, admin):
settings.A2_EMAIL_IS_UNIQUE = True
@ -1194,13 +1203,13 @@ def test_api_users_get_or_create_email_is_unique(settings, app, admin):
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
payload['first_name'] = 'Jane'
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=201)
resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'Jane'
assert User.objects.get(id=id).last_name == 'Doe'
@ -1220,17 +1229,21 @@ def test_api_users_get_or_create_multi_key(settings, app, admin):
assert User.objects.get(id=id).last_name == 'Doe'
password = User.objects.get(id=id).password
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201)
resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).first_name == 'John'
assert User.objects.get(id=id).last_name == 'Doe'
assert User.objects.get(id=id).password == password
payload['email'] = 'john.doe@example2.net'
resp = app.post_json('/api/users/?update_or_create=first_name&update_or_create=last_name', params=payload, status=201)
payload['password'] = 'secret'
resp = app.post_json(
'/api/users/?update_or_create=first_name&update_or_create=last_name',
params=payload, status=200)
assert id == resp.json['id']
assert User.objects.get(id=id).email == 'john.doe@example2.net'
assert User.objects.get(id=id).password == password
assert User.objects.get(id=id).password != password
assert User.objects.get(id=id).check_password('secret')
def test_api_roles_get_or_create(settings, ou1, app, admin):
@ -1247,11 +1260,11 @@ def test_api_roles_get_or_create(settings, ou1, app, admin):
assert Role.objects.get(uuid=uuid).slug == 'role-1'
assert Role.objects.get(uuid=uuid).ou == ou1
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201)
resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=200)
assert uuid == resp.json['uuid']
payload['name'] = 'Role 2'
resp = app.post_json('/api/roles/?update_or_create=slug', params=payload, status=201)
resp = app.post_json('/api/roles/?update_or_create=slug', params=payload, status=200)
assert uuid == resp.json['uuid']
assert Role.objects.get(uuid=uuid).name == 'Role 2'
assert Role.objects.get(uuid=uuid).slug == 'role-1'