combo/combo/utils.py

341 lines
12 KiB
Python

# combo - content management system
# Copyright (C) 2015 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import base64
import hmac
import hashlib
import binascii
import re
from HTMLParser import HTMLParser
import logging
import random
from StringIO import StringIO
import urlparse
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import PBKDF2
from Crypto import Random
from requests import Response, Session as RequestsSession
from django.conf import settings
from django.core.cache import cache
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
from django.utils.html import strip_tags
from django.utils.http import urlencode, quote
from .middleware import get_request
class DecryptionError(Exception):
pass
def aes_hex_encrypt(key, data):
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
new IV is generated each time, the IV is also used as salt for PBKDF2.
'''
iv = Random.get_random_bytes(2) * 8
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
crypted = aes.encrypt(data)
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
def aes_hex_decrypt(key, payload, raise_on_error=True):
'''Decrypt data encrypted with aes_base64_encrypt'''
try:
iv, crypted = payload[:4], payload[4:]
except (ValueError, TypeError):
if raise_on_error:
raise DecryptionError('bad payload')
return None
try:
iv = binascii.unhexlify(iv) * 8
crypted = binascii.unhexlify(crypted)
except TypeError:
if raise_on_error:
raise DecryptionError('incorrect hexadecimal encoding')
return None
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv)
return aes.decrypt(crypted)
class NothingInCacheException(Exception):
pass
class Requests(RequestsSession):
def request(self, method, url, **kwargs):
remote_service = kwargs.pop('remote_service', None)
cache_duration = kwargs.pop('cache_duration', 15)
invalidate_cache = kwargs.pop('invalidate_cache', False)
user = kwargs.pop('user', None)
without_user = kwargs.pop('without_user', False)
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
log_errors = kwargs.pop('log_errors', True)
if remote_service == 'auto':
remote_service = None
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
for services in settings.KNOWN_SERVICES.values():
for service in services.values():
remote_url = service.get('url')
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
urlparse.urlparse(remote_url)
if remote_scheme == scheme and remote_netloc == netloc:
remote_service = service
break
else:
continue
break
if remote_service:
# only keeps the path (URI) in url parameter, scheme and netloc are
# in remote_service
url = urlparse.urlunparse(('', '', path, params, query, fragment))
else:
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
if remote_service:
if isinstance(user, dict):
query_params = user.copy()
elif not user or not user.is_authenticated():
if without_user:
query_params = {}
else:
query_params = {'NameID': '', 'email': ''}
else:
query_params = {}
if federation_key == 'nameid':
query_params['NameID'] = user.saml_identifiers.first().name_id
elif federation_key == 'email':
query_params['email'] = user.email
else: # 'auto'
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
query_params['NameID'] = user.saml_identifiers.first().name_id
else:
query_params['email'] = user.email
query_params['orig'] = remote_service.get('orig')
remote_service_base_url = remote_service.get('url')
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
remote_service_base_url)
query = urlencode(query_params)
if '?' in url:
path, old_query = url.split('?')
query += '&' + old_query
else:
path = url
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
if method == 'GET' and cache_duration:
# handle cache
cache_key = hashlib.md5(url).hexdigest()
cache_content = cache.get(cache_key)
if cache_content and not invalidate_cache:
response = Response()
response.status_code = 200
response.raw = StringIO(cache_content)
return response
elif raise_if_not_cached:
raise NothingInCacheException()
if remote_service: # sign
url = sign_url(url, remote_service.get('secret'))
response = super(Requests, self).request(method, url, **kwargs)
if log_errors and (response.status_code // 100 != 2):
logging.error('failed to %s %s (%s)', method, url, response.status_code)
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(cache_key, response.content, cache_duration)
return response
requests = Requests()
class TemplateError(Exception):
def __init__(self, msg, params=()):
self.msg = msg
self.params = params
def __str__(self):
return self.msg % self.params
def get_templated_url(url, context=None):
if '{{' not in url and '{%' not in url and '[' not in url:
return url
template_vars = Context()
if context:
template_vars.update(context)
template_vars['user_email'] = ''
template_vars['user_nameid'] = ''
user = getattr(context.get('request'), 'user', None)
if user and user.is_authenticated():
template_vars['user_email'] = quote(user.email)
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
template_vars.update(settings.TEMPLATE_VARS)
if '{{' in url or '{%' in url: # Django template
try:
return Template(url).render(template_vars)
except VariableDoesNotExist as e:
raise TemplateError(e.msg, e.params)
except TemplateSyntaxError:
raise TemplateError('syntax error')
# ezt-like template
def repl(matchobj):
varname = matchobj.group(0)[1:-1]
if varname == '[':
return '['
if varname not in template_vars:
raise TemplateError('unknown variable %s', varname)
return unicode(template_vars[varname])
return re.sub(r'(\[.+?\])', repl, url)
# Simple signature scheme for query strings
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
parsed = urlparse.urlparse(url)
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
if timestamp is None:
timestamp = datetime.datetime.utcnow()
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
if nonce is None:
nonce = hex(random.getrandbits(128))[2:]
new_query = query
if new_query:
new_query += '&'
new_query += urlencode((
('algo', algo),
('timestamp', timestamp),
('nonce', nonce)))
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
new_query += '&signature=' + quote(signature)
return new_query
def sign_string(s, key, algo='sha256', timedelta=30):
digestmod = getattr(hashlib, algo)
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
return hash.digest()
def ellipsize(text, length=50):
text = HTMLParser().unescape(strip_tags(text))
if len(text) < length:
return text
return text[:(length-10)] + '...'
def check_request_signature(django_request, keys=[]):
query_string = django_request.META['QUERY_STRING']
if not query_string:
return False
orig = django_request.GET.get('orig', '')
known_services = getattr(settings, 'KNOWN_SERVICES', None)
if known_services and orig:
for services in known_services.itervalues():
for service in services.itervalues():
if 'verif_orig' in service and service['verif_orig'] == orig:
keys.append(service['secret'])
break
return check_query(query_string, keys)
def check_query(query, keys, known_nonce=None, timedelta=30):
parsed = urlparse.parse_qs(query)
if not ('signature' in parsed and 'algo' in parsed and
'timestamp' in parsed and 'nonce' in parsed):
return False
signature = base64.b64decode(parsed['signature'][0])
algo = parsed['algo'][0]
timestamp = parsed['timestamp'][0]
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
nonce = parsed['nonce']
unsigned_query = query.split('&signature=')[0]
if known_nonce is not None and known_nonce(nonce):
return False
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
return False
return check_string(unsigned_query, signature, keys, algo=algo)
def check_string(s, signature, keys, algo='sha256'):
if not isinstance(keys, list):
keys = [keys]
for key in keys:
signature2 = sign_string(s, key, algo=algo)
if len(signature2) != len(signature):
continue
res = 0
# constant time compare
for a, b in zip(signature, signature2):
res |= ord(a) ^ ord(b)
if res == 0:
return True
return False
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
def _make_key(args, kwds,
kwd_mark=(object(),),
fasttypes={int, str, frozenset, type(None)},
tuple=tuple, type=type, len=len):
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
class _HashedSeq(list):
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
def cache_during_request(func):
def inner(*args, **kwargs):
request = get_request()
if request:
cache_key = (id(func), _make_key(args, kwargs))
if cache_key in request.cache:
return request.cache[cache_key]
result = func(*args, **kwargs)
if request:
request.cache[cache_key] = result
return result
return inner