misc: split utils in specific modules (#21165)

This commit is contained in:
Frédéric Péters 2018-01-14 09:24:46 +01:00
parent 5eba146729
commit 71db20a5e4
14 changed files with 495 additions and 386 deletions

View File

@ -1,355 +0,0 @@
# combo - content management system
# Copyright (C) 2015 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import base64
import hmac
import hashlib
import binascii
import re
from HTMLParser import HTMLParser
import logging
import random
from StringIO import StringIO
import urlparse
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import PBKDF2
from Crypto import Random
from requests import Response, Session as RequestsSession
from django.conf import settings
from django.core.cache import cache
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
from django.template.context import BaseContext
from django.utils.html import strip_tags
from django.utils.http import urlencode, quote
from .middleware import get_request
class DecryptionError(Exception):
pass
def aes_hex_encrypt(key, data):
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
new IV is generated each time, the IV is also used as salt for PBKDF2.
'''
iv = Random.get_random_bytes(2) * 8
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
crypted = aes.encrypt(data)
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
def aes_hex_decrypt(key, payload, raise_on_error=True):
'''Decrypt data encrypted with aes_base64_encrypt'''
try:
iv, crypted = payload[:4], payload[4:]
except (ValueError, TypeError):
if raise_on_error:
raise DecryptionError('bad payload')
return None
try:
iv = binascii.unhexlify(iv) * 8
crypted = binascii.unhexlify(crypted)
except TypeError:
if raise_on_error:
raise DecryptionError('incorrect hexadecimal encoding')
return None
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
return aes.decrypt(crypted)
class NothingInCacheException(Exception):
pass
class Requests(RequestsSession):
def request(self, method, url, **kwargs):
remote_service = kwargs.pop('remote_service', None)
cache_duration = kwargs.pop('cache_duration', 15)
invalidate_cache = kwargs.pop('invalidate_cache', False)
user = kwargs.pop('user', None)
without_user = kwargs.pop('without_user', False)
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
log_errors = kwargs.pop('log_errors', True)
if remote_service == 'auto':
remote_service = None
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
for services in settings.KNOWN_SERVICES.values():
for service in services.values():
remote_url = service.get('url')
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
urlparse.urlparse(remote_url)
if remote_scheme == scheme and remote_netloc == netloc:
remote_service = service
break
else:
continue
break
if remote_service:
# only keeps the path (URI) in url parameter, scheme and netloc are
# in remote_service
url = urlparse.urlunparse(('', '', path, params, query, fragment))
else:
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
if remote_service:
if isinstance(user, dict):
query_params = user.copy()
elif not user or not user.is_authenticated():
if without_user:
query_params = {}
else:
query_params = {'NameID': '', 'email': ''}
else:
query_params = {}
if federation_key == 'nameid':
query_params['NameID'] = user.saml_identifiers.first().name_id
elif federation_key == 'email':
query_params['email'] = user.email
else: # 'auto'
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
query_params['NameID'] = user.saml_identifiers.first().name_id
else:
query_params['email'] = user.email
query_params['orig'] = remote_service.get('orig')
remote_service_base_url = remote_service.get('url')
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
remote_service_base_url)
query = urlencode(query_params)
if '?' in url:
path, old_query = url.split('?')
query += '&' + old_query
else:
path = url
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
if method == 'GET' and cache_duration:
# handle cache
cache_key = hashlib.md5(url).hexdigest()
cache_content = cache.get(cache_key)
if cache_content and not invalidate_cache:
response = Response()
response.status_code = 200
response.raw = StringIO(cache_content)
return response
elif raise_if_not_cached:
raise NothingInCacheException()
if remote_service: # sign
url = sign_url(url, remote_service.get('secret'))
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
response = super(Requests, self).request(method, url, **kwargs)
if log_errors and (response.status_code // 100 != 2):
logging.error('failed to %s %s (%s)', method, url, response.status_code)
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(cache_key, response.content, cache_duration)
return response
requests = Requests()
class TemplateError(Exception):
def __init__(self, msg, params=()):
self.msg = msg
self.params = params
def __str__(self):
return self.msg % self.params
def get_templated_url(url, context=None):
if '{{' not in url and '{%' not in url and '[' not in url:
return url
template_vars = Context()
if context:
template_vars.update(context)
template_vars['user_email'] = ''
template_vars['user_nameid'] = ''
user = getattr(context.get('request'), 'user', None)
if user and user.is_authenticated():
template_vars['user_email'] = quote(user.email)
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
template_vars.update(settings.TEMPLATE_VARS)
if '{{' in url or '{%' in url: # Django template
try:
return Template(url).render(template_vars)
except VariableDoesNotExist as e:
raise TemplateError(e.msg, e.params)
except TemplateSyntaxError:
raise TemplateError('syntax error')
# ezt-like template
def repl(matchobj):
varname = matchobj.group(0)[1:-1]
if varname == '[':
return '['
if varname not in template_vars:
raise TemplateError('unknown variable %s', varname)
return unicode(template_vars[varname])
return re.sub(r'(\[.+?\])', repl, url)
# Simple signature scheme for query strings
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
parsed = urlparse.urlparse(url)
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
if timestamp is None:
timestamp = datetime.datetime.utcnow()
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
if nonce is None:
nonce = hex(random.getrandbits(128))[2:]
new_query = query
if new_query:
new_query += '&'
new_query += urlencode((
('algo', algo),
('timestamp', timestamp),
('nonce', nonce)))
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
new_query += '&signature=' + quote(signature)
return new_query
def sign_string(s, key, algo='sha256', timedelta=30):
digestmod = getattr(hashlib, algo)
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
return hash.digest()
def ellipsize(text, length=50):
text = HTMLParser().unescape(strip_tags(text))
if len(text) < length:
return text
return text[:(length-10)] + '...'
def check_request_signature(django_request, keys=[]):
query_string = django_request.META['QUERY_STRING']
if not query_string:
return False
orig = django_request.GET.get('orig', '')
known_services = getattr(settings, 'KNOWN_SERVICES', None)
if known_services and orig:
for services in known_services.itervalues():
for service in services.itervalues():
if 'verif_orig' in service and service['verif_orig'] == orig:
keys.append(service['secret'])
break
return check_query(query_string, keys)
def check_query(query, keys, known_nonce=None, timedelta=30):
parsed = urlparse.parse_qs(query)
if not ('signature' in parsed and 'algo' in parsed and
'timestamp' in parsed and 'nonce' in parsed):
return False
signature = base64.b64decode(parsed['signature'][0])
algo = parsed['algo'][0]
timestamp = parsed['timestamp'][0]
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
nonce = parsed['nonce']
unsigned_query = query.split('&signature=')[0]
if known_nonce is not None and known_nonce(nonce):
return False
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
return False
return check_string(unsigned_query, signature, keys, algo=algo)
def check_string(s, signature, keys, algo='sha256'):
if not isinstance(keys, list):
keys = [keys]
for key in keys:
signature2 = sign_string(s, key, algo=algo)
if len(signature2) != len(signature):
continue
res = 0
# constant time compare
for a, b in zip(signature, signature2):
res |= ord(a) ^ ord(b)
if res == 0:
return True
return False
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
def _make_key(args, kwds,
kwd_mark=(object(),),
fasttypes={int, str, frozenset, type(None)},
tuple=tuple, type=type, len=len):
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
class _HashedSeq(list):
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
def cache_during_request(func):
def inner(*args, **kwargs):
request = get_request()
if request:
cache_key = (id(func), _make_key(args, kwargs))
if cache_key in request.cache:
return request.cache[cache_key]
result = func(*args, **kwargs)
if request:
request.cache[cache_key] = result
return result
return inner
def flatten_context(context):
# flatten a context to a dictionary, with full support for embedded Context
# objects.
flat_context = {}
if isinstance(context, BaseContext):
for ctx in context.dicts:
flat_context.update(flatten_context(ctx))
else:
flat_context.update(context)
return flat_context

23
combo/utils/__init__.py Normal file
View File

@ -0,0 +1,23 @@
# combo - content management system
# Copyright (C) 2015 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# import specific symbols for compatibility
from .cache import cache_during_request
from .crypto import aes_hex_decrypt, aes_hex_encrypt, DecryptionError
from .misc import ellipsize, flatten_context
from .requests_wrapper import requests, NothingInCacheException
from .signature import check_query, check_request_signature, sign_url
from .urls import get_templated_url, TemplateError

56
combo/utils/cache.py Normal file
View File

@ -0,0 +1,56 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from combo.middleware import get_request
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
def _make_key(args, kwds,
kwd_mark=(object(),),
fasttypes={int, str, frozenset, type(None)},
tuple=tuple, type=type, len=len):
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
class _HashedSeq(list):
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
def cache_during_request(func):
def inner(*args, **kwargs):
request = get_request()
if request:
cache_key = (id(func), _make_key(args, kwargs))
if cache_key in request.cache:
return request.cache[cache_key]
result = func(*args, **kwargs)
if request:
request.cache[cache_key] = result
return result
return inner

55
combo/utils/crypto.py Normal file
View File

@ -0,0 +1,55 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import binascii
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import PBKDF2
from Crypto import Random
class DecryptionError(Exception):
pass
def aes_hex_encrypt(key, data):
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
new IV is generated each time, the IV is also used as salt for PBKDF2.
'''
iv = Random.get_random_bytes(2) * 8
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
crypted = aes.encrypt(data)
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
def aes_hex_decrypt(key, payload, raise_on_error=True):
'''Decrypt data encrypted with aes_base64_encrypt'''
try:
iv, crypted = payload[:4], payload[4:]
except (ValueError, TypeError):
if raise_on_error:
raise DecryptionError('bad payload')
return None
try:
iv = binascii.unhexlify(iv) * 8
crypted = binascii.unhexlify(crypted)
except TypeError:
if raise_on_error:
raise DecryptionError('incorrect hexadecimal encoding')
return None
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
return aes.decrypt(crypted)

39
combo/utils/misc.py Normal file
View File

@ -0,0 +1,39 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from HTMLParser import HTMLParser
from django.template.context import BaseContext
from django.utils.html import strip_tags
def ellipsize(text, length=50):
text = HTMLParser().unescape(strip_tags(text))
if len(text) < length:
return text
return text[:(length-10)] + '...'
def flatten_context(context):
# flatten a context to a dictionary, with full support for embedded Context
# objects.
flat_context = {}
if isinstance(context, BaseContext):
for ctx in context.dicts:
flat_context.update(flatten_context(ctx))
else:
flat_context.update(context)
return flat_context

View File

@ -0,0 +1,127 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib
import logging
from StringIO import StringIO
import urlparse
from requests import Response, Session as RequestsSession
from django.conf import settings
from django.core.cache import cache
from django.utils.http import urlencode
from .signature import sign_url
class NothingInCacheException(Exception):
pass
class Requests(RequestsSession):
def request(self, method, url, **kwargs):
remote_service = kwargs.pop('remote_service', None)
cache_duration = kwargs.pop('cache_duration', 15)
invalidate_cache = kwargs.pop('invalidate_cache', False)
user = kwargs.pop('user', None)
without_user = kwargs.pop('without_user', False)
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
log_errors = kwargs.pop('log_errors', True)
if remote_service == 'auto':
remote_service = None
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
for services in settings.KNOWN_SERVICES.values():
for service in services.values():
remote_url = service.get('url')
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
urlparse.urlparse(remote_url)
if remote_scheme == scheme and remote_netloc == netloc:
remote_service = service
break
else:
continue
break
if remote_service:
# only keeps the path (URI) in url parameter, scheme and netloc are
# in remote_service
url = urlparse.urlunparse(('', '', path, params, query, fragment))
else:
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
if remote_service:
if isinstance(user, dict):
query_params = user.copy()
elif not user or not user.is_authenticated():
if without_user:
query_params = {}
else:
query_params = {'NameID': '', 'email': ''}
else:
query_params = {}
if federation_key == 'nameid':
query_params['NameID'] = user.saml_identifiers.first().name_id
elif federation_key == 'email':
query_params['email'] = user.email
else: # 'auto'
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
query_params['NameID'] = user.saml_identifiers.first().name_id
else:
query_params['email'] = user.email
query_params['orig'] = remote_service.get('orig')
remote_service_base_url = remote_service.get('url')
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
remote_service_base_url)
query = urlencode(query_params)
if '?' in url:
path, old_query = url.split('?')
query += '&' + old_query
else:
path = url
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
if method == 'GET' and cache_duration:
# handle cache
cache_key = hashlib.md5(url).hexdigest()
cache_content = cache.get(cache_key)
if cache_content and not invalidate_cache:
response = Response()
response.status_code = 200
response.raw = StringIO(cache_content)
return response
elif raise_if_not_cached:
raise NothingInCacheException()
if remote_service: # sign
url = sign_url(url, remote_service.get('secret'))
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
response = super(Requests, self).request(method, url, **kwargs)
if log_errors and (response.status_code // 100 != 2):
logging.error('failed to %s %s (%s)', method, url, response.status_code)
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(cache_key, response.content, cache_duration)
return response
requests = Requests()

102
combo/utils/signature.py Normal file
View File

@ -0,0 +1,102 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import datetime
import hmac
import hashlib
import random
import urlparse
from django.conf import settings
from django.utils.http import quote, urlencode
# Simple signature scheme for query strings
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
parsed = urlparse.urlparse(url)
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
if timestamp is None:
timestamp = datetime.datetime.utcnow()
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
if nonce is None:
nonce = hex(random.getrandbits(128))[2:]
new_query = query
if new_query:
new_query += '&'
new_query += urlencode((
('algo', algo),
('timestamp', timestamp),
('nonce', nonce)))
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
new_query += '&signature=' + quote(signature)
return new_query
def sign_string(s, key, algo='sha256', timedelta=30):
digestmod = getattr(hashlib, algo)
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
return hash.digest()
def check_request_signature(django_request, keys=[]):
query_string = django_request.META['QUERY_STRING']
if not query_string:
return False
orig = django_request.GET.get('orig', '')
known_services = getattr(settings, 'KNOWN_SERVICES', None)
if known_services and orig:
for services in known_services.itervalues():
for service in services.itervalues():
if 'verif_orig' in service and service['verif_orig'] == orig:
keys.append(service['secret'])
break
return check_query(query_string, keys)
def check_query(query, keys, known_nonce=None, timedelta=30):
parsed = urlparse.parse_qs(query)
if not ('signature' in parsed and 'algo' in parsed and
'timestamp' in parsed and 'nonce' in parsed):
return False
signature = base64.b64decode(parsed['signature'][0])
algo = parsed['algo'][0]
timestamp = parsed['timestamp'][0]
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
nonce = parsed['nonce']
unsigned_query = query.split('&signature=')[0]
if known_nonce is not None and known_nonce(nonce):
return False
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
return False
return check_string(unsigned_query, signature, keys, algo=algo)
def check_string(s, signature, keys, algo='sha256'):
if not isinstance(keys, list):
keys = [keys]
for key in keys:
signature2 = sign_string(s, key, algo=algo)
if len(signature2) != len(signature):
continue
res = 0
# constant time compare
for a, b in zip(signature, signature2):
res |= ord(a) ^ ord(b)
if res == 0:
return True
return False

62
combo/utils/urls.py Normal file
View File

@ -0,0 +1,62 @@
# combo - content management system
# Copyright (C) 2015-2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import re
from django.conf import settings
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
from django.utils.http import quote
class TemplateError(Exception):
def __init__(self, msg, params=()):
self.msg = msg
self.params = params
def __str__(self):
return self.msg % self.params
def get_templated_url(url, context=None):
if '{{' not in url and '{%' not in url and '[' not in url:
return url
template_vars = Context()
if context:
template_vars.update(context)
template_vars['user_email'] = ''
template_vars['user_nameid'] = ''
user = getattr(context.get('request'), 'user', None)
if user and user.is_authenticated():
template_vars['user_email'] = quote(user.email)
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
template_vars.update(settings.TEMPLATE_VARS)
if '{{' in url or '{%' in url: # Django template
try:
return Template(url).render(template_vars)
except VariableDoesNotExist as e:
raise TemplateError(e.msg, e.params)
except TemplateSyntaxError:
raise TemplateError('syntax error')
# ezt-like template
def repl(matchobj):
varname = matchobj.group(0)[1:-1]
if varname == '[':
return '['
if varname not in template_vars:
raise TemplateError('unknown variable %s', varname)
return unicode(template_vars[varname])
return re.sub(r'(\[.+?\])', repl, url)

View File

@ -415,7 +415,7 @@ def test_json_force_async():
cell.url = 'http://example.net/test-force-async'
cell.template_string = '{{json.hello}}'
cell.force_async = True
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(content=json.dumps({'hello': 'world'}), status_code=200)
with pytest.raises(NothingInCacheException):
cell.render({})
@ -431,7 +431,7 @@ def test_json_force_async():
cell.url = 'http://example.net/test-force-async-2'
cell.template_string = '{{json.hello}}'
cell.force_async = False
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(content=json.dumps({'hello': 'world2'}), status_code=200)
# raise if nothing in cache
with pytest.raises(NothingInCacheException):

View File

@ -137,7 +137,7 @@ def test_successfull_items_payment(regie, user):
assert urlparse.urlparse(qs['return_url'][0]).path.startswith(
reverse('lingo-return', kwargs={'regie_pk': regie.id}))
# simulate successful call to callback URL
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
resp = client.get(reverse('lingo-callback', kwargs={'regie_pk': regie.id}), args)
assert resp.status_code == 200
# simulate successful return URL
@ -273,7 +273,7 @@ def test_cancel_basket_item(key, regie, user):
assert BasketItem.objects.filter(amount=21, cancellation_date__isnull=True).exists()
basket_item_id_2 = json.loads(resp.content)['id']
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
url = '%s?email=%s&orig=wcs' % (reverse('api-remove-basket-item'), user_email)
url = sign_url(url, key)
data = {'basket_item_id': basket_item_id, 'notify': 'true'}
@ -282,7 +282,7 @@ def test_cancel_basket_item(key, regie, user):
assert not BasketItem.objects.filter(amount=42, cancellation_date__isnull=True).exists()
assert BasketItem.objects.filter(amount=21, cancellation_date__isnull=True).exists()
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
url = '%s?email=%s&orig=wcs' % (reverse('api-remove-basket-item'), user_email)
url = sign_url(url, key)
data = {'basket_item_id': basket_item_id_2}
@ -315,7 +315,7 @@ def test_cancel_basket_item_from_cell(key, regie, user):
# check a successful case
login()
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
client.post(reverse('lingo-cancel-item', kwargs={'pk': basket_item_id}))
url = request.call_args[0][1]
assert url.startswith('http://example.org/testitem/jump/trigger/cancelled')
@ -367,7 +367,7 @@ def test_payment_callback(regie, user):
# call callback with GET
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
get_resp = client.get(callback_url, data)
url = request.call_args[0][1]
assert url.startswith('http://example.org/testitem/jump/trigger/paid')
@ -388,7 +388,7 @@ def test_payment_callback(regie, user):
assert data['amount'] == '11.50'
# call callback with POST
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
post_resp = client.post(callback_url, urllib.urlencode(data),
content_type='text/html')
assert post_resp.status_code == 200
@ -415,7 +415,7 @@ def test_payment_callback_no_regie(regie, user):
# call callback with GET
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
get_resp = client.get(callback_url, data)
url = request.call_args[0][1]
assert url.startswith('http://example.org/testitem/jump/trigger/paid')
@ -527,7 +527,7 @@ def test_extra_fees(key, regie, user):
User.objects.get_or_create(email=user_email)
amount = 42
data = {'amount': amount, 'display_name': 'test amount'}
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_json = mock.Mock()
mock_json.status_code = 200
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '5'}]}
@ -541,7 +541,7 @@ def test_extra_fees(key, regie, user):
assert BasketItem.objects.filter(amount=5, extra_fee=True).exists()
assert BasketItem.objects.filter(amount=5, extra_fee=True)[0].regie_id == regie.id
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_json = mock.Mock()
mock_json.status_code = 200
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '7'}]}
@ -557,7 +557,7 @@ def test_extra_fees(key, regie, user):
assert not BasketItem.objects.filter(amount=5, extra_fee=True).exists()
assert BasketItem.objects.filter(amount=7, extra_fee=True).exists()
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_json = mock.Mock()
mock_json.status_code = 200
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '4'}]}
@ -572,7 +572,7 @@ def test_extra_fees(key, regie, user):
# test payment
login()
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_json = mock.Mock()
mock_json.status_code = 200
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '2'}]}
@ -588,7 +588,7 @@ def test_extra_fees(key, regie, user):
assert data['amount'] == '44.00'
# test again, without specifying a regie
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_json = mock.Mock()
mock_json.status_code = 200
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '3'}]}
@ -626,7 +626,7 @@ def test_payment_callback_error(regie, user):
# call callback with GET
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_response = mock.Mock()
def kaboom():
raise Exception('kaboom')
@ -651,7 +651,7 @@ def test_payment_callback_error(regie, user):
basket_item.payment_date = timezone.now() - timedelta(hours=1)
basket_item.save()
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
mock_response = mock.Mock()
mock_response.status_code = 200
request.return_value = mock_response

