Code linting and added runtests.py

This commit is contained in:
Tom Christie 2014-08-19 13:28:07 +01:00
parent e385a7b8eb
commit bf09c32de8
60 changed files with 548 additions and 305 deletions

View File

@ -28,7 +28,7 @@ install:
- export PYTHONPATH=.
script:
- py.test
- ./runtests.py
matrix:
exclude:

View File

@ -1,2 +0,0 @@
[pytest]
addopts = --tb=short

View File

@ -1,3 +1,9 @@
# Test requirements
pytest-django==2.6
pytest==2.5.2
pytest-cov==1.6
# Optional packages
markdown>=2.1.0
PyYAML>=3.10
defusedxml>=0.3

View File

@ -1,3 +1 @@
-e .
Django>=1.3
pytest-django==2.6

View File

@ -1,9 +1,9 @@
"""
______ _____ _____ _____ __ _
| ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| | __
______ _____ _____ _____ __
| ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""

View File

@ -21,7 +21,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get('HTTP_AUTHORIZATION', b'')
if type(auth) == type(''):
if isinstance(auth, type('')):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
return auth

View File

@ -1,6 +1,5 @@
import binascii
import os
from hashlib import sha1
from django.conf import settings
from django.db import models

View File

@ -1,15 +1,10 @@
# -*- coding: utf-8 -*-
import datetime
from south.db import db
from south.v2 import SchemaMigration
from django.db import models
from rest_framework.settings import api_settings
try:
from django.contrib.auth import get_user_model
except ImportError: # django < 1.5
except ImportError: # django < 1.5
from django.contrib.auth.models import User
else:
User = get_user_model()
@ -26,12 +21,10 @@ class Migration(SchemaMigration):
))
db.send_create_signal('authtoken', ['Token'])
def backwards(self, orm):
# Deleting model 'Token'
db.delete_table('authtoken_token')
models = {
'auth.group': {
'Meta': {'object_name': 'Group'},

View File

@ -131,6 +131,7 @@ def list_route(methods=['get'], **kwargs):
return func
return decorator
# These are now pending deprecation, in favor of `detail_route` and `list_route`.
def link(**kwargs):
@ -139,11 +140,13 @@ def link(**kwargs):
"""
msg = 'link is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = ['get']
func.detail = True
func.kwargs = kwargs
return func
return decorator
@ -153,9 +156,11 @@ def action(methods=['post'], **kwargs):
"""
msg = 'action is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = methods
func.detail = True
func.kwargs = kwargs
return func
return decorator
return decorator

View File

@ -23,6 +23,7 @@ class APIException(Exception):
def __str__(self):
return self.detail
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.'

View File

@ -116,7 +116,7 @@ class OrderingFilter(BaseFilterBackend):
def get_ordering(self, request):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.

View File

@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None):
ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
@ -162,10 +163,11 @@ class GenericAPIView(views.APIView):
raise Http404(_("Page is not 'last', nor can it be converted to an int."))
try:
page = paginator.page(page_number)
except InvalidPage as e:
raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
'page_number': page_number,
'message': str(e)
except InvalidPage as exc:
error_format = _('Invalid page (%(page_number)s): %(message)s')
raise Http404(error_format % {
'page_number': page_number,
'message': str(exc)
})
if deprecated_style:
@ -208,7 +210,6 @@ class GenericAPIView(views.APIView):
return filter_backends
########################
### The following methods provide default implementations
### that you may want to override for more complex cases.
@ -284,8 +285,8 @@ class GenericAPIView(views.APIView):
if self.model is not None:
return self.model._default_manager.all()
raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
% self.__class__.__name__)
error_format = "'%s' must define 'queryset' or 'model'"
raise ImproperlyConfigured(error_format % self.__class__.__name__)
def get_object(self, queryset=None):
"""

View File

@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
if (_MediaType(renderer.media_type).precedence >
_MediaType(media_type).precedence):
if (
_MediaType(renderer.media_type).precedence >
_MediaType(media_type).precedence
):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
return renderer, renderer.media_type

View File

