debian-django-ratelimit/ratelimit/utils.py

187 lines
5.1 KiB
Python

import hashlib
import re
import time
import zlib
from importlib import import_module
from django.conf import settings
from django.core.cache import caches
from django.core.exceptions import ImproperlyConfigured
from ratelimit import ALL, UNSAFE
__all__ = ['is_ratelimited']
_PERIODS = {
's': 1,
'm': 60,
'h': 60 * 60,
'd': 24 * 60 * 60,
}
# Extend the expiration time by a few seconds to avoid misses.
EXPIRATION_FUDGE = 5
def user_or_ip(request):
if request.user.is_authenticated:
return str(request.user.pk)
return request.META['REMOTE_ADDR']
_SIMPLE_KEYS = {
'ip': lambda r: r.META['REMOTE_ADDR'],
'user': lambda r: str(r.user.pk),
'user_or_ip': user_or_ip,
}
def get_header(request, header):
key = 'HTTP_' + header.replace('-', '_').upper()
return request.META.get(key, '')
_ACCESSOR_KEYS = {
'get': lambda r, k: r.GET.get(k, ''),
'post': lambda r, k: r.POST.get(k, ''),
'header': get_header,
}
def _method_match(request, method=ALL):
if method == ALL:
return True
if not isinstance(method, (list, tuple)):
method = [method]
return request.method in [m.upper() for m in method]
rate_re = re.compile(r'([\d]+)/([\d]*)([smhd])?')
def _split_rate(rate):
if isinstance(rate, tuple):
return rate
count, multi, period = rate_re.match(rate).groups()
count = int(count)
if not period:
period = 's'
seconds = _PERIODS[period.lower()]
if multi:
seconds = seconds * int(multi)
return count, seconds
def _get_window(value, period):
ts = int(time.time())
if period == 1:
return ts
if not isinstance(value, bytes):
value = value.encode('utf-8')
w = ts - (ts % period) + (zlib.crc32(value) % period)
if w < ts:
return w + period
return w
def _make_cache_key(group, rate, value, methods):
count, period = _split_rate(rate)
safe_rate = '%d/%ds' % (count, period)
window = _get_window(value, period)
parts = [group + safe_rate, value, str(window)]
if methods is not None:
if methods == ALL:
methods = ''
elif isinstance(methods, (list, tuple)):
methods = ''.join(sorted([m.upper() for m in methods]))
parts.append(methods)
prefix = getattr(settings, 'RATELIMIT_CACHE_PREFIX', 'rl:')
return prefix + hashlib.md5(u''.join(parts).encode('utf-8')).hexdigest()
def is_ratelimited(request, group=None, fn=None, key=None, rate=None,
method=ALL, increment=False):
if group is None:
if hasattr(fn, '__self__'):
parts = fn.__module__, fn.__self__.__class__.__name__, fn.__name__
else:
parts = (fn.__module__, fn.__name__)
group = '.'.join(parts)
if not getattr(settings, 'RATELIMIT_ENABLE', True):
request.limited = False
return False
if not _method_match(request, method):
return False
old_limited = getattr(request, 'limited', False)
if callable(rate):
rate = rate(group, request)
if rate is None:
request.limited = old_limited
return False
usage = get_usage_count(request, group, fn, key, rate, method, increment)
fail_open = getattr(settings, 'RATELIMIT_FAIL_OPEN', False)
usage_count = usage.get('count')
if usage_count is None:
limited = not fail_open
else:
usage_limit = usage.get('limit')
limited = usage_count > usage_limit
if increment:
request.limited = old_limited or limited
return limited
def get_usage_count(request, group=None, fn=None, key=None, rate=None,
method=ALL, increment=False):
if not key:
raise ImproperlyConfigured('Ratelimit key must be specified')
limit, period = _split_rate(rate)
cache_name = getattr(settings, 'RATELIMIT_USE_CACHE', 'default')
cache = caches[cache_name]
if callable(key):
value = key(group, request)
elif key in _SIMPLE_KEYS:
value = _SIMPLE_KEYS[key](request)
elif ':' in key:
accessor, k = key.split(':', 1)
if accessor not in _ACCESSOR_KEYS:
raise ImproperlyConfigured('Unknown ratelimit key: %s' % key)
value = _ACCESSOR_KEYS[accessor](request, k)
elif '.' in key:
mod, attr = key.rsplit('.', 1)
keyfn = getattr(import_module(mod), attr)
value = keyfn(group, request)
else:
raise ImproperlyConfigured(
'Could not understand ratelimit key: %s' % key)
cache_key = _make_cache_key(group, rate, value, method)
time_left = _get_window(value, period) - int(time.time())
initial_value = 1 if increment else 0
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE)
if added:
count = initial_value
else:
if increment:
try:
count = cache.incr(cache_key)
except ValueError:
count = initial_value
else:
count = cache.get(cache_key, initial_value)
return {'count': count, 'limit': limit, 'time_left': time_left}
is_ratelimited.ALL = ALL
is_ratelimited.UNSAFE = UNSAFE