Port get_schema_fields from DRF

This commit is contained in:
Ryan P Kilby 2016-09-21 17:43:07 -04:00
parent 5c7274c02b
commit 982c5ff1a6
2 changed files with 44 additions and 1 deletions

View File

@ -2,7 +2,7 @@
from __future__ import absolute_import
from django.template import Template, TemplateDoesNotExist, loader
from rest_framework.compat import template_render
from rest_framework.compat import coreapi, template_render
from rest_framework.filters import BaseFilterBackend
from .. import compat
@ -89,3 +89,15 @@ class DjangoFilterBackend(BaseFilterBackend):
return template_render(template, context={
'filter': filter_instance
})
def get_schema_fields(self, view):
# This is not compatible with widgets where the query param differs from the
# filter's attribute name. Notably, this includes `MultiWidget`, where query
# params will be of the format `<name>_0`, `<name>_1`, etc...
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
filter_class = self.get_filter_class(view, view.get_queryset())
return [] if not filter_class else [
coreapi.Field(name=field_name, required=False, location='query')
for field_name in filter_class().filters.keys()
]

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime
from decimal import Decimal
from unittest import skipIf
from django.conf.urls import url
from django.test import TestCase
@ -15,6 +16,7 @@ except ImportError:
from django.core.urlresolvers import reverse
from rest_framework import generics, serializers, status
from rest_framework.compat import coreapi
from rest_framework.test import APIRequestFactory
from django_filters import filters
@ -121,6 +123,35 @@ urlpatterns = [
]
@skipIf(coreapi is None, 'coreapi must be installed')
class GetSchemaFieldsTests(TestCase):
def test_fields_with_filter_fields_list(self):
backend = DjangoFilterBackend()
fields = backend.get_schema_fields(FilterFieldsRootView())
fields = [f.name for f in fields]
self.assertEqual(fields, ['decimal', 'date'])
def test_fields_with_filter_fields_dict(self):
class DictFilterFieldsRootView(FilterFieldsRootView):
filter_fields = {
'decimal': ['exact', 'lt', 'gt'],
}
backend = DjangoFilterBackend()
fields = backend.get_schema_fields(DictFilterFieldsRootView())
fields = [f.name for f in fields]
self.assertEqual(fields, ['decimal', 'decimal__lt', 'decimal__gt'])
def test_fields_with_filter_class(self):
backend = DjangoFilterBackend()
fields = backend.get_schema_fields(FilterClassRootView())
fields = [f.name for f in fields]
self.assertEqual(fields, ['text', 'decimal', 'date'])
class CommonFilteringTestCase(TestCase):
def _serialize_object(self, obj):
return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}