@ -62,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission):
"""
def has_permission(self, request, view):
return (request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated())
return (
request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated()
)
class DjangoModelPermissions(BasePermission):
@ -122,9 +124,11 @@ class DjangoModelPermissions(BasePermission):
perms = self.get_required_permissions(request.method, model_cls)
return (request.user and
return (
request.user and
(request.user.is_authenticated() or not self.authenticated_users_only) and
request.user.has_perms(perms))
request.user.has_perms(perms)
)
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
@ -212,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission):
required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
return oauth2_provider_scope.check(required, request.auth.scope)
assert False, ('TokenHasReadWriteScope requires either the'
'`OAuthAuthentication` or `OAuth2Authentication` authentication '
'class to be used.')
assert False, (
'TokenHasReadWriteScope requires either the'
'`OAuthAuthentication` or `OAuth2Authentication` authentication '
'class to be used.'
)

View File

@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API.
"""
from __future__ import unicode_literals
import copy
import json
import django
from django import forms
@ -75,7 +74,6 @@ class JSONRenderer(BaseRenderer):
# E.g. If we're being called by the BrowsableAPIRenderer.
return renderer_context.get('indent', None)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render `data` into JSON, returning a bytestring.
@ -86,8 +84,10 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context)
ret = json.dumps(data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii)
ret = json.dumps(
data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
# but if ensure_ascii=False, the return type is underspecified,
@ -454,8 +454,10 @@ class BrowsableAPIRenderer(BaseRenderer):
if method in ('DELETE', 'OPTIONS'):
return True # Don't actually need to return a form
if (not getattr(view, 'get_serializer', None)
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):
if (
not getattr(view, 'get_serializer', None)
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
):
return
serializer = view.get_serializer(instance=obj, data=data, files=files)
@ -576,7 +578,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'version': VERSION,
'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods,
'available_formats': [renderer.format for renderer in view.renderer_classes],
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request),
@ -625,4 +627,3 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data)

View File

