combo/combo/utils.py

165 lines
6.0 KiB
Python

# combo - content management system
# Copyright (C) 2015 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import base64
import hmac
import hashlib
from HTMLParser import HTMLParser
import logging
import random
from StringIO import StringIO
import urlparse
from requests import Response, Session as RequestsSession
import requests
from django.core.cache import cache
from django.utils.html import strip_tags
from django.utils.http import urlencode, quote
class NothingInCacheException(Exception):
pass
class Requests(RequestsSession):
def request(self, method, url, **kwargs):
remote_service = kwargs.pop('remote_service', None)
cache_duration = kwargs.pop('cache_duration', 15)
user = kwargs.pop('user', None)
without_user = kwargs.pop('without_user', False)
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
if remote_service:
if isinstance(user, dict):
query_params = user.copy()
elif not user or not user.is_authenticated():
if without_user:
query_params = {}
else:
query_params = {'NameID': '', 'email': ''}
else:
query_params = {}
if federation_key == 'nameid':
query_params['NameID'] = user.saml_identifiers.first().name_id
elif federation_key == 'email':
query_params['email'] = user.email
else: # 'auto'
if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
query_params['NameID'] = user.saml_identifiers.first().name_id
else:
query_params['email'] = user.email
query_params['orig'] = remote_service.get('orig')
remote_service_base_url = remote_service.get('url')
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
remote_service_base_url)
query = urlencode(query_params)
if '?' in url:
path, old_query = url.split('?')
query += '&' + old_query
else:
path = url
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
if method == 'GET' and cache_duration:
# handle cache
cache_key = hashlib.md5(url).hexdigest()
cache_content = cache.get(cache_key)
if cache_content:
response = Response()
response.status_code = 200
response.raw = StringIO(cache_content)
return response
elif raise_if_not_cached:
raise NothingInCacheException()
if remote_service: # sign
url = sign_url(url, remote_service.get('secret'))
response = super(Requests, self).request(method, url, **kwargs)
if response.status_code != 200:
logging.error('failed to %s %s (%s)' % (method, url, response.status_code))
if method == 'GET' and cache_duration and response.status_code == 200:
cache.set(cache_key, response.content, cache_duration)
return response
requests = Requests()
# Simple signature scheme for query strings
def sign_url(url, key, algo='sha256', timestamp=None, nonce=None):
parsed = urlparse.urlparse(url)
new_query = sign_query(parsed.query, key, algo, timestamp, nonce)
return urlparse.urlunparse(parsed[:4] + (new_query,) + parsed[5:])
def sign_query(query, key, algo='sha256', timestamp=None, nonce=None):
if timestamp is None:
timestamp = datetime.datetime.utcnow()
timestamp = timestamp.strftime('%Y-%m-%dT%H:%M:%SZ')
if nonce is None:
nonce = hex(random.getrandbits(128))[2:]
new_query = query
if new_query:
new_query += '&'
new_query += urlencode((
('algo', algo),
('timestamp', timestamp),
('nonce', nonce)))
signature = base64.b64encode(sign_string(new_query, key, algo=algo))
new_query += '&signature=' + quote(signature)
return new_query
def sign_string(s, key, algo='sha256', timedelta=30):
digestmod = getattr(hashlib, algo)
hash = hmac.HMAC(str(key), digestmod=digestmod, msg=s)
return hash.digest()
def ellipsize(text, length=50):
text = HTMLParser().unescape(strip_tags(text))
if len(text) < length:
return text
return text[:(length-10)] + '...'
def check_query(query, key, known_nonce=None, timedelta=30):
parsed = urlparse.parse_qs(query)
signature = base64.b64decode(parsed['signature'][0])
algo = parsed['algo'][0]
timestamp = parsed['timestamp'][0]
timestamp = datetime.datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%SZ')
nonce = parsed['nonce']
unsigned_query = query.split('&signature=')[0]
if known_nonce is not None and known_nonce(nonce):
return False
if abs(datetime.datetime.utcnow() - timestamp) > datetime.timedelta(seconds=timedelta):
return False
return check_string(unsigned_query, signature, key, algo=algo)
def check_string(s, signature, key, algo='sha256'):
# constant time compare
signature2 = sign_string(s, key, algo=algo)
if len(signature2) != len(signature):
return False
res = 0
for a, b in zip(signature, signature2):
res |= ord(a) ^ ord(b)
return res == 0