View File

@ -68,7 +68,7 @@ class MockUser(object):
return MockSAMLUser()
self.saml_identifiers = MockSAMLUsers()
@mock.patch('combo.utils.RequestsSession.request')
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
def test_remote_regie_active_invoices_cell(mock_request, remote_regie):
assert remote_regie.is_remote() == True
@ -113,7 +113,7 @@ def test_remote_regie_active_invoices_cell(mock_request, remote_regie):
content = cell.render(context)
assert 'No items yet' in content
@mock.patch('combo.utils.RequestsSession.request')
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
def test_remote_regie_past_invoices_cell(mock_request, remote_regie):
assert remote_regie.is_remote() == True

View File

@ -137,7 +137,7 @@ def test_geojson_on_restricted_cell(layer, user):
user.groups.add(group)
user.save()
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
@ -155,7 +155,7 @@ def test_get_geojson(layer, user):
cell.layers.add(layer)
# check cache duration
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
@ -173,7 +173,7 @@ def test_get_geojson(layer, user):
# check user params
layer.geojson_url = 'http://example.org/geojson?t2'
layer.save()
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
@ -185,7 +185,7 @@ def test_get_geojson(layer, user):
login()
layer.geojson_url = 'http://example.org/geojson?t3'
layer.save()
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
@ -197,7 +197,7 @@ def test_get_geojson(layer, user):
layer.geojson_url = 'http://example.org/geojson?t4'
layer.include_user_identifier = False
layer.save()
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
@ -210,7 +210,7 @@ def test_get_geojson(layer, user):
layer.geojson_url = 'http://example.org/geojson?t5'
layer.include_user_identifier = False
layer.save()
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(
content=SAMPLE_GEOJSON_CONTENT,
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),

