Port get_schema_fields from DRF
This commit is contained in:
parent
5c7274c02b
commit
982c5ff1a6
|
@ -2,7 +2,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from django.template import Template, TemplateDoesNotExist, loader
|
||||
from rest_framework.compat import template_render
|
||||
from rest_framework.compat import coreapi, template_render
|
||||
from rest_framework.filters import BaseFilterBackend
|
||||
|
||||
from .. import compat
|
||||
|
@ -89,3 +89,15 @@ class DjangoFilterBackend(BaseFilterBackend):
|
|||
return template_render(template, context={
|
||||
'filter': filter_instance
|
||||
})
|
||||
|
||||
def get_schema_fields(self, view):
|
||||
# This is not compatible with widgets where the query param differs from the
|
||||
# filter's attribute name. Notably, this includes `MultiWidget`, where query
|
||||
# params will be of the format `<name>_0`, `<name>_1`, etc...
|
||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||
filter_class = self.get_filter_class(view, view.get_queryset())
|
||||
|
||||
return [] if not filter_class else [
|
||||
coreapi.Field(name=field_name, required=False, location='query')
|
||||
for field_name in filter_class().filters.keys()
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
from unittest import skipIf
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.test import TestCase
|
||||
|
@ -15,6 +16,7 @@ except ImportError:
|
|||
from django.core.urlresolvers import reverse
|
||||
|
||||
from rest_framework import generics, serializers, status
|
||||
from rest_framework.compat import coreapi
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from django_filters import filters
|
||||
|
@ -121,6 +123,35 @@ urlpatterns = [
|
|||
]
|
||||
|
||||
|
||||
@skipIf(coreapi is None, 'coreapi must be installed')
|
||||
class GetSchemaFieldsTests(TestCase):
|
||||
def test_fields_with_filter_fields_list(self):
|
||||
backend = DjangoFilterBackend()
|
||||
fields = backend.get_schema_fields(FilterFieldsRootView())
|
||||
fields = [f.name for f in fields]
|
||||
|
||||
self.assertEqual(fields, ['decimal', 'date'])
|
||||
|
||||
def test_fields_with_filter_fields_dict(self):
|
||||
class DictFilterFieldsRootView(FilterFieldsRootView):
|
||||
filter_fields = {
|
||||
'decimal': ['exact', 'lt', 'gt'],
|
||||
}
|
||||
|
||||
backend = DjangoFilterBackend()
|
||||
fields = backend.get_schema_fields(DictFilterFieldsRootView())
|
||||
fields = [f.name for f in fields]
|
||||
|
||||
self.assertEqual(fields, ['decimal', 'decimal__lt', 'decimal__gt'])
|
||||
|
||||
def test_fields_with_filter_class(self):
|
||||
backend = DjangoFilterBackend()
|
||||
fields = backend.get_schema_fields(FilterClassRootView())
|
||||
fields = [f.name for f in fields]
|
||||
|
||||
self.assertEqual(fields, ['text', 'decimal', 'date'])
|
||||
|
||||
|
||||
class CommonFilteringTestCase(TestCase):
|
||||
def _serialize_object(self, obj):
|
||||
return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
|
||||
|
|
Loading…
Reference in New Issue