@ -295,8 +295,11 @@ class Request(object):
Return the content body of the request, as a stream.
"""
try:
content_length = int(self.META.get('CONTENT_LENGTH',
self.META.get('HTTP_CONTENT_LENGTH')))
content_length = int(
self.META.get(
'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH')
)
)
except (ValueError, TypeError):
content_length = 0
@ -320,9 +323,11 @@ class Request(object):
)
# We only need to use form overloading on form POST requests.
if (not USE_FORM_OVERLOADING
if (
not USE_FORM_OVERLOADING
or self._request.method != 'POST'
or not is_form_media_type(self._content_type)):
or not is_form_media_type(self._content_type)
):
return
# At this point we're committed to parsing the request as form data.
@ -330,15 +335,19 @@ class Request(object):
self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content.
if (self._METHOD_PARAM and
self._METHOD_PARAM in self._data):
if (
self._METHOD_PARAM and
self._METHOD_PARAM in self._data
):
self._method = self._data[self._METHOD_PARAM].upper()
# Content overloading - modify the content type, and force re-parse.
if (self._CONTENT_PARAM and
if (
self._CONTENT_PARAM and
self._CONTENTTYPE_PARAM and
self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data):
self._CONTENTTYPE_PARAM in self._data
):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)

View File

@ -62,8 +62,10 @@ class Response(SimpleTemplateResponse):
ret = renderer.render(self.data, media_type, context)
if isinstance(ret, six.text_type):
assert charset, 'renderer returned unicode, and did not specify ' \
'a charset value.'
assert charset, (
'renderer returned unicode, and did not specify '
'a charset value.'
)
return bytes(ret.encode(charset))
if not ret:

View File

@ -449,9 +449,11 @@ class BaseSerializer(WritableField):
# If we have a model manager or similar object then we need
# to iterate through each instance.
if (self.many and
if (
self.many and
not hasattr(obj, '__iter__') and
is_simple_callable(getattr(obj, 'all', None))):
is_simple_callable(getattr(obj, 'all', None))
):
obj = obj.all()
kwargs = {
@ -601,8 +603,10 @@ class BaseSerializer(WritableField):
API schemas for auto-documentation.
"""
return SortedDict(
[(field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)]
[
(field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)
]
)
@ -656,8 +660,10 @@ class ModelSerializer(Serializer):
"""
cls = self.opts.model
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
assert cls is not None, (
"Serializer class '%s' is missing 'model' Meta option" %
self.__class__.__name__
)
opts = cls._meta.concrete_model._meta
ret = SortedDict()
nested = bool(self.opts.depth)
@ -668,9 +674,9 @@ class ModelSerializer(Serializer):
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
field = self.get_pk_field(pk_field)
if field:
ret[pk_field.name] = field
serializer_pk_field = self.get_pk_field(pk_field)
if serializer_pk_field:
ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
@ -739,9 +745,11 @@ class ModelSerializer(Serializer):
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
if (is_m2m and
if (
is_m2m and
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created):
not relation.field.rel.through._meta.auto_created
):
has_through_model = True
if nested:
@ -911,10 +919,12 @@ class ModelSerializer(Serializer):
for field_name, field in self.fields.items():
field_name = field.source or field_name
if field_name in exclusions \
and not field.read_only \
and (field.required or hasattr(instance, field_name)) \
and not isinstance(field, Serializer):
if (
field_name in exclusions
and not field.read_only
and (field.required or hasattr(instance, field_name))
and not isinstance(field, Serializer)
):
exclusions.remove(field_name)
return exclusions

View File

@ -46,16 +46,12 @@ DEFAULTS = {
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
),
'DEFAULT_THROTTLE_CLASSES': (
),
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_THROTTLE_CLASSES': (),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
# Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS':
'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer',
'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_FILTER_BACKENDS': (),
# Throttling

View File

@ -10,15 +10,19 @@ from __future__ import unicode_literals
def is_informational(code):
return code >= 100 and code <= 199
def is_success(code):
return code >= 200 and code <= 299
def is_redirect(code):
return code >= 300 and code <= 399
def is_client_error(code):
return code >= 400 and code <= 499
def is_server_error(code):
return code >= 500 and code <= 599

View File

@ -152,8 +152,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
middle = middle[len(opening):]
lead = lead + opening
# Keep parentheses at the end only if they're balanced.
if (middle.endswith(closing)
and middle.count(closing) == middle.count(opening) + 1):
if (
middle.endswith(closing)
and middle.count(closing) == middle.count(opening) + 1
):
middle = middle[:-len(closing)]
trail = closing + trail

View File

@ -49,9 +49,10 @@ class APIRequestFactory(DjangoRequestFactory):
else:
format = format or self.default_format
assert format in self.renderer_classes, ("Invalid format '{0}'. "
"Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
"to enable extra request formats.".format(
assert format in self.renderer_classes, (
"Invalid format '{0}'. Available formats are {1}. "
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
)

View File

@ -8,17 +8,19 @@ your API requires authentication:
...
url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))
)
The urls must be namespaced as 'rest_framework', and you should make sure
your authentication settings include `SessionAuthentication`.
"""
from __future__ import unicode_literals
from django.conf.urls import patterns, url
from django.contrib.auth import views
template_name = {'template_name': 'rest_framework/login.html'}
urlpatterns = patterns('django.contrib.auth.views',
url(r'^login/$', 'login', template_name, name='login'),
url(r'^logout/$', 'logout', template_name, name='logout'),
urlpatterns = patterns(
'',
url(r'^login/$', views.login, template_name, name='login'),
url(r'^logout/$', views.logout, template_name, name='logout')
)

View File

@ -98,14 +98,23 @@ else:
node.flow_style = best_style
return node
SafeDumper.add_representer(decimal.Decimal,
SafeDumper.represent_decimal)
SafeDumper.add_representer(SortedDict,
yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(DictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list)
SafeDumper.add_representer(
decimal.Decimal,
SafeDumper.represent_decimal
)
SafeDumper.add_representer(
SortedDict,
yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
DictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
)

View File

@ -6,8 +6,6 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown
from rest_framework.settings import api_settings
from textwrap import dedent
import re
@ -40,6 +38,7 @@ def dedent(content):
return content.strip()
def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
@ -49,6 +48,7 @@ def camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
def markup_description(description):
"""
Apply HTML markup to the given description.

View File

@ -57,7 +57,7 @@ class _MediaType(object):
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:

View File

@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None):
return name
def get_view_description(view_cls, html=False):
"""
Given a view class, return a textual description to represent the view.
@ -119,7 +120,6 @@ class APIView(View):
headers['Vary'] = 'Accept'
return headers
def http_method_not_allowed(self, request, *args, **kwargs):
"""
If `request.method` does not correspond to a handler method,

View File

@ -127,11 +127,11 @@ class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
class ModelViewSet(mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
GenericViewSet):
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
GenericViewSet):
"""
A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions.

86
runtests.py Executable file
View File

@ -0,0 +1,86 @@
#! /usr/bin/env python
from __future__ import print_function
import pytest
import sys
import os
import subprocess
PYTEST_ARGS = {
'default': ['tests'],
'fast': ['tests', '-q'],
}
FLAKE8_ARGS = ['rest_framework', 'tests', '--ignore=E501']
sys.path.append(os.path.dirname(__file__))
def exit_on_failure(ret, message=None):
if ret:
sys.exit(ret)
def flake8_main(args):
print('Running flake8 code linting')
ret = subprocess.call(['flake8'] + args)
print('flake8 failed' if ret else 'flake8 passed')
return ret
def split_class_and_function(string):
class_string, function_string = string.split('.', 1)
return "%s and %s" % (class_string, function_string)
def is_function(string):
# `True` if it looks like a test function is included in the string.
return string.startswith('test_') or '.test_' in string
def is_class(string):
# `True` if first character is uppercase - assume it's a class name.
return string[0] == string[0].upper()
if __name__ == "__main__":
try:
sys.argv.remove('--nolint')
except ValueError:
run_flake8 = True
else:
run_flake8 = False
try:
sys.argv.remove('--lintonly')
except ValueError:
run_tests = True
else:
run_tests = False
try:
sys.argv.remove('--fast')
except ValueError:
style = 'default'
else:
style = 'fast'
run_flake8 = False
if len(sys.argv) > 1:
pytest_args = sys.argv[1:]
first_arg = pytest_args[0]
if first_arg.startswith('-'):
# `runtests.py [flags]`
pytest_args = ['tests'] + pytest_args
elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg)
pytest_args = ['tests', '-k', expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]`
pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:]
else:
pytest_args = PYTEST_ARGS[style]
if run_tests:
exit_on_failure(pytest.main(pytest_args))
if run_flake8:
exit_on_failure(flake8_main(FLAKE8_ARGS))

View File

@ -1,5 +1,4 @@
from rest_framework import serializers
from tests.models import NullableForeignKeySource

View File

@ -68,7 +68,6 @@ SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
TEMPLATE_LOADERS = (
'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader',
# 'django.template.loaders.eggs.Loader',
)
MIDDLEWARE_CLASSES = (
@ -104,8 +103,8 @@ INSTALLED_APPS = (
# OAuth is optional and won't work if there is no oauth_provider & oauth2
try:
import oauth_provider
import oauth2
import oauth_provider # NOQA
import oauth2 # NOQA
except ImportError:
pass
else:
@ -114,7 +113,7 @@ else:
)
try:
import provider
import provider # NOQA
except ImportError:
pass
else:
@ -125,13 +124,13 @@ else:
# guardian is optional
try:
import guardian
import guardian # NOQA
except ImportError:
pass
else:
ANONYMOUS_USER_ID = -1
AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend', # default
'django.contrib.auth.backends.ModelBackend', # default
'guardian.backends.ObjectPermissionBackend',
)
INSTALLED_APPS += (

View File

@ -45,26 +45,39 @@ class MockView(APIView):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('',
urlpatterns = patterns(
'',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
(r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]))
(
r'^oauth-with-scope/$',
MockView.as_view(
authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]
)
)
)
class OAuth2AuthenticationDebug(OAuth2Authentication):
allow_query_params_token = True
if oauth2_provider is not None:
urlpatterns += patterns('',
urlpatterns += patterns(
'',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
url(
r'^oauth2-with-scope-test/$',
MockView.as_view(
authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope]
)
)
)
@ -278,12 +291,16 @@ class OAuthTests(TestCase):
self.TOKEN_KEY = "token_key"
self.TOKEN_SECRET = "token_secret"
self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='example', user=self.user, status=self.consts.ACCEPTED)
self.consumer = Consumer.objects.create(
key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='example', user=self.user, status=self.consts.ACCEPTED
)
self.scope = Scope.objects.create(name="resource name", url="api/")
self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
self.token = OAuthToken.objects.create(
user=self.user, consumer=self.consumer, scope=self.scope,
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET,
is_approved=True
)
def _create_authorization_header(self):
@ -501,24 +518,24 @@ class OAuth2Tests(TestCase):
self.REFRESH_TOKEN = "refresh_token"
self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
client_type=0,
name='example',
user=None,
)
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
client_type=0,
name='example',
user=None,
)
self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
)
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
)
def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token)
@ -569,8 +586,10 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post('/oauth2-test/',
data={'access_token': self.access_token.token})
response = self.csrf_client.post(
'/oauth2-test/',
data={'access_token': self.access_token.token}
)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')

View File

@ -24,7 +24,8 @@ class NestedResourceRoot(APIView):
class NestedResourceInstance(APIView):
pass
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
@ -40,34 +41,60 @@ class BreadcrumbTests(TestCase):
def test_root_breadcrumbs(self):
url = '/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)
def test_resource_root_breadcrumbs(self):
url = '/resource/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/')])
self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/')
]
)
def test_resource_instance_breadcrumbs(self):
url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123')])
self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123')
]
)
def test_nested_resource_breadcrumbs(self):
url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/')])
self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/')
]
)
def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')])
self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')
]
)
def test_broken_url_breadcrumbs_handled_gracefully(self):
url = '/foobar'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)

View File

@ -648,7 +648,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
def test_raise_min_value(self):
"""
@ -660,7 +660,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '99'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
def test_raise_max_digits(self):
"""
@ -672,7 +672,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123.456'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
def test_raise_max_decimal_places(self):
"""
@ -684,7 +684,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '123.4567'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
def test_raise_max_whole_digits(self):
"""
@ -696,7 +696,7 @@ class DecimalFieldTest(TestCase):
s = DecimalSerializer(data={'decimal_field': '12345.6'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
class ChoiceFieldTests(TestCase):
@ -729,7 +729,7 @@ class ChoiceFieldTests(TestCase):
def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid())
self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
self.assertEqual(s.data['choice'], '')
def test_empty_choice_model(self):
@ -875,7 +875,7 @@ class SlugFieldTests(TestCase):
s = SlugFieldSerializer(data={'slug_field': 'a b'})
self.assertEqual(s.is_valid(), False)
self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
class URLFieldTests(TestCase):

View File

@ -85,11 +85,8 @@ class FileSerializerTests(TestCase):
"""
Validation should still function when no data dictionary is provided.
"""
now = datetime.datetime.now()
file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt'
file.size = len(file.getvalue())
uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(files={'file': file})
uploaded_file = BytesIO(six.b('stuff'))
uploaded_file.name = 'stuff.txt'
uploaded_file.size = len(uploaded_file.getvalue())
serializer = UploadedFileSerializer(files={'file': uploaded_file})
self.assertFalse(serializer.is_valid())

View File

@ -74,7 +74,8 @@ if django_filters:
def get_queryset(self):
return FilterableItem.objects.all()
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
url(r'^get-queryset/$', GetQuerysetView.as_view(),
@ -653,8 +654,8 @@ class SensitiveOrderingFilterTests(TestCase):
self.assertEqual(
response.data,
[
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]
)

View File

@ -117,18 +117,18 @@ class TestGenericRelations(TestCase):
serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [
{
'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'reminder',
'tagged_item': 'Note: Remember the milk'
}
{
'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
},
{
'tag': 'reminder',
'tagged_item': 'Note: Remember the milk'
}
]
self.assertEqual(serializer.data, expected)

View File

@ -34,7 +34,8 @@ def not_found(request):
raise Http404()
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^$', example),
url(r'^permission_denied$', permission_denied),
url(r'^not_found$', not_found),

View File

@ -94,7 +94,8 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.utils import unittest
@ -12,6 +11,7 @@ from .models import BasicModel, FilterableItem
factory = APIRequestFactory()
# Helper function to split arguments out of an url
def split_arguments_from_url(url):
if '?' not in url:
@ -274,8 +274,8 @@ class TestUnpaginated(TestCase):
BasicModel(text=i).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = DefaultPageSizeKwargView.as_view()
@ -302,8 +302,8 @@ class TestCustomPaginateByParam(TestCase):
BasicModel(text=i).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = PaginateByParamView.as_view()
@ -483,8 +483,6 @@ class NonIntegerPaginator(object):
class TestNonIntegerPagination(TestCase):
def test_custom_pagination_serializer(self):
objects = ['john', 'paul', 'george', 'ringo']
paginator = NonIntegerPaginator(objects, 2)

View File

@ -12,6 +12,7 @@ import base64
factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView):
model = BasicModel
authentication_classes = [authentication.BasicAuthentication]
@ -101,42 +102,54 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self):
request = factory.options('/',
HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self):
request = factory.options('/',
HTTP_AUTHORIZATION=self.disallowed_credentials)
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options('/1',
HTTP_AUTHORIZATION=self.disallowed_credentials)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
def test_options_updateonly(self):
request = factory.options('/',
HTTP_AUTHORIZATION=self.updateonly_credentials)
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options('/1',
HTTP_AUTHORIZATION=self.updateonly_credentials)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
@ -153,6 +166,7 @@ class BasicPermModel(models.Model):
# add, change, delete built in to django
)
# Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = {
@ -205,7 +219,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
app_label = BasicPermModel._meta.app_label
f = '{0}_{1}'.format
perms = {
'view': f('view', model_name),
'view': f('view', model_name),
'change': f('change', model_name),
'delete': f('delete', model_name)
}
@ -246,21 +260,27 @@ class ObjectPermissionsIntegrationTests(TestCase):
# Update
def test_can_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly'])
request = factory.patch(
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get('text'), 'foobar')
def test_cannot_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly'])
request = factory.patch(
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_cannot_update_permissions_non_existing(self):
request = factory.patch('/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly'])
request = factory.patch(
'/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

View File

@ -108,19 +108,25 @@ class RelatedFieldSourceTests(TestCase):
doesn't exist.
"""
from tests.models import ManyToManySource
class Meta:
model = ManyToManySource
attrs = {
'name': serializers.SlugRelatedField(
slug_field='name', source='banzai'),
'Meta': Meta,
}
TestSerializer = type(str('TestSerializer'),
(serializers.ModelSerializer,), attrs)
TestSerializer = type(
str('TestSerializer'),
(serializers.ModelSerializer,),
attrs
)
with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'})
@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
class RelatedFieldChoicesTests(TestCase):
"""
@ -141,4 +147,3 @@ class RelatedFieldChoicesTests(TestCase):
widget_count = len(field.widget.choices)
self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')

View File

@ -16,7 +16,8 @@ request = factory.get('/') # Just to ensure we have a request in the serializer
def dummy_view(request, pk):
pass
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
@ -86,9 +87,9 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)
@ -114,9 +115,9 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)

View File

@ -65,9 +65,9 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
self.assertEqual(serializer.data, expected)
@ -93,9 +93,9 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
self.assertEqual(serializer.data, expected)

View File

@ -76,7 +76,6 @@ class MockGETView(APIView):
return Response({'foo': ['bar', 'baz']})
class MockPOSTView(APIView):
def post(self, request, **kwargs):
return Response({'foo': request.DATA})
@ -102,7 +101,8 @@ class HTMLView1(APIView):
def get(self, request, **kwargs):
return Response('text')
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()),
@ -312,16 +312,22 @@ class JSONRendererTests(TestCase):
class Dict(MutableMapping):
def __init__(self):
self._dict = dict()
def __getitem__(self, key):
return self._dict.__getitem__(key)
def __setitem__(self, key, value):
return self._dict.__setitem__(key, value)
def __delitem__(self, key):
return self._dict.__delitem__(key)
def __iter__(self):
return self._dict.__iter__()
def __len__(self):
return self._dict.__len__()
def keys(self):
return self._dict.keys()
@ -330,22 +336,24 @@ class JSONRendererTests(TestCase):
x[2] = 3
ret = JSONRenderer().render(x)
data = json.loads(ret.decode('utf-8'))
self.assertEquals(data, {'key': 'string value', '2': 3})
self.assertEquals(data, {'key': 'string value', '2': 3})
def test_render_obj_with_getitem(self):
class DictLike(object):
def __init__(self):
self._dict = {}
def set(self, value):
self._dict = dict(value)
def __getitem__(self, key):
return self._dict[key]
x = DictLike()
x.set({'a': 1, 'b': 'string'})
with self.assertRaises(TypeError):
JSONRenderer().render(x)
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
@ -394,35 +402,47 @@ class JSONPRendererTests(TestCase):
"""
Test JSONP rendering with View JSON Renderer.
"""
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript')
resp = self.client.get(
'/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
self.assertEqual(
resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_without_callback_without_json_renderer(self):
"""
Test JSONP rendering without View JSON Renderer.
"""
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript')
resp = self.client.get(
'/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
self.assertEqual(
resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_with_callback(self):
"""
Test JSONP rendering with callback function name.
"""
callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript')
resp = self.client.get(
'/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
self.assertEqual(
resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')
)
if yaml:
@ -467,7 +487,6 @@ if yaml:
def assertYAMLContains(self, content, string):
self.assertTrue(string in content, '%r not in %r' % (string, content))
class UnicodeYAMLRendererTests(TestCase):
"""
Tests specific for the Unicode YAML Renderer
@ -592,13 +611,13 @@ class CacheRenderTest(TestCase):
""" Return any errors that would be raised if `obj' is pickled
Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
"""
if seen == None:
if seen is None:
seen = []
try:
state = obj.__getstate__()
except AttributeError:
return
if state == None:
if state is None:
return
if isinstance(state, tuple):
if not isinstance(state[0], dict):

View File

@ -272,7 +272,8 @@ class MockView(APIView):
return Response(status=status.INTERNAL_SERVER_ERROR)
urlpatterns = patterns('',
urlpatterns = patterns(
'',
(r'^$', MockView.as_view()),
)

View File

@ -100,7 +100,8 @@ new_model_viewset_router = routers.DefaultRouter()
new_model_viewset_router.register(r'', HTMLNewModelViewSet)
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),

View File

@ -10,7 +10,8 @@ factory = APIRequestFactory()
def null_view(request):
pass
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^view$', null_view, name='view'),
)

View File

@ -93,7 +93,8 @@ class TestCustomLookupFields(TestCase):
from tests import test_routers
urls = getattr(test_routers, 'urlpatterns')
urls += patterns('',
urls += patterns(
'',
url(r'^', include(self.router.urls)),
)
@ -104,7 +105,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/notes/')
self.assertEqual(response.data,
self.assertEqual(
response.data,
[{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"
@ -113,7 +115,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_detail_view(self):
response = self.client.get('/notes/123/')
self.assertEqual(response.data,
self.assertEqual(
response.data,
{
"url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar"

View File

@ -7,10 +7,12 @@ from django.utils import unittest
from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations
from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel,
ForeignKeySource, ManyToManySource)
from tests.models import (
HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel,
DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo,
RESTFrameworkModel, ForeignKeySource
)
from tests.models import BasicModelSerializer
import datetime
import pickle
@ -99,6 +101,7 @@ class ActionItemSerializer(serializers.ModelSerializer):
class Meta:
model = ActionItem
class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
"""
Intended to test that fields with `required=False` are excluded from validation.
@ -109,6 +112,7 @@ class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
model = ActionItem
fields = ('title',)
class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
class Meta:
@ -295,8 +299,10 @@ class BasicTests(TestCase):
in the Meta data
"""
serializer = PersonSerializer(self.person)
self.assertEqual(set(serializer.data.keys()),
set(['name', 'age', 'info']))
self.assertEqual(
set(serializer.data.keys()),
set(['name', 'age', 'info'])
)
def test_field_with_dictionary(self):
"""
@ -331,9 +337,9 @@ class BasicTests(TestCase):
 id field is not populated if `data` is accessed prior to `save()`
"""
serializer = ActionItemSerializer(self.actionitem)
self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
self.assertIsNone(serializer.data.get('id', None), 'New instance. `id` should not be set.')
serializer.save()
self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
self.assertIsNotNone(serializer.data.get('id', None), 'Model is saved. `id` should be set.')
def test_fields_marked_as_not_required_are_excluded_from_validation(self):
"""
@ -660,10 +666,10 @@ class ModelValidationTests(TestCase):
serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True)
self.assertFalse(third_serializer.is_valid())
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
def test_foreign_key_is_null_with_partial(self):
"""
@ -959,7 +965,7 @@ class WritableFieldDefaultValueTests(TestCase):
self.assertEqual(got, self.expected)
def test_get_default_value_with_callable(self):
field = self.create_field(default=lambda : self.expected)
field = self.create_field(default=lambda: self.expected)
got = field.get_default_value()
self.assertEqual(got, self.expected)
@ -974,7 +980,7 @@ class WritableFieldDefaultValueTests(TestCase):
self.assertIsNone(got)
def test_get_default_value_returns_non_True_values(self):
values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
for expected in values:
field = self.create_field(default=expected)
got = field.get_default_value()

View File

@ -83,9 +83,9 @@ class BulkCreateSerializerTests(TestCase):
self.assertEqual(serializer.is_valid(), False)
expected_errors = [
{'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']}
{'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']},
{'non_field_errors': ['Invalid data']}
]
self.assertEqual(serializer.errors, expected_errors)

View File

@ -328,12 +328,14 @@ class NestedModelSerializerUpdateTests(TestCase):
class BlogPostSerializer(serializers.ModelSerializer):
comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
class Meta:
model = models.BlogPost
fields = ('id', 'title', 'comments')
class PersonSerializer(serializers.ModelSerializer):
posts = BlogPostSerializer(many=True, source='blogpost_set')
class Meta:
model = models.Person
fields = ('id', 'name', 'age', 'posts')

View File

@ -1,9 +1,7 @@
from django.db import models
from django.test import TestCase
from rest_framework.compat import six
from rest_framework.serializers import _resolve_model
from tests.models import BasicModel
from rest_framework.compat import six
class ResolveModelTests(TestCase):

View File

@ -30,4 +30,4 @@ class TestStatus(TestCase):
self.assertFalse(is_server_error(499))
self.assertTrue(is_server_error(500))
self.assertTrue(is_server_error(599))
self.assertFalse(is_server_error(600))
self.assertFalse(is_server_error(600))

View File

@ -48,4 +48,4 @@ class Issue1386Tests(TestCase):
self.assertEqual(i, res)
# example from issue #1386, this shouldn't raise an exception
_ = urlize_quoted_links("asdf:[/p]zxcv.com")
urlize_quoted_links("asdf:[/p]zxcv.com")

View File

@ -28,7 +28,8 @@ def session_view(request):
})
urlpatterns = patterns('',
urlpatterns = patterns(
'',
url(r'^view/$', view),
url(r'^session-view/$', session_view),
)
@ -142,7 +143,8 @@ class TestAPIRequestFactory(TestCase):
assertion error.
"""
factory = APIRequestFactory()
self.assertRaises(AssertionError, factory.post,
self.assertRaises(
AssertionError, factory.post,
path='/view/', data={'example': 1}, format='xml'
)

View File

@ -27,7 +27,7 @@ class NonTimeThrottle(BaseThrottle):
if not hasattr(self.__class__, 'called'):
self.__class__.called = True
return True
return False
return False
class MockView(APIView):
@ -125,36 +125,42 @@ class ThrottlingTests(TestCase):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, None),
(0, None),
(0, None),
(0, '1')
))
self.ensure_response_header_contains_proper_throttle_field(
MockView, (
(0, None),
(0, None),
(0, None),
(0, '1')
)
)
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, None),
(0, None),
(0, None),
(0, '60')
))
self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling, (
(0, None),
(0, None),
(0, None),
(0, '60')
)
)
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, None),
(20, None),
(40, None),
(60, None),
(80, None)
))
self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling, (
(0, None),
(20, None),
(40, None),
(60, None),
(80, None)
)
)
def test_non_time_throttle(self):
"""
@ -170,7 +176,7 @@ class ThrottlingTests(TestCase):
self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('X-Throttle-Wait-Seconds' in response)
self.assertFalse('X-Throttle-Wait-Seconds' in response)
class ScopedRateThrottleTests(TestCase):

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.templatetags.rest_framework import urlize_quoted_links
import sys
class URLizerTests(TestCase):

10
tox.ini
View File

@ -1,13 +1,19 @@
[tox]
downloadcache = {toxworkdir}/cache/
envlist =
flake8,
py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7,
py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,
py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,
py2.7-django1.4,py2.6-django1.4,
py2.7-django1.4,py2.6-django1.4
[testenv]
commands = py.test -q
commands = ./runtests.py --fast
[testenv:flake8]
basepython = python2.7
deps = pytest==2.5.2
commands = ./runtests.py --lintonly
[testenv:py3.4-django1.7]
basepython = python3.4