debian-django-rest-framework/rest_framework/filters.py

315 lines
11 KiB
Python

"""
Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
from __future__ import unicode_literals
import operator
from functools import reduce
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.template import loader
from django.utils import six
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import (
crispy_forms, distinct, django_filters, guardian, template_render
)
from rest_framework.settings import api_settings
if 'crispy_forms' in settings.INSTALLED_APPS and crispy_forms and django_filters:
# If django-crispy-forms is installed, use it to get a bootstrap3 rendering
# of the DjangoFilterBackend controls when displayed as HTML.
from crispy_forms.helper import FormHelper
from crispy_forms.layout import Layout, Submit
class FilterSet(django_filters.FilterSet):
def __init__(self, *args, **kwargs):
super(FilterSet, self).__init__(*args, **kwargs)
for field in self.form.fields.values():
field.help_text = None
layout_components = list(self.form.fields.keys()) + [
Submit('', _('Submit'), css_class='btn-default'),
]
helper = FormHelper()
helper.form_method = 'GET'
helper.template_pack = 'bootstrap3'
helper.layout = Layout(*layout_components)
self.form.helper = helper
filter_template = 'rest_framework/filters/django_filter_crispyforms.html'
elif django_filters:
# If django-crispy-forms is not installed, use the standard
# 'form.as_p' rendering when DjangoFilterBackend is displayed as HTML.
class FilterSet(django_filters.FilterSet):
def __init__(self, *args, **kwargs):
super(FilterSet, self).__init__(*args, **kwargs)
for field in self.form.fields.values():
field.help_text = None
filter_template = 'rest_framework/filters/django_filter.html'
else:
FilterSet = None
filter_template = None
class BaseFilterBackend(object):
"""
A base class from which all filter backend classes should inherit.
"""
def filter_queryset(self, request, queryset, view):
"""
Return a filtered queryset.
"""
raise NotImplementedError(".filter_queryset() must be overridden.")
class DjangoFilterBackend(BaseFilterBackend):
"""
A filter backend that uses django-filter.
"""
default_filter_set = FilterSet
template = filter_template
def __init__(self):
assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
def get_filter_class(self, view, queryset=None):
"""
Return the django-filters `FilterSet` used to filter the queryset.
"""
filter_class = getattr(view, 'filter_class', None)
filter_fields = getattr(view, 'filter_fields', None)
if filter_class:
filter_model = filter_class.Meta.model
assert issubclass(queryset.model, filter_model), \
'FilterSet model %s does not match queryset model %s' % \
(filter_model, queryset.model)
return filter_class
if filter_fields:
class AutoFilterSet(FilterSet):
class Meta:
model = queryset.model
fields = filter_fields
return AutoFilterSet
return None
def filter_queryset(self, request, queryset, view):
filter_class = self.get_filter_class(view, queryset)
if filter_class:
return filter_class(request.query_params, queryset=queryset).qs
return queryset
def to_html(self, request, queryset, view):
filter_class = self.get_filter_class(view, queryset)
if not filter_class:
return None
filter_instance = filter_class(request.query_params, queryset=queryset)
context = {
'filter': filter_instance
}
template = loader.get_template(self.template)
return template_render(template, context)
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split()
def construct_search(self, field_name):
if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:]
elif field_name.startswith('='):
return "%s__iexact" % field_name[1:]
elif field_name.startswith('@'):
return "%s__search" % field_name[1:]
if field_name.startswith('$'):
return "%s__iregex" % field_name[1:]
else:
return "%s__icontains" % field_name
def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None)
search_terms = self.get_search_terms(request)
if not search_fields or not search_terms:
return queryset
orm_lookups = [
self.construct_search(six.text_type(search_field))
for search_field in search_fields
]
base = queryset
for search_term in search_terms:
queries = [
models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups
]
queryset = queryset.filter(reduce(operator.or_, queries))
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
return distinct(queryset, base)
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
term = self.get_search_terms(request)
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template)
return template_render(template, context)
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, six.string_types):
return (ordering,)
return ordering
def get_valid_fields(self, queryset, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
serializer_class = getattr(view, 'serializer_class')
if serializer_class is None:
msg = ("Cannot use %s on a view which does not have either a "
"'serializer_class' or 'ordering_fields' attribute.")
raise ImproperlyConfigured(msg % self.__class__.__name__)
valid_fields = [
(field.source or field_name, field.label)
for field_name, field in serializer_class().fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
]
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, getattr(field, 'label', field.name.title()))
for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.aggregates.keys()
]
else:
valid_fields = [
(item, item) if isinstance(item, six.string_types) else item
for item in valid_fields
]
return valid_fields
def remove_invalid_fields(self, queryset, fields, view):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view)]
return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if current is None else current[0]
options = []
for key, label in self.get_valid_fields(queryset, view):
options.append((key, '%s - ascending' % label))
options.append(('-' + key, '%s - descending' % label))
return {
'request': request,
'current': current,
'param': self.ordering_param,
'options': options,
}
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template_render(template, context)
class DjangoObjectPermissionsFilter(BaseFilterBackend):
"""
A filter backend that limits results to those where the requesting user
has read object level permissions.
"""
def __init__(self):
assert guardian, 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed'
perm_format = '%(app_label)s.view_%(model_name)s'
def filter_queryset(self, request, queryset, view):
extra = {}
user = request.user
model_cls = queryset.model
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
}
permission = self.perm_format % kwargs
if guardian.VERSION >= (1, 3):
# Maintain behavior compatibility with versions prior to 1.3
extra = {'accept_global_perms': False}
else:
extra = {}
return guardian.shortcuts.get_objects_for_user(user, permission, queryset, **extra)