misc: split utils in specific modules (#21165)
This commit is contained in:
parent
5eba146729
commit
71db20a5e4
355
combo/utils.py
355
combo/utils.py
|
@ -1,355 +0,0 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import datetime
|
||||
import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
import binascii
|
||||
import re
|
||||
|
||||
from HTMLParser import HTMLParser
|
||||
import logging
|
||||
import random
|
||||
from StringIO import StringIO
|
||||
import urlparse
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Protocol.KDF import PBKDF2
|
||||
from Crypto import Random
|
||||
|
||||
from requests import Response, Session as RequestsSession
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
|
||||
from django.template.context import BaseContext
|
||||
from django.utils.html import strip_tags
|
||||
from django.utils.http import urlencode, quote
|
||||
|
||||
from .middleware import get_request
|
||||
|
||||
|
||||
class DecryptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def aes_hex_encrypt(key, data):
|
||||
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
|
||||
new IV is generated each time, the IV is also used as salt for PBKDF2.
|
||||
'''
|
||||
iv = Random.get_random_bytes(2) * 8
|
||||
aes_key = PBKDF2(key, iv)
|
||||
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
||||
crypted = aes.encrypt(data)
|
||||
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
|
||||
|
||||
|
||||
def aes_hex_decrypt(key, payload, raise_on_error=True):
|
||||
'''Decrypt data encrypted with aes_base64_encrypt'''
|
||||
try:
|
||||
iv, crypted = payload[:4], payload[4:]
|
||||
except (ValueError, TypeError):
|
||||
if raise_on_error:
|
||||
raise DecryptionError('bad payload')
|
||||
return None
|
||||
try:
|
||||
iv = binascii.unhexlify(iv) * 8
|
||||
crypted = binascii.unhexlify(crypted)
|
||||
except TypeError:
|
||||
if raise_on_error:
|
||||
raise DecryptionError('incorrect hexadecimal encoding')
|
||||
return None
|
||||
aes_key = PBKDF2(key, iv)
|
||||
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
||||
return aes.decrypt(crypted)
|
||||
|
||||
|
||||
class NothingInCacheException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Requests(RequestsSession):
|
||||
|
||||
def request(self, method, url, **kwargs):
|
||||
remote_service = kwargs.pop('remote_service', None)
|
||||
cache_duration = kwargs.pop('cache_duration', 15)
|
||||
invalidate_cache = kwargs.pop('invalidate_cache', False)
|
||||
user = kwargs.pop('user', None)
|
||||
without_user = kwargs.pop('without_user', False)
|
||||
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
|
||||
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
|
||||
log_errors = kwargs.pop('log_errors', True)
|
||||
|
||||
if remote_service == 'auto':
|
||||
remote_service = None
|
||||
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
||||
for services in settings.KNOWN_SERVICES.values():
|
||||
for service in services.values():
|
||||
remote_url = service.get('url')
|
||||
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
|
||||
urlparse.urlparse(remote_url)
|
||||
if remote_scheme == scheme and remote_netloc == netloc:
|
||||
remote_service = service
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
if remote_service:
|
||||
# only keeps the path (URI) in url parameter, scheme and netloc are
|
||||
# in remote_service
|
||||
url = urlparse.urlunparse(('', '', path, params, query, fragment))
|
||||
else:
|
||||
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
|
||||
|
||||
if remote_service:
|
||||
if isinstance(user, dict):
|
||||
query_params = user.copy()
|
||||
elif not user or not user.is_authenticated():
|
||||
if without_user:
|
||||
query_params = {}
|
||||
else:
|
||||
query_params = {'NameID': '', 'email': ''}
|
||||
else:
|
||||
query_params = {}
|
||||
if federation_key == 'nameid':
|
||||
query_params['NameID'] = user.saml_identifiers.first().name_id
|
||||
elif federation_key == 'email':
|
||||
query_params['email'] = user.email
|
||||
else: # 'auto'
|
||||
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
||||
query_params['NameID'] = user.saml_identifiers.first().name_id
|
||||
else:
|
||||
query_params['email'] = user.email
|
||||
|
||||
query_params['orig'] = remote_service.get('orig')
|
||||
|
||||
remote_service_base_url = remote_service.get('url')
|
||||
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
|
||||
remote_service_base_url)
|
||||
|
||||
query = urlencode(query_params)
|
||||
if '?' in url:
|
||||
path, old_query = url.split('?')
|
||||
query += '&' + old_query
|
||||
else:
|
||||
path = url
|
||||
|
||||
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
|
||||
|
||||
if method == 'GET' and cache_duration:
|
||||
# handle cache
|
||||
cache_key = hashlib.md5(url).hexdigest()
|
||||
cache_content = cache.get(cache_key)
|
||||
if cache_content and not invalidate_cache:
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.raw = StringIO(cache_content)
|
||||
return response
|
||||
elif raise_if_not_cached:
|
||||
raise NothingInCacheException()
|
||||
|
||||
if remote_service: # sign
|
||||
url = sign_url(url, remote_service.get('secret'))
|
||||
|
||||
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
|
||||
|
||||
response = super(Requests, self).request(method, url, **kwargs)
|
||||
if log_errors and (response.status_code // 100 != 2):
|
||||
logging.error('failed to %s %s (%s)', method, url, response.status_code)
|
||||
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
|
||||
cache.set(cache_key, response.content, cache_duration)
|
||||
|
||||
return response
|
||||
|
||||
requests = Requests()
|
||||
|
||||
|
||||
class TemplateError(Exception):
|
||||
def __init__(self, msg, params=()):
|
||||
self.msg = msg
|
||||
self.params = params
|
||||
|
||||
def __str__(self):
|
||||
return self.msg % self.params
|
||||
|
||||
|
||||
def get_templated_url(url, context=None):
|
||||
if '{{' not in url and '{%' not in url and '[' not in url:
|
||||
return url
|
||||
template_vars = Context()
|
||||
if context:
|
||||
template_vars.update(context)
|
||||
template_vars['user_email'] = ''
|
||||
template_vars['user_nameid'] = ''
|
||||
user = getattr(context.get('request'), 'user', None)
|
||||
if user and user.is_authenticated():
|
||||
template_vars['user_email'] = quote(user.email)
|
||||
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
||||
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
|
||||
template_vars.update(settings.TEMPLATE_VARS)
|
||||
if '{{' in url or '{%' in url: # Django template
|
||||
try:
|
||||
return Template(url).render(template_vars)
|
||||
except VariableDoesNotExist as e:
|
||||
raise TemplateError(e.msg, e.params)
|
||||
except TemplateSyntaxError:
|
||||
raise TemplateError('syntax error')
|
||||
# ezt-like template
|
||||
def repl(matchobj):
|
||||
varname = matchobj.group(0)[1:-1]
|
||||
if varname == '[':
|
||||
return '['
|
||||
if varname not in template_vars:
|
||||
raise TemplateError('unknown variable %s', varname)
|
||||
return unicode(template_vars[varname])
|
||||
return re.sub(r'(\[.+?\])', repl, url)
|
||||
|
||||
|
||||
# Simple signature scheme for query strings
|
||||
|
||||
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
|
||||
parsed = urlparse.urlparse(url)
|
||||
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
|
||||
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
|
||||
|
||||
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
|
||||
if timestamp is None:
|
||||
timestamp = datetime.datetime.utcnow()
|
||||
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
if nonce is None:
|
||||
nonce = hex(random.getrandbits(128))[2:]
|
||||
new_query = query
|
||||
if new_query:
|
||||
new_query += '&'
|
||||
new_query += urlencode((
|
||||
('algo', algo),
|
||||
('timestamp', timestamp),
|
||||
('nonce', nonce)))
|
||||
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
|
||||
new_query += '&signature=' + quote(signature)
|
||||
return new_query
|
||||
|
||||
def sign_string(s, key, algo='sha256', timedelta=30):
|
||||
digestmod = getattr(hashlib, algo)
|
||||
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
|
||||
return hash.digest()
|
||||
|
||||
def ellipsize(text, length=50):
|
||||
text = HTMLParser().unescape(strip_tags(text))
|
||||
if len(text) < length:
|
||||
return text
|
||||
return text[:(length-10)] + '...'
|
||||
|
||||
|
||||
def check_request_signature(django_request, keys=[]):
|
||||
query_string = django_request.META['QUERY_STRING']
|
||||
if not query_string:
|
||||
return False
|
||||
orig = django_request.GET.get('orig', '')
|
||||
known_services = getattr(settings, 'KNOWN_SERVICES', None)
|
||||
if known_services and orig:
|
||||
for services in known_services.itervalues():
|
||||
for service in services.itervalues():
|
||||
if 'verif_orig' in service and service['verif_orig'] == orig:
|
||||
keys.append(service['secret'])
|
||||
break
|
||||
return check_query(query_string, keys)
|
||||
|
||||
|
||||
def check_query(query, keys, known_nonce=None, timedelta=30):
|
||||
parsed = urlparse.parse_qs(query)
|
||||
if not ('signature' in parsed and 'algo' in parsed and
|
||||
'timestamp' in parsed and 'nonce' in parsed):
|
||||
return False
|
||||
signature = base64.b64decode(parsed['signature'][0])
|
||||
algo = parsed['algo'][0]
|
||||
timestamp = parsed['timestamp'][0]
|
||||
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
|
||||
nonce = parsed['nonce']
|
||||
unsigned_query = query.split('&signature=')[0]
|
||||
if known_nonce is not None and known_nonce(nonce):
|
||||
return False
|
||||
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
|
||||
return False
|
||||
return check_string(unsigned_query, signature, keys, algo=algo)
|
||||
|
||||
|
||||
def check_string(s, signature, keys, algo='sha256'):
|
||||
if not isinstance(keys, list):
|
||||
keys = [keys]
|
||||
for key in keys:
|
||||
signature2 = sign_string(s, key, algo=algo)
|
||||
if len(signature2) != len(signature):
|
||||
continue
|
||||
res = 0
|
||||
# constant time compare
|
||||
for a, b in zip(signature, signature2):
|
||||
res |= ord(a) ^ ord(b)
|
||||
if res == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
|
||||
def _make_key(args, kwds,
|
||||
kwd_mark=(object(),),
|
||||
fasttypes={int, str, frozenset, type(None)},
|
||||
tuple=tuple, type=type, len=len):
|
||||
key = args
|
||||
if kwds:
|
||||
key += kwd_mark
|
||||
for item in kwds.items():
|
||||
key += item
|
||||
if len(key) == 1 and type(key[0]) in fasttypes:
|
||||
return key[0]
|
||||
return _HashedSeq(key)
|
||||
|
||||
|
||||
class _HashedSeq(list):
|
||||
__slots__ = 'hashvalue'
|
||||
|
||||
def __init__(self, tup, hash=hash):
|
||||
self[:] = tup
|
||||
self.hashvalue = hash(tup)
|
||||
|
||||
def __hash__(self):
|
||||
return self.hashvalue
|
||||
|
||||
|
||||
def cache_during_request(func):
|
||||
def inner(*args, **kwargs):
|
||||
request = get_request()
|
||||
if request:
|
||||
cache_key = (id(func), _make_key(args, kwargs))
|
||||
if cache_key in request.cache:
|
||||
return request.cache[cache_key]
|
||||
result = func(*args, **kwargs)
|
||||
if request:
|
||||
request.cache[cache_key] = result
|
||||
return result
|
||||
return inner
|
||||
|
||||
|
||||
def flatten_context(context):
|
||||
# flatten a context to a dictionary, with full support for embedded Context
|
||||
# objects.
|
||||
flat_context = {}
|
||||
if isinstance(context, BaseContext):
|
||||
for ctx in context.dicts:
|
||||
flat_context.update(flatten_context(ctx))
|
||||
else:
|
||||
flat_context.update(context)
|
||||
return flat_context
|
|
@ -0,0 +1,23 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# import specific symbols for compatibility
|
||||
from .cache import cache_during_request
|
||||
from .crypto import aes_hex_decrypt, aes_hex_encrypt, DecryptionError
|
||||
from .misc import ellipsize, flatten_context
|
||||
from .requests_wrapper import requests, NothingInCacheException
|
||||
from .signature import check_query, check_request_signature, sign_url
|
||||
from .urls import get_templated_url, TemplateError
|
|
@ -0,0 +1,56 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from combo.middleware import get_request
|
||||
|
||||
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
|
||||
def _make_key(args, kwds,
|
||||
kwd_mark=(object(),),
|
||||
fasttypes={int, str, frozenset, type(None)},
|
||||
tuple=tuple, type=type, len=len):
|
||||
key = args
|
||||
if kwds:
|
||||
key += kwd_mark
|
||||
for item in kwds.items():
|
||||
key += item
|
||||
if len(key) == 1 and type(key[0]) in fasttypes:
|
||||
return key[0]
|
||||
return _HashedSeq(key)
|
||||
|
||||
|
||||
class _HashedSeq(list):
|
||||
__slots__ = 'hashvalue'
|
||||
|
||||
def __init__(self, tup, hash=hash):
|
||||
self[:] = tup
|
||||
self.hashvalue = hash(tup)
|
||||
|
||||
def __hash__(self):
|
||||
return self.hashvalue
|
||||
|
||||
|
||||
def cache_during_request(func):
|
||||
def inner(*args, **kwargs):
|
||||
request = get_request()
|
||||
if request:
|
||||
cache_key = (id(func), _make_key(args, kwargs))
|
||||
if cache_key in request.cache:
|
||||
return request.cache[cache_key]
|
||||
result = func(*args, **kwargs)
|
||||
if request:
|
||||
request.cache[cache_key] = result
|
||||
return result
|
||||
return inner
|
|
@ -0,0 +1,55 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import binascii
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Protocol.KDF import PBKDF2
|
||||
from Crypto import Random
|
||||
|
||||
|
||||
class DecryptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def aes_hex_encrypt(key, data):
|
||||
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
|
||||
new IV is generated each time, the IV is also used as salt for PBKDF2.
|
||||
'''
|
||||
iv = Random.get_random_bytes(2) * 8
|
||||
aes_key = PBKDF2(key, iv)
|
||||
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
||||
crypted = aes.encrypt(data)
|
||||
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
|
||||
|
||||
def aes_hex_decrypt(key, payload, raise_on_error=True):
|
||||
'''Decrypt data encrypted with aes_base64_encrypt'''
|
||||
try:
|
||||
iv, crypted = payload[:4], payload[4:]
|
||||
except (ValueError, TypeError):
|
||||
if raise_on_error:
|
||||
raise DecryptionError('bad payload')
|
||||
return None
|
||||
try:
|
||||
iv = binascii.unhexlify(iv) * 8
|
||||
crypted = binascii.unhexlify(crypted)
|
||||
except TypeError:
|
||||
if raise_on_error:
|
||||
raise DecryptionError('incorrect hexadecimal encoding')
|
||||
return None
|
||||
aes_key = PBKDF2(key, iv)
|
||||
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
||||
return aes.decrypt(crypted)
|
|
@ -0,0 +1,39 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from HTMLParser import HTMLParser
|
||||
|
||||
from django.template.context import BaseContext
|
||||
from django.utils.html import strip_tags
|
||||
|
||||
|
||||
def ellipsize(text, length=50):
|
||||
text = HTMLParser().unescape(strip_tags(text))
|
||||
if len(text) < length:
|
||||
return text
|
||||
return text[:(length-10)] + '...'
|
||||
|
||||
|
||||
def flatten_context(context):
|
||||
# flatten a context to a dictionary, with full support for embedded Context
|
||||
# objects.
|
||||
flat_context = {}
|
||||
if isinstance(context, BaseContext):
|
||||
for ctx in context.dicts:
|
||||
flat_context.update(flatten_context(ctx))
|
||||
else:
|
||||
flat_context.update(context)
|
||||
return flat_context
|
|
@ -0,0 +1,127 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from StringIO import StringIO
|
||||
import urlparse
|
||||
|
||||
from requests import Response, Session as RequestsSession
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from .signature import sign_url
|
||||
|
||||
class NothingInCacheException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Requests(RequestsSession):
|
||||
|
||||
def request(self, method, url, **kwargs):
|
||||
remote_service = kwargs.pop('remote_service', None)
|
||||
cache_duration = kwargs.pop('cache_duration', 15)
|
||||
invalidate_cache = kwargs.pop('invalidate_cache', False)
|
||||
user = kwargs.pop('user', None)
|
||||
without_user = kwargs.pop('without_user', False)
|
||||
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
|
||||
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
|
||||
log_errors = kwargs.pop('log_errors', True)
|
||||
|
||||
if remote_service == 'auto':
|
||||
remote_service = None
|
||||
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
||||
for services in settings.KNOWN_SERVICES.values():
|
||||
for service in services.values():
|
||||
remote_url = service.get('url')
|
||||
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
|
||||
urlparse.urlparse(remote_url)
|
||||
if remote_scheme == scheme and remote_netloc == netloc:
|
||||
remote_service = service
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
if remote_service:
|
||||
# only keeps the path (URI) in url parameter, scheme and netloc are
|
||||
# in remote_service
|
||||
url = urlparse.urlunparse(('', '', path, params, query, fragment))
|
||||
else:
|
||||
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
|
||||
|
||||
if remote_service:
|
||||
if isinstance(user, dict):
|
||||
query_params = user.copy()
|
||||
elif not user or not user.is_authenticated():
|
||||
if without_user:
|
||||
query_params = {}
|
||||
else:
|
||||
query_params = {'NameID': '', 'email': ''}
|
||||
else:
|
||||
query_params = {}
|
||||
if federation_key == 'nameid':
|
||||
query_params['NameID'] = user.saml_identifiers.first().name_id
|
||||
elif federation_key == 'email':
|
||||
query_params['email'] = user.email
|
||||
else: # 'auto'
|
||||
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
||||
query_params['NameID'] = user.saml_identifiers.first().name_id
|
||||
else:
|
||||
query_params['email'] = user.email
|
||||
|
||||
query_params['orig'] = remote_service.get('orig')
|
||||
|
||||
remote_service_base_url = remote_service.get('url')
|
||||
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
|
||||
remote_service_base_url)
|
||||
|
||||
query = urlencode(query_params)
|
||||
if '?' in url:
|
||||
path, old_query = url.split('?')
|
||||
query += '&' + old_query
|
||||
else:
|
||||
path = url
|
||||
|
||||
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
|
||||
|
||||
if method == 'GET' and cache_duration:
|
||||
# handle cache
|
||||
cache_key = hashlib.md5(url).hexdigest()
|
||||
cache_content = cache.get(cache_key)
|
||||
if cache_content and not invalidate_cache:
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.raw = StringIO(cache_content)
|
||||
return response
|
||||
elif raise_if_not_cached:
|
||||
raise NothingInCacheException()
|
||||
|
||||
if remote_service: # sign
|
||||
url = sign_url(url, remote_service.get('secret'))
|
||||
|
||||
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
|
||||
|
||||
response = super(Requests, self).request(method, url, **kwargs)
|
||||
if log_errors and (response.status_code // 100 != 2):
|
||||
logging.error('failed to %s %s (%s)', method, url, response.status_code)
|
||||
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
|
||||
cache.set(cache_key, response.content, cache_duration)
|
||||
|
||||
return response
|
||||
|
||||
requests = Requests()
|
|
@ -0,0 +1,102 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import hmac
|
||||
import hashlib
|
||||
import random
|
||||
import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.http import quote, urlencode
|
||||
|
||||
# Simple signature scheme for query strings
|
||||
|
||||
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
|
||||
parsed = urlparse.urlparse(url)
|
||||
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
|
||||
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
|
||||
|
||||
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
|
||||
if timestamp is None:
|
||||
timestamp = datetime.datetime.utcnow()
|
||||
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
if nonce is None:
|
||||
nonce = hex(random.getrandbits(128))[2:]
|
||||
new_query = query
|
||||
if new_query:
|
||||
new_query += '&'
|
||||
new_query += urlencode((
|
||||
('algo', algo),
|
||||
('timestamp', timestamp),
|
||||
('nonce', nonce)))
|
||||
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
|
||||
new_query += '&signature=' + quote(signature)
|
||||
return new_query
|
||||
|
||||
def sign_string(s, key, algo='sha256', timedelta=30):
|
||||
digestmod = getattr(hashlib, algo)
|
||||
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
|
||||
return hash.digest()
|
||||
|
||||
def check_request_signature(django_request, keys=[]):
|
||||
query_string = django_request.META['QUERY_STRING']
|
||||
if not query_string:
|
||||
return False
|
||||
orig = django_request.GET.get('orig', '')
|
||||
known_services = getattr(settings, 'KNOWN_SERVICES', None)
|
||||
if known_services and orig:
|
||||
for services in known_services.itervalues():
|
||||
for service in services.itervalues():
|
||||
if 'verif_orig' in service and service['verif_orig'] == orig:
|
||||
keys.append(service['secret'])
|
||||
break
|
||||
return check_query(query_string, keys)
|
||||
|
||||
|
||||
def check_query(query, keys, known_nonce=None, timedelta=30):
|
||||
parsed = urlparse.parse_qs(query)
|
||||
if not ('signature' in parsed and 'algo' in parsed and
|
||||
'timestamp' in parsed and 'nonce' in parsed):
|
||||
return False
|
||||
signature = base64.b64decode(parsed['signature'][0])
|
||||
algo = parsed['algo'][0]
|
||||
timestamp = parsed['timestamp'][0]
|
||||
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
|
||||
nonce = parsed['nonce']
|
||||
unsigned_query = query.split('&signature=')[0]
|
||||
if known_nonce is not None and known_nonce(nonce):
|
||||
return False
|
||||
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
|
||||
return False
|
||||
return check_string(unsigned_query, signature, keys, algo=algo)
|
||||
|
||||
|
||||
def check_string(s, signature, keys, algo='sha256'):
|
||||
if not isinstance(keys, list):
|
||||
keys = [keys]
|
||||
for key in keys:
|
||||
signature2 = sign_string(s, key, algo=algo)
|
||||
if len(signature2) != len(signature):
|
||||
continue
|
||||
res = 0
|
||||
# constant time compare
|
||||
for a, b in zip(signature, signature2):
|
||||
res |= ord(a) ^ ord(b)
|
||||
if res == 0:
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,62 @@
|
|||
# combo - content management system
|
||||
# Copyright (C) 2015-2018 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import re
|
||||
|
||||
from django.conf import settings
|
||||
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
|
||||
from django.utils.http import quote
|
||||
|
||||
|
||||
class TemplateError(Exception):
|
||||
def __init__(self, msg, params=()):
|
||||
self.msg = msg
|
||||
self.params = params
|
||||
|
||||
def __str__(self):
|
||||
return self.msg % self.params
|
||||
|
||||
|
||||
def get_templated_url(url, context=None):
|
||||
if '{{' not in url and '{%' not in url and '[' not in url:
|
||||
return url
|
||||
template_vars = Context()
|
||||
if context:
|
||||
template_vars.update(context)
|
||||
template_vars['user_email'] = ''
|
||||
template_vars['user_nameid'] = ''
|
||||
user = getattr(context.get('request'), 'user', None)
|
||||
if user and user.is_authenticated():
|
||||
template_vars['user_email'] = quote(user.email)
|
||||
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
||||
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
|
||||
template_vars.update(settings.TEMPLATE_VARS)
|
||||
if '{{' in url or '{%' in url: # Django template
|
||||
try:
|
||||
return Template(url).render(template_vars)
|
||||
except VariableDoesNotExist as e:
|
||||
raise TemplateError(e.msg, e.params)
|
||||
except TemplateSyntaxError:
|
||||
raise TemplateError('syntax error')
|
||||
# ezt-like template
|
||||
def repl(matchobj):
|
||||
varname = matchobj.group(0)[1:-1]
|
||||
if varname == '[':
|
||||
return '['
|
||||
if varname not in template_vars:
|
||||
raise TemplateError('unknown variable %s', varname)
|
||||
return unicode(template_vars[varname])
|
||||
return re.sub(r'(\[.+?\])', repl, url)
|
|
@ -415,7 +415,7 @@ def test_json_force_async():
|
|||
cell.url = 'http://example.net/test-force-async'
|
||||
cell.template_string = '{{json.hello}}'
|
||||
cell.force_async = True
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(content=json.dumps({'hello': 'world'}), status_code=200)
|
||||
with pytest.raises(NothingInCacheException):
|
||||
cell.render({})
|
||||
|
@ -431,7 +431,7 @@ def test_json_force_async():
|
|||
cell.url = 'http://example.net/test-force-async-2'
|
||||
cell.template_string = '{{json.hello}}'
|
||||
cell.force_async = False
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(content=json.dumps({'hello': 'world2'}), status_code=200)
|
||||
# raise if nothing in cache
|
||||
with pytest.raises(NothingInCacheException):
|
||||
|
|
|
@ -137,7 +137,7 @@ def test_successfull_items_payment(regie, user):
|
|||
assert urlparse.urlparse(qs['return_url'][0]).path.startswith(
|
||||
reverse('lingo-return', kwargs={'regie_pk': regie.id}))
|
||||
# simulate successful call to callback URL
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
resp = client.get(reverse('lingo-callback', kwargs={'regie_pk': regie.id}), args)
|
||||
assert resp.status_code == 200
|
||||
# simulate successful return URL
|
||||
|
@ -273,7 +273,7 @@ def test_cancel_basket_item(key, regie, user):
|
|||
assert BasketItem.objects.filter(amount=21, cancellation_date__isnull=True).exists()
|
||||
basket_item_id_2 = json.loads(resp.content)['id']
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
url = '%s?email=%s&orig=wcs' % (reverse('api-remove-basket-item'), user_email)
|
||||
url = sign_url(url, key)
|
||||
data = {'basket_item_id': basket_item_id, 'notify': 'true'}
|
||||
|
@ -282,7 +282,7 @@ def test_cancel_basket_item(key, regie, user):
|
|||
assert not BasketItem.objects.filter(amount=42, cancellation_date__isnull=True).exists()
|
||||
assert BasketItem.objects.filter(amount=21, cancellation_date__isnull=True).exists()
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
url = '%s?email=%s&orig=wcs' % (reverse('api-remove-basket-item'), user_email)
|
||||
url = sign_url(url, key)
|
||||
data = {'basket_item_id': basket_item_id_2}
|
||||
|
@ -315,7 +315,7 @@ def test_cancel_basket_item_from_cell(key, regie, user):
|
|||
|
||||
# check a successful case
|
||||
login()
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
client.post(reverse('lingo-cancel-item', kwargs={'pk': basket_item_id}))
|
||||
url = request.call_args[0][1]
|
||||
assert url.startswith('http://example.org/testitem/jump/trigger/cancelled')
|
||||
|
@ -367,7 +367,7 @@ def test_payment_callback(regie, user):
|
|||
|
||||
# call callback with GET
|
||||
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
get_resp = client.get(callback_url, data)
|
||||
url = request.call_args[0][1]
|
||||
assert url.startswith('http://example.org/testitem/jump/trigger/paid')
|
||||
|
@ -388,7 +388,7 @@ def test_payment_callback(regie, user):
|
|||
assert data['amount'] == '11.50'
|
||||
|
||||
# call callback with POST
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
post_resp = client.post(callback_url, urllib.urlencode(data),
|
||||
content_type='text/html')
|
||||
assert post_resp.status_code == 200
|
||||
|
@ -415,7 +415,7 @@ def test_payment_callback_no_regie(regie, user):
|
|||
|
||||
# call callback with GET
|
||||
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
get_resp = client.get(callback_url, data)
|
||||
url = request.call_args[0][1]
|
||||
assert url.startswith('http://example.org/testitem/jump/trigger/paid')
|
||||
|
@ -527,7 +527,7 @@ def test_extra_fees(key, regie, user):
|
|||
User.objects.get_or_create(email=user_email)
|
||||
amount = 42
|
||||
data = {'amount': amount, 'display_name': 'test amount'}
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_json = mock.Mock()
|
||||
mock_json.status_code = 200
|
||||
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '5'}]}
|
||||
|
@ -541,7 +541,7 @@ def test_extra_fees(key, regie, user):
|
|||
assert BasketItem.objects.filter(amount=5, extra_fee=True).exists()
|
||||
assert BasketItem.objects.filter(amount=5, extra_fee=True)[0].regie_id == regie.id
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_json = mock.Mock()
|
||||
mock_json.status_code = 200
|
||||
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '7'}]}
|
||||
|
@ -557,7 +557,7 @@ def test_extra_fees(key, regie, user):
|
|||
assert not BasketItem.objects.filter(amount=5, extra_fee=True).exists()
|
||||
assert BasketItem.objects.filter(amount=7, extra_fee=True).exists()
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_json = mock.Mock()
|
||||
mock_json.status_code = 200
|
||||
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '4'}]}
|
||||
|
@ -572,7 +572,7 @@ def test_extra_fees(key, regie, user):
|
|||
# test payment
|
||||
login()
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_json = mock.Mock()
|
||||
mock_json.status_code = 200
|
||||
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '2'}]}
|
||||
|
@ -588,7 +588,7 @@ def test_extra_fees(key, regie, user):
|
|||
assert data['amount'] == '44.00'
|
||||
|
||||
# test again, without specifying a regie
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_json = mock.Mock()
|
||||
mock_json.status_code = 200
|
||||
mock_json.json.return_value = {'err': 0, 'data': [{'subject': 'Extra Fees', 'amount': '3'}]}
|
||||
|
@ -626,7 +626,7 @@ def test_payment_callback_error(regie, user):
|
|||
|
||||
# call callback with GET
|
||||
callback_url = reverse('lingo-callback', kwargs={'regie_pk': regie.id})
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_response = mock.Mock()
|
||||
def kaboom():
|
||||
raise Exception('kaboom')
|
||||
|
@ -651,7 +651,7 @@ def test_payment_callback_error(regie, user):
|
|||
basket_item.payment_date = timezone.now() - timedelta(hours=1)
|
||||
basket_item.save()
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
mock_response = mock.Mock()
|
||||
mock_response.status_code = 200
|
||||
request.return_value = mock_response
|
||||
|
|
|
@ -68,7 +68,7 @@ class MockUser(object):
|
|||
return MockSAMLUser()
|
||||
self.saml_identifiers = MockSAMLUsers()
|
||||
|
||||
@mock.patch('combo.utils.RequestsSession.request')
|
||||
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
|
||||
def test_remote_regie_active_invoices_cell(mock_request, remote_regie):
|
||||
assert remote_regie.is_remote() == True
|
||||
|
||||
|
@ -113,7 +113,7 @@ def test_remote_regie_active_invoices_cell(mock_request, remote_regie):
|
|||
content = cell.render(context)
|
||||
assert 'No items yet' in content
|
||||
|
||||
@mock.patch('combo.utils.RequestsSession.request')
|
||||
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
|
||||
def test_remote_regie_past_invoices_cell(mock_request, remote_regie):
|
||||
assert remote_regie.is_remote() == True
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ def test_geojson_on_restricted_cell(layer, user):
|
|||
user.groups.add(group)
|
||||
user.save()
|
||||
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
@ -155,7 +155,7 @@ def test_get_geojson(layer, user):
|
|||
cell.layers.add(layer)
|
||||
|
||||
# check cache duration
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
@ -173,7 +173,7 @@ def test_get_geojson(layer, user):
|
|||
# check user params
|
||||
layer.geojson_url = 'http://example.org/geojson?t2'
|
||||
layer.save()
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
@ -185,7 +185,7 @@ def test_get_geojson(layer, user):
|
|||
login()
|
||||
layer.geojson_url = 'http://example.org/geojson?t3'
|
||||
layer.save()
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
@ -197,7 +197,7 @@ def test_get_geojson(layer, user):
|
|||
layer.geojson_url = 'http://example.org/geojson?t4'
|
||||
layer.include_user_identifier = False
|
||||
layer.save()
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
@ -210,7 +210,7 @@ def test_get_geojson(layer, user):
|
|||
layer.geojson_url = 'http://example.org/geojson?t5'
|
||||
layer.include_user_identifier = False
|
||||
layer.save()
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(
|
||||
content=SAMPLE_GEOJSON_CONTENT,
|
||||
json=lambda: json.loads(SAMPLE_GEOJSON_CONTENT),
|
||||
|
|
|
@ -147,7 +147,7 @@ def test_get_subscriptions(mock_get, cell, user):
|
|||
assert cell.get_subscriptions(user) == expected_subscriptions
|
||||
assert mock_get.call_args[1]['user'].email == USER_EMAIL
|
||||
|
||||
@mock.patch('combo.utils.RequestsSession.request')
|
||||
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
|
||||
def test_get_subscriptions_signature_check(mock_get, cell, user):
|
||||
restrictions = ('mail', 'sms')
|
||||
cell.transports_restrictions = ','.join(restrictions)
|
||||
|
|
|
@ -25,13 +25,13 @@ class MockUser(object):
|
|||
|
||||
|
||||
def test_nosign():
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
requests.get('http://example.org/foo/bar/')
|
||||
assert request.call_args[0][1] == 'http://example.org/foo/bar/'
|
||||
|
||||
def test_sign():
|
||||
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
requests.get('/foo/bar/', remote_service=remote_service)
|
||||
url = request.call_args[0][1]
|
||||
assert url.startswith('http://example.org/foo/bar/?')
|
||||
|
@ -54,7 +54,7 @@ def test_sign():
|
|||
|
||||
|
||||
def test_auto_sign():
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
requests.get('http://example.org/foo/bar/', remote_service='auto')
|
||||
url = request.call_args[0][1]
|
||||
assert url.startswith('http://example.org/foo/bar/?')
|
||||
|
@ -69,7 +69,7 @@ def test_auto_sign():
|
|||
|
||||
def test_sign_user():
|
||||
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
|
||||
user = MockUser(samlized=True)
|
||||
|
||||
|
@ -110,7 +110,7 @@ def test_sign_user():
|
|||
|
||||
def test_sign_anonymous_user():
|
||||
remote_service = {'url': 'http://example.org', 'secret': 'secret', 'orig': 'myself'}
|
||||
with mock.patch('combo.utils.RequestsSession.request') as request:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as request:
|
||||
|
||||
user = AnonymousUser()
|
||||
|
||||
|
@ -125,7 +125,7 @@ def test_sign_anonymous_user():
|
|||
assert check_query(querystring, 'secret') == True
|
||||
|
||||
def test_requests_cache():
|
||||
with mock.patch('combo.utils.RequestsSession.request') as requests_get:
|
||||
with mock.patch('combo.utils.requests_wrapper.RequestsSession.request') as requests_get:
|
||||
requests_get.return_value = mock.Mock(content='hello world', status_code=200)
|
||||
# default cache, nothing in there
|
||||
assert requests.get('http://cache.example.org/').content == 'hello world'
|
||||
|
|
Loading…
Reference in New Issue