View File

@ -147,7 +147,7 @@ def test_get_subscriptions(mock_get, cell, user):
assert cell.get_subscriptions(user) == expected_subscriptions
assert mock_get.call_args[1]['user'].email == USER_EMAIL
@mock.patch('combo.utils.RequestsSession.request')
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
def test_get_subscriptions_signature_check(mock_get, cell, user):
restrictions = ('mail', 'sms')
cell.transports_restrictions = ','.join(restrictions)

View File

@ -25,13 +25,13 @@ class MockUser(object):
def test_nosign():
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
requests.get('http://example.org/foo/bar/')
assert request.call_args[0][1] == 'http://example.org/foo/bar/'
def test_sign():
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
requests.get('/foo/bar/', remote_service=remote_service)
url = request.call_args[0][1]
assert url.startswith('http://example.org/foo/bar/?')
@ -54,7 +54,7 @@ def test_sign():
def test_auto_sign():
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
requests.get('http://example.org/foo/bar/', remote_service='auto')
url = request.call_args[0][1]
assert url.startswith('http://example.org/foo/bar/?')
@ -69,7 +69,7 @@ def test_auto_sign():
def test_sign_user():
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
user = MockUser(samlized=True)
@ -110,7 +110,7 @@ def test_sign_user():
def test_sign_anonymous_user():
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
with mock.patch('combo.utils.RequestsSession.request') as request:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
user = AnonymousUser()
@ -125,7 +125,7 @@ def test_sign_anonymous_user():
assert check_query(querystring, 'secret') == True
def test_requests_cache():
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
requests_get.return_value = mock.Mock(content='hello world', status_code=200)
# default cache, nothing in there
assert requests.get('http://cache.example.org/').content == 'hello world'