356 lines
12 KiB
Python
356 lines
12 KiB
Python
# combo - content management system
|
|
# Copyright (C) 2015 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import datetime
|
|
import base64
|
|
import hmac
|
|
import hashlib
|
|
import binascii
|
|
import re
|
|
|
|
from HTMLParser import HTMLParser
|
|
import logging
|
|
import random
|
|
from StringIO import StringIO
|
|
import urlparse
|
|
|
|
from Crypto.Cipher import AES
|
|
from Crypto.Protocol.KDF import PBKDF2
|
|
from Crypto import Random
|
|
|
|
from requests import Response, Session as RequestsSession
|
|
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from django.template import Context, Template, TemplateSyntaxError, VariableDoesNotExist
|
|
from django.template.context import BaseContext
|
|
from django.utils.html import strip_tags
|
|
from django.utils.http import urlencode, quote
|
|
|
|
from .middleware import get_request
|
|
|
|
|
|
class DecryptionError(Exception):
|
|
pass
|
|
|
|
|
|
def aes_hex_encrypt(key, data):
|
|
'''Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
|
|
new IV is generated each time, the IV is also used as salt for PBKDF2.
|
|
'''
|
|
iv = Random.get_random_bytes(2) * 8
|
|
aes_key = PBKDF2(key, iv)
|
|
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
|
crypted = aes.encrypt(data)
|
|
return '%s%s' % (binascii.hexlify(iv[:2]), binascii.hexlify(crypted))
|
|
|
|
|
|
def aes_hex_decrypt(key, payload, raise_on_error=True):
|
|
'''Decrypt data encrypted with aes_base64_encrypt'''
|
|
try:
|
|
iv, crypted = payload[:4], payload[4:]
|
|
except (ValueError, TypeError):
|
|
if raise_on_error:
|
|
raise DecryptionError('bad payload')
|
|
return None
|
|
try:
|
|
iv = binascii.unhexlify(iv) * 8
|
|
crypted = binascii.unhexlify(crypted)
|
|
except TypeError:
|
|
if raise_on_error:
|
|
raise DecryptionError('incorrect hexadecimal encoding')
|
|
return None
|
|
aes_key = PBKDF2(key, iv)
|
|
aes = AES.new(aes_key, AES.MODE_CFB, iv)
|
|
return aes.decrypt(crypted)
|
|
|
|
|
|
class NothingInCacheException(Exception):
|
|
pass
|
|
|
|
|
|
class Requests(RequestsSession):
|
|
|
|
def request(self, method, url, **kwargs):
|
|
remote_service = kwargs.pop('remote_service', None)
|
|
cache_duration = kwargs.pop('cache_duration', 15)
|
|
invalidate_cache = kwargs.pop('invalidate_cache', False)
|
|
user = kwargs.pop('user', None)
|
|
without_user = kwargs.pop('without_user', False)
|
|
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
|
|
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
|
|
log_errors = kwargs.pop('log_errors', True)
|
|
|
|
if remote_service == 'auto':
|
|
remote_service = None
|
|
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
|
for services in settings.KNOWN_SERVICES.values():
|
|
for service in services.values():
|
|
remote_url = service.get('url')
|
|
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
|
|
urlparse.urlparse(remote_url)
|
|
if remote_scheme == scheme and remote_netloc == netloc:
|
|
remote_service = service
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
if remote_service:
|
|
# only keeps the path (URI) in url parameter, scheme and netloc are
|
|
# in remote_service
|
|
url = urlparse.urlunparse(('', '', path, params, query, fragment))
|
|
else:
|
|
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
|
|
|
|
if remote_service:
|
|
if isinstance(user, dict):
|
|
query_params = user.copy()
|
|
elif not user or not user.is_authenticated():
|
|
if without_user:
|
|
query_params = {}
|
|
else:
|
|
query_params = {'NameID': '', 'email': ''}
|
|
else:
|
|
query_params = {}
|
|
if federation_key == 'nameid':
|
|
query_params['NameID'] = user.saml_identifiers.first().name_id
|
|
elif federation_key == 'email':
|
|
query_params['email'] = user.email
|
|
else: # 'auto'
|
|
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
|
query_params['NameID'] = user.saml_identifiers.first().name_id
|
|
else:
|
|
query_params['email'] = user.email
|
|
|
|
query_params['orig'] = remote_service.get('orig')
|
|
|
|
remote_service_base_url = remote_service.get('url')
|
|
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
|
|
remote_service_base_url)
|
|
|
|
query = urlencode(query_params)
|
|
if '?' in url:
|
|
path, old_query = url.split('?')
|
|
query += '&' + old_query
|
|
else:
|
|
path = url
|
|
|
|
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
|
|
|
|
if method == 'GET' and cache_duration:
|
|
# handle cache
|
|
cache_key = hashlib.md5(url).hexdigest()
|
|
cache_content = cache.get(cache_key)
|
|
if cache_content and not invalidate_cache:
|
|
response = Response()
|
|
response.status_code = 200
|
|
response.raw = StringIO(cache_content)
|
|
return response
|
|
elif raise_if_not_cached:
|
|
raise NothingInCacheException()
|
|
|
|
if remote_service: # sign
|
|
url = sign_url(url, remote_service.get('secret'))
|
|
|
|
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
|
|
|
|
response = super(Requests, self).request(method, url, **kwargs)
|
|
if log_errors and (response.status_code // 100 != 2):
|
|
logging.error('failed to %s %s (%s)', method, url, response.status_code)
|
|
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
|
|
cache.set(cache_key, response.content, cache_duration)
|
|
|
|
return response
|
|
|
|
requests = Requests()
|
|
|
|
|
|
class TemplateError(Exception):
|
|
def __init__(self, msg, params=()):
|
|
self.msg = msg
|
|
self.params = params
|
|
|
|
def __str__(self):
|
|
return self.msg % self.params
|
|
|
|
|
|
def get_templated_url(url, context=None):
|
|
if '{{' not in url and '{%' not in url and '[' not in url:
|
|
return url
|
|
template_vars = Context()
|
|
if context:
|
|
template_vars.update(context)
|
|
template_vars['user_email'] = ''
|
|
template_vars['user_nameid'] = ''
|
|
user = getattr(context.get('request'), 'user', None)
|
|
if user and user.is_authenticated():
|
|
template_vars['user_email'] = quote(user.email)
|
|
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
|
|
template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
|
|
template_vars.update(settings.TEMPLATE_VARS)
|
|
if '{{' in url or '{%' in url: # Django template
|
|
try:
|
|
return Template(url).render(template_vars)
|
|
except VariableDoesNotExist as e:
|
|
raise TemplateError(e.msg, e.params)
|
|
except TemplateSyntaxError:
|
|
raise TemplateError('syntax error')
|
|
# ezt-like template
|
|
def repl(matchobj):
|
|
varname = matchobj.group(0)[1:-1]
|
|
if varname == '[':
|
|
return '['
|
|
if varname not in template_vars:
|
|
raise TemplateError('unknown variable %s', varname)
|
|
return unicode(template_vars[varname])
|
|
return re.sub(r'(\[.+?\])', repl, url)
|
|
|
|
|
|
# Simple signature scheme for query strings
|
|
|
|
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
|
|
parsed = urlparse.urlparse(url)
|
|
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
|
|
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
|
|
|
|
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
|
|
if timestamp is None:
|
|
timestamp = datetime.datetime.utcnow()
|
|
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
if nonce is None:
|
|
nonce = hex(random.getrandbits(128))[2:]
|
|
new_query = query
|
|
if new_query:
|
|
new_query += '&'
|
|
new_query += urlencode((
|
|
('algo', algo),
|
|
('timestamp', timestamp),
|
|
('nonce', nonce)))
|
|
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
|
|
new_query += '&signature=' + quote(signature)
|
|
return new_query
|
|
|
|
def sign_string(s, key, algo='sha256', timedelta=30):
|
|
digestmod = getattr(hashlib, algo)
|
|
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
|
|
return hash.digest()
|
|
|
|
def ellipsize(text, length=50):
|
|
text = HTMLParser().unescape(strip_tags(text))
|
|
if len(text) < length:
|
|
return text
|
|
return text[:(length-10)] + '...'
|
|
|
|
|
|
def check_request_signature(django_request, keys=[]):
|
|
query_string = django_request.META['QUERY_STRING']
|
|
if not query_string:
|
|
return False
|
|
orig = django_request.GET.get('orig', '')
|
|
known_services = getattr(settings, 'KNOWN_SERVICES', None)
|
|
if known_services and orig:
|
|
for services in known_services.itervalues():
|
|
for service in services.itervalues():
|
|
if 'verif_orig' in service and service['verif_orig'] == orig:
|
|
keys.append(service['secret'])
|
|
break
|
|
return check_query(query_string, keys)
|
|
|
|
|
|
def check_query(query, keys, known_nonce=None, timedelta=30):
|
|
parsed = urlparse.parse_qs(query)
|
|
if not ('signature' in parsed and 'algo' in parsed and
|
|
'timestamp' in parsed and 'nonce' in parsed):
|
|
return False
|
|
signature = base64.b64decode(parsed['signature'][0])
|
|
algo = parsed['algo'][0]
|
|
timestamp = parsed['timestamp'][0]
|
|
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
|
|
nonce = parsed['nonce']
|
|
unsigned_query = query.split('&signature=')[0]
|
|
if known_nonce is not None and known_nonce(nonce):
|
|
return False
|
|
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
|
|
return False
|
|
return check_string(unsigned_query, signature, keys, algo=algo)
|
|
|
|
|
|
def check_string(s, signature, keys, algo='sha256'):
|
|
if not isinstance(keys, list):
|
|
keys = [keys]
|
|
for key in keys:
|
|
signature2 = sign_string(s, key, algo=algo)
|
|
if len(signature2) != len(signature):
|
|
continue
|
|
res = 0
|
|
# constant time compare
|
|
for a, b in zip(signature, signature2):
|
|
res |= ord(a) ^ ord(b)
|
|
if res == 0:
|
|
return True
|
|
return False
|
|
|
|
# _make_key and _HashedSeq imported/adapted from functools from Python 3.2+
|
|
def _make_key(args, kwds,
|
|
kwd_mark=(object(),),
|
|
fasttypes={int, str, frozenset, type(None)},
|
|
tuple=tuple, type=type, len=len):
|
|
key = args
|
|
if kwds:
|
|
key += kwd_mark
|
|
for item in kwds.items():
|
|
key += item
|
|
if len(key) == 1 and type(key[0]) in fasttypes:
|
|
return key[0]
|
|
return _HashedSeq(key)
|
|
|
|
|
|
class _HashedSeq(list):
|
|
__slots__ = 'hashvalue'
|
|
|
|
def __init__(self, tup, hash=hash):
|
|
self[:] = tup
|
|
self.hashvalue = hash(tup)
|
|
|
|
def __hash__(self):
|
|
return self.hashvalue
|
|
|
|
|
|
def cache_during_request(func):
|
|
def inner(*args, **kwargs):
|
|
request = get_request()
|
|
if request:
|
|
cache_key = (id(func), _make_key(args, kwargs))
|
|
if cache_key in request.cache:
|
|
return request.cache[cache_key]
|
|
result = func(*args, **kwargs)
|
|
if request:
|
|
request.cache[cache_key] = result
|
|
return result
|
|
return inner
|
|
|
|
|
|
def flatten_context(context):
|
|
# flatten a context to a dictionary, with full support for embedded Context
|
|
# objects.
|
|
flat_context = {}
|
|
if isinstance(context, BaseContext):
|
|
for ctx in context.dicts:
|
|
flat_context.update(flatten_context(ctx))
|
|
else:
|
|
flat_context.update(context)
|
|
return flat_context
|