debian-django-ratelimit/ratelimit/tests.py

616 lines
18 KiB
Python

from django.core.cache import cache, InvalidCacheBackendError
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory, TestCase
from django.test.utils import override_settings
from django.views.generic import View
from ratelimit.decorators import ratelimit
from ratelimit.exceptions import Ratelimited
from ratelimit.mixins import RatelimitMixin
from ratelimit.utils import is_ratelimited, _split_rate
rf = RequestFactory()
class MockUser(object):
def __init__(self, authenticated=False):
self.pk = 1
self.is_authenticated = authenticated
class RateParsingTests(TestCase):
def test_simple(self):
tests = (
('100/s', (100, 1)),
('100/10s', (100, 10)),
('100/10', (100, 10)),
('100/m', (100, 60)),
('400/10m', (400, 600)),
('1000/h', (1000, 3600)),
('800/d', (800, 24 * 60 * 60)),
)
for i, o in tests:
assert o == _split_rate(i)
def mykey(group, request):
return request.META['REMOTE_ADDR'][::-1]
class RatelimitTests(TestCase):
def setUp(self):
cache.clear()
def test_no_key(self):
@ratelimit(rate='1/m', block=True)
def view(request):
return True
req = rf.get('/')
with self.assertRaises(ImproperlyConfigured):
view(req)
def test_ip(self):
@ratelimit(key='ip', rate='1/m', block=True)
def view(request):
return True
req = rf.get('/')
assert view(req), 'First request works.'
with self.assertRaises(Ratelimited):
view(req)
def test_block(self):
@ratelimit(key='ip', rate='1/m', block=True)
def blocked(request):
return request.limited
@ratelimit(key='ip', rate='1/m', block=False)
def unblocked(request):
return request.limited
req = rf.get('/')
assert not blocked(req), 'First request works.'
with self.assertRaises(Ratelimited):
blocked(req)
assert unblocked(req), 'Request is limited but not blocked.'
def test_method(self):
post = rf.post('/')
get = rf.get('/')
@ratelimit(key='ip', method='POST', rate='1/m', group='a')
def limit_post(request):
return request.limited
@ratelimit(key='ip', method=['POST', 'GET'], rate='1/m', group='a')
def limit_get(request):
return request.limited
assert not limit_post(post), 'Do not limit first POST.'
assert limit_post(post), 'Limit second POST.'
assert not limit_post(get), 'Do not limit GET.'
assert limit_get(post), 'Limit first POST.'
assert limit_get(get), 'Limit first GET.'
def test_unsafe_methods(self):
@ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m')
def limit_unsafe(request):
return request.limited
get = rf.get('/')
head = rf.head('/')
options = rf.options('/')
delete = rf.delete('/')
post = rf.post('/')
put = rf.put('/')
assert not limit_unsafe(get)
assert not limit_unsafe(head)
assert not limit_unsafe(options)
assert limit_unsafe(delete)
assert limit_unsafe(post)
assert limit_unsafe(put)
# TODO: When all supported versions have this, drop the `if`.
if hasattr(rf, 'patch'):
patch = rf.patch('/')
assert limit_unsafe(patch)
def test_key_get(self):
req_a = rf.get('/', {'foo': 'a'})
req_b = rf.get('/', {'foo': 'b'})
@ratelimit(key='get:foo', rate='1/m', method='GET')
def view(request):
return request.limited
assert not view(req_a)
assert view(req_a)
assert not view(req_b)
assert view(req_b)
def test_key_post(self):
req_a = rf.post('/', {'foo': 'a'})
req_b = rf.post('/', {'foo': 'b'})
@ratelimit(key='post:foo', rate='1/m')
def view(request):
return request.limited
assert not view(req_a)
assert view(req_a)
assert not view(req_b)
assert view(req_b)
def test_key_header(self):
req = rf.post('/')
req.META['HTTP_X_REAL_IP'] = '1.2.3.4'
@ratelimit(key='header:x-real-ip', rate='1/m')
@ratelimit(key='header:x-missing-header', rate='1/m')
def view(request):
return request.limited
assert not view(req)
assert view(req)
def test_rate(self):
req = rf.post('/')
@ratelimit(key='ip', rate='2/m')
def twice(request):
return request.limited
assert not twice(req), 'First request is not limited.'
del req.limited
assert not twice(req), 'Second request is not limited.'
del req.limited
assert twice(req), 'Third request is limited.'
def test_zero_rate(self):
req = rf.post('/')
@ratelimit(key='ip', rate='0/m')
def never(request):
return request.limited
assert never(req)
def test_none_rate(self):
req = rf.post('/')
@ratelimit(key='ip', rate=None)
def always(request):
return request.limited
assert not always(req)
del req.limited
assert not always(req)
del req.limited
assert not always(req)
del req.limited
assert not always(req)
del req.limited
assert not always(req)
del req.limited
assert not always(req)
def test_callable_rate(self):
auth = rf.post('/')
unauth = rf.post('/')
auth.user = MockUser(authenticated=True)
unauth.user = MockUser(authenticated=False)
def get_rate(group, request):
if request.user.is_authenticated:
return (2, 60)
return (1, 60)
@ratelimit(key='user_or_ip', rate=get_rate)
def view(request):
return request.limited
assert not view(unauth)
assert view(unauth)
assert not view(auth)
assert not view(auth)
assert view(auth)
def test_callable_rate_none(self):
req = rf.post('/')
req.never_limit = False
get_rate = lambda g, r: None if r.never_limit else '1/m'
@ratelimit(key='ip', rate=get_rate)
def view(request):
return request.limited
assert not view(req)
del req.limited
assert view(req)
req.never_limit = True
del req.limited
assert not view(req)
del req.limited
assert not view(req)
def test_callable_rate_zero(self):
auth = rf.post('/')
unauth = rf.post('/')
auth.user = MockUser(authenticated=True)
unauth.user = MockUser(authenticated=False)
def get_rate(group, request):
if request.user.is_authenticated:
return '1/m'
return '0/m'
@ratelimit(key='ip', rate=get_rate)
def view(request):
return request.limited
assert view(unauth)
del unauth.limited
assert not view(auth)
del auth.limited
assert view(auth)
assert view(unauth)
@override_settings(RATELIMIT_USE_CACHE='fake-cache')
def test_bad_cache(self):
"""The RATELIMIT_USE_CACHE setting works if the cache exists."""
@ratelimit(key='ip', rate='1/m')
def view(request):
return request
req = rf.post('/')
with self.assertRaises(InvalidCacheBackendError):
view(req)
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_cache_connection_error(self):
@ratelimit(key='ip', rate='1/m')
def view(request):
return request
req = rf.post('/')
assert view(req)
def test_user_or_ip(self):
"""Allow custom functions to set cache keys."""
@ratelimit(key='user_or_ip', rate='1/m', block=False)
def view(request):
return request.limited
unauth = rf.post('/')
unauth.user = MockUser(authenticated=False)
assert not view(unauth), 'First unauthenticated request is allowed.'
assert view(unauth), 'Second unauthenticated request is limited.'
auth = rf.post('/')
auth.user = MockUser(authenticated=True)
assert not view(auth), 'First authenticated request is allowed.'
assert view(auth), 'Second authenticated is limited.'
def test_key_path(self):
@ratelimit(key='ratelimit.tests.mykey', rate='1/m')
def view(request):
return request.limited
req = rf.post('/')
assert not view(req)
assert view(req)
def test_callable_key(self):
@ratelimit(key=mykey, rate='1/m')
def view(request):
return request.limited
req = rf.post('/')
assert not view(req)
assert view(req)
def test_stacked_decorator(self):
"""Allow @ratelimit to be stacked."""
# Put the shorter one first and make sure the second one doesn't
# reset request.limited back to False.
@ratelimit(rate='1/m', block=False, key=lambda x, y: 'min')
@ratelimit(rate='10/d', block=False, key=lambda x, y: 'day')
def view(request):
return request.limited
req = rf.post('/')
assert not view(req), 'First unauthenticated request is allowed.'
assert view(req), 'Second unauthenticated request is limited.'
def test_stacked_methods(self):
"""Different methods should result in different counts."""
@ratelimit(rate='1/m', key='ip', method='GET')
@ratelimit(rate='1/m', key='ip', method='POST')
def view(request):
return request.limited
get = rf.get('/')
post = rf.post('/')
assert not view(get)
assert not view(post)
assert view(get)
assert view(post)
def test_sorted_methods(self):
"""Order of the methods shouldn't matter."""
@ratelimit(rate='1/m', key='ip', method=['GET', 'POST'], group='a')
def get_post(request):
return request.limited
@ratelimit(rate='1/m', key='ip', method=['POST', 'GET'], group='a')
def post_get(request):
return request.limited
req = rf.get('/')
assert not get_post(req)
assert post_get(req)
def test_is_ratelimited(self):
def get_key(group, request):
return 'test_is_ratelimited_key'
def not_increment(request):
return is_ratelimited(request, increment=False,
method=is_ratelimited.ALL, key=get_key,
rate='1/m', group='a')
def do_increment(request):
return is_ratelimited(request, increment=True,
method=is_ratelimited.ALL, key=get_key,
rate='1/m', group='a')
req = rf.get('/')
# Does not increment. Count still 0. Does not rate limit
# because 0 < 1.
assert not not_increment(req), 'Request should not be rate limited.'
# Increments. Does not rate limit because 0 < 1. Count now 1.
assert not do_increment(req), 'Request should not be rate limited.'
# Does not increment. Count still 1. Not limited because 1 > 1
# is false.
assert not not_increment(req), 'Request should not be rate limited.'
# Count = 2, 2 > 1.
assert do_increment(req), 'Request should be rate limited.'
assert not_increment(req), 'Request should be rate limited.'
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_is_ratelimited_cache_connection_error_without_increment(self):
def get_key(group, request):
return 'test_is_ratelimited_key'
def not_increment(request):
return is_ratelimited(request, increment=False,
method=is_ratelimited.ALL, key=get_key,
rate='1/m', group='a')
req = rf.get('/')
assert not not_increment(req)
@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_is_ratelimited_cache_connection_error_with_increment(self):
def get_key(group, request):
return 'test_is_ratelimited_key'
def do_increment(request):
return is_ratelimited(request, increment=True,
method=is_ratelimited.ALL, key=get_key,
rate='1/m', group='a')
req = rf.get('/')
assert not do_increment(req)
assert req.limited is False
@override_settings(RATELIMIT_USE_CACHE='connection-errors-redis')
def test_is_ratelimited_cache_connection_error_with_increment_redis(self):
def get_key(group, request):
return 'test_is_ratelimited_key'
def do_increment(request):
return is_ratelimited(request, increment=True,
method=is_ratelimited.ALL, key=get_key,
rate='1/m', group='a')
req = rf.get('/')
assert do_increment(req)
assert req.limited is True
@override_settings(RATELIMIT_USE_CACHE='instant-expiration')
def test_cache_timeout(self):
@ratelimit(key='ip', rate='1/m', block=True)
def view(request):
return True
req = rf.get('/')
assert view(req), 'First request works.'
with self.assertRaises(Ratelimited):
view(req)
class RatelimitCBVTests(TestCase):
def setUp(self):
cache.clear()
def test_limit_ip(self):
class RLView(RatelimitMixin, View):
ratelimit_key = 'ip'
ratelimit_method = ratelimit.ALL
ratelimit_rate = '1/m'
ratelimit_block = True
rlview = RLView.as_view()
req = rf.get('/')
assert rlview(req), 'First request works.'
with self.assertRaises(Ratelimited):
rlview(req)
def test_block(self):
class BlockedView(RatelimitMixin, View):
ratelimit_group = 'cbv:block'
ratelimit_key = 'ip'
ratelimit_method = ratelimit.ALL
ratelimit_rate = '1/m'
ratelimit_block = True
def get(self, request, *args, **kwargs):
return request.limited
class UnBlockedView(RatelimitMixin, View):
ratelimit_group = 'cbv:block'
ratelimit_key = 'ip'
ratelimit_method = ratelimit.ALL
ratelimit_rate = '1/m'
ratelimit_block = False
def get(self, request, *args, **kwargs):
return request.limited
blocked = BlockedView.as_view()
unblocked = UnBlockedView.as_view()
req = rf.get('/')
assert not blocked(req), 'First request works.'
with self.assertRaises(Ratelimited):
blocked(req)
assert unblocked(req), 'Request is limited but not blocked.'
def test_method(self):
post = rf.post('/')
get = rf.get('/')
class LimitPostView(RatelimitMixin, View):
ratelimit_group = 'cbv:method'
ratelimit_key = 'ip'
ratelimit_method = ['POST']
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
class LimitGetView(RatelimitMixin, View):
ratelimit_group = 'cbv:method'
ratelimit_key = 'ip'
ratelimit_method = ['POST', 'GET']
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
limit_post = LimitPostView.as_view()
limit_get = LimitGetView.as_view()
assert not limit_post(post), 'Do not limit first POST.'
assert limit_post(post), 'Limit second POST.'
assert not limit_post(get), 'Do not limit GET.'
assert limit_get(post), 'Limit first POST.'
assert limit_get(get), 'Limit first GET.'
def test_rate(self):
req = rf.post('/')
class TwiceView(RatelimitMixin, View):
ratelimit_key = 'ip'
ratelimit_rate = '2/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
twice = TwiceView.as_view()
assert not twice(req), 'First request is not limited.'
assert not twice(req), 'Second request is not limited.'
assert twice(req), 'Third request is limited.'
@override_settings(RATELIMIT_USE_CACHE='fake-cache')
def test_bad_cache(self):
"""The RATELIMIT_USE_CACHE setting works if the cache exists."""
self.skipTest('I do not know why this fails when the other works.')
class BadCacheView(RatelimitMixin, View):
ratelimit_key = 'ip'
def post(self, request, *args, **kwargs):
return request
get = post
view = BadCacheView.as_view()
req = rf.post('/')
with self.assertRaises(InvalidCacheBackendError):
view(req)
def test_keys(self):
"""Allow custom functions to set cache keys."""
def user_or_ip(group, req):
if req.user.is_authenticated:
return 'uip:%d' % req.user.pk
return 'uip:%s' % req.META['REMOTE_ADDR']
class KeysView(RatelimitMixin, View):
ratelimit_key = user_or_ip
ratelimit_block = False
ratelimit_rate = '1/m'
def post(self, request, *args, **kwargs):
return request.limited
get = post
view = KeysView.as_view()
req = rf.post('/')
req.user = MockUser(authenticated=False)
assert not view(req), 'First unauthenticated request is allowed.'
assert view(req), 'Second unauthenticated request is limited.'
del req.limited
req.user = MockUser(authenticated=True)
assert not view(req), 'First authenticated request is allowed.'
assert view(req), 'Second authenticated is limited.'
def test_method_decorator(self):
class TestView(View):
@ratelimit(key='ip', rate='1/m', block=False)
def post(self, request):
return request.limited
view = TestView.as_view()
req = rf.post('/')
assert not view(req)
assert view(req)