debian-django-ratelimit/ratelimit/tests.py

430 lines
14 KiB
Python

import django
from django.core.cache import cache, InvalidCacheBackendError
from django.test import RequestFactory, TestCase
from django.test.utils import override_settings
from django.views.generic import View
from ratelimit.decorators import ratelimit
from ratelimit.exceptions import Ratelimited
from ratelimit.mixins import RateLimitMixin
from ratelimit.helpers import is_ratelimited
class RatelimitTests(TestCase):
def setUp(self):
cache.clear()
def test_limit_ip(self):
@ratelimit(ip=True, method=None, rate='1/m', block=True)
def view(request):
return True
req = RequestFactory().get('/')
assert view(req), 'First request works.'
with self.assertRaises(Ratelimited):
view(req)
def test_block(self):
@ratelimit(ip=True, method=None, rate='1/m', block=True)
def blocked(request):
return request.limited
@ratelimit(ip=True, method=None, rate='1/m', block=False)
def unblocked(request):
return request.limited
req = RequestFactory().get('/')
assert not blocked(req), 'First request works.'
with self.assertRaises(Ratelimited):
blocked(req)
assert unblocked(req), 'Request is limited but not blocked.'
def test_method(self):
rf = RequestFactory()
post = rf.post('/')
get = rf.get('/')
@ratelimit(ip=True, method=['POST'], rate='1/m')
def limit_post(request):
return request.limited
@ratelimit(ip=True, method=['POST', 'GET'], rate='1/m')
def limit_get(request):
return request.limited
assert not limit_post(post), 'Do not limit first POST.'
assert limit_post(post), 'Limit second POST.'
assert not limit_post(get), 'Do not limit GET.'
assert limit_get(post), 'Limit first POST.'
assert limit_get(get), 'Limit first GET.'
def test_field(self):
james = RequestFactory().post('/', {'username': 'james'})
john = RequestFactory().post('/', {'username': 'john'})
@ratelimit(ip=False, field='username', rate='1/m')
def username(request):
return request.limited
assert not username(james), "james' first request is fine."
assert username(james), "james' second request is limited."
assert not username(john), "john's first request is fine."
def test_field_unicode(self):
post = RequestFactory().post('/', {'username': u'fran\xe7ois'})
@ratelimit(ip=False, field='username', rate='1/m')
def view(request):
return request.limited
assert not view(post), 'First request is not limited.'
assert view(post), 'Second request is limited.'
def test_field_empty(self):
post = RequestFactory().post('/', {})
@ratelimit(ip=False, field='username', rate='1/m')
def view(request):
return request.limited
assert not view(post), 'First request is not limited.'
assert view(post), 'Second request is limited.'
def test_rate(self):
req = RequestFactory().post('/')
@ratelimit(ip=True, rate='2/m')
def twice(request):
return request.limited
assert not twice(req), 'First request is not limited.'
assert not twice(req), 'Second request is not limited.'
assert twice(req), 'Third request is limited.'
def test_skip_if(self):
req = RequestFactory().post('/')
@ratelimit(rate='1/m', skip_if=lambda r: getattr(r, 'skip', False))
def view(request):
return request.limited
assert not view(req), 'First request is not limited.'
assert view(req), 'Second request is limited.'
del req.limited
req.skip = True
assert not view(req), 'Skipped request is not limited.'
@override_settings(RATELIMIT_USE_CACHE='fake.cache')
def test_bad_cache(self):
"""The RATELIMIT_USE_CACHE setting works if the cache exists."""
@ratelimit()
def view(request):
return request
req = RequestFactory().post('/')
with self.assertRaises(InvalidCacheBackendError):
view(req)
def test_keys(self):
"""Allow custom functions to set cache keys."""
class User(object):
def __init__(self, authenticated=False):
self.pk = 1
self.authenticated = authenticated
def is_authenticated(self):
return self.authenticated
def user_or_ip(req):
if req.user.is_authenticated():
return 'uip:%d' % req.user.pk
return 'uip:%s' % req.META['REMOTE_ADDR']
@ratelimit(ip=False, rate='1/m', block=False, keys=user_or_ip)
def view(request):
return request.limited
req = RequestFactory().post('/')
req.user = User(authenticated=False)
assert not view(req), 'First unauthenticated request is allowed.'
assert view(req), 'Second unauthenticated request is limited.'
del req.limited
req.user = User(authenticated=True)
assert not view(req), 'First authenticated request is allowed.'
assert view(req), 'Second authenticated is limited.'
def test_stacked_decorator(self):
"""Allow @ratelimit to be stacked."""
# Put the shorter one first and make sure the second one doesn't
# reset request.limited back to False.
@ratelimit(ip=False, rate='1/m', block=False, keys=lambda x: 'min')
@ratelimit(ip=False, rate='10/d', block=False, keys=lambda x: 'day')
def view(request):
return request.limited
req = RequestFactory().post('/')
assert not view(req), 'First unauthenticated request is allowed.'
assert view(req), 'Second unauthenticated request is limited.'
def test_is_ratelimited(self):
def get_keys(request):
return 'test_is_ratelimited_key'
def not_increment(request):
return is_ratelimited(request, increment=False, ip=False,
method=None, keys=[get_keys], rate='1/m')
def do_increment(request):
return is_ratelimited(request, increment=True, ip=False,
method=None, keys=[get_keys], rate='1/m')
req = RequestFactory().get('/')
# Does not increment. Count still 0. Does not rate limit
# because 0 < 1.
assert not not_increment(req), 'Request should not be rate limited.'
# Increments. Does not rate limit because 0 < 1. Count now 1.
assert not do_increment(req), 'Request should not be rate limited.'
# Does not increment. Count still 1. Rate limits because 1 < 1
# is false.
assert not_increment(req), 'Request should be rate limited.'
#do it here, since python < 2.7 does not have unittest.skipIf
if django.VERSION >= (1, 4):
class RateLimitCBVTests(TestCase):
SKIP_REASON = u'Class Based View supported by Django >=1.4'
def setUp(self):
cache.clear()
def test_limit_ip(self):
class RLView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_method = None
ratelimit_rate = '1/m'
ratelimit_block = True
rlview = RLView.as_view()
req = RequestFactory().get('/')
assert rlview(req), 'First request works.'
with self.assertRaises(Ratelimited):
rlview(req)
def test_block(self):
class BlockedView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_method = None
ratelimit_rate = '1/m'
ratelimit_block = True
def get(self, request, *args, **kwargs):
return request.limited
class UnBlockedView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_method = None
ratelimit_rate = '1/m'
ratelimit_block = False
def get(self, request, *args, **kwargs):
return request.limited
blocked = BlockedView.as_view()
unblocked = UnBlockedView.as_view()
req = RequestFactory().get('/')
assert not blocked(req), 'First request works.'
with self.assertRaises(Ratelimited):
blocked(req)
assert unblocked(req), 'Request is limited but not blocked.'
def test_method(self):
rf = RequestFactory()
post = rf.post('/')
get = rf.get('/')
class LimitPostView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_method = ['POST']
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
class LimitGetView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_method = ['POST', 'GET']
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
limit_post = LimitPostView.as_view()
limit_get = LimitGetView.as_view()
assert not limit_post(post), 'Do not limit first POST.'
assert limit_post(post), 'Limit second POST.'
assert not limit_post(get), 'Do not limit GET.'
assert limit_get(post), 'Limit first POST.'
assert limit_get(get), 'Limit first GET.'
def test_field(self):
james = RequestFactory().post('/', {'username': 'james'})
john = RequestFactory().post('/', {'username': 'john'})
class UsernameView(RateLimitMixin, View):
ratelimit_ip = False
ratelimit_field = 'username'
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
username = UsernameView.as_view()
assert not username(james), "james' first request is fine."
assert username(james), "james' second request is limited."
assert not username(john), "john's first request is fine."
def test_field_unicode(self):
post = RequestFactory().post('/', {'username': u'fran\xe7ois'})
class UsernameView(RateLimitMixin, View):
ratelimit_ip = False
ratelimit_field = 'username'
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
view = UsernameView.as_view()
assert not view(post), 'First request is not limited.'
assert view(post), 'Second request is limited.'
def test_field_empty(self):
post = RequestFactory().post('/', {})
class EmptyFieldView(RateLimitMixin, View):
ratelimit_ip = False
ratelimit_field = 'username'
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
view = EmptyFieldView.as_view()
assert not view(post), 'First request is not limited.'
assert view(post), 'Second request is limited.'
def test_rate(self):
req = RequestFactory().post('/')
class TwiceView(RateLimitMixin, View):
ratelimit_ip = True
ratelimit_rate = '2/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
twice = TwiceView.as_view()
assert not twice(req), 'First request is not limited.'
assert not twice(req), 'Second request is not limited.'
assert twice(req), 'Third request is limited.'
def test_skip_if(self):
req = RequestFactory().post('/')
class SkipIfView(RateLimitMixin, View):
ratelimit_rate = '1/m'
ratelimit_skip_if = lambda r: getattr(r, 'skip', False)
def post(self, request, *args, **kwargs):
return request.limited
get = post
view = SkipIfView.as_view()
assert not view(req), 'First request is not limited.'
assert view(req), 'Second request is limited.'
del req.limited
req.skip = True
assert not view(req), 'Skipped request is not limited.'
@override_settings(RATELIMIT_USE_CACHE='fake-cache')
def test_bad_cache(self):
"""The RATELIMIT_USE_CACHE setting works if the cache exists."""
class BadCacheView(RateLimitMixin, View):
def post(self, request, *args, **kwargs):
return request
get = post
view = BadCacheView.as_view()
req = RequestFactory().post('/')
with self.assertRaises(InvalidCacheBackendError):
view(req)
def test_keys(self):
"""Allow custom functions to set cache keys."""
class User(object):
def __init__(self, authenticated=False):
self.pk = 1
self.authenticated = authenticated
def is_authenticated(self):
return self.authenticated
def user_or_ip(req):
if req.user.is_authenticated():
return 'uip:%d' % req.user.pk
return 'uip:%s' % req.META['REMOTE_ADDR']
class KeysView(RateLimitMixin, View):
ratelimit_ip = False
ratelimit_block = False
ratelimit_rate = '1/m'
ratelimit_keys = user_or_ip
def post(self, request, *args, **kwargs):
return request.limited
get = post
view = KeysView.as_view()
req = RequestFactory().post('/')
req.user = User(authenticated=False)
assert not view(req), 'First unauthenticated request is allowed.'
assert view(req), 'Second unauthenticated request is limited.'
del req.limited
req.user = User(authenticated=True)
assert not view(req), 'First authenticated request is allowed.'
assert view(req), 'Second authenticated is limited.'