137 lines
5.4 KiB
Python
137 lines
5.4 KiB
Python
# combo - content management system
|
|
# Copyright (C) 2015-2018 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import hashlib
|
|
import logging
|
|
|
|
from requests import Response, Session as RequestsSession
|
|
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from django.utils.encoding import smart_bytes
|
|
from django.utils.http import urlencode
|
|
from django.utils.six.moves.urllib import parse as urlparse
|
|
from django.utils.six import BytesIO
|
|
|
|
from .signature import sign_url
|
|
|
|
class NothingInCacheException(Exception):
|
|
pass
|
|
|
|
|
|
class Requests(RequestsSession):
|
|
|
|
def request(self, method, url, **kwargs):
|
|
remote_service = kwargs.pop('remote_service', None)
|
|
cache_duration = kwargs.pop('cache_duration', 15)
|
|
invalidate_cache = kwargs.pop('invalidate_cache', False)
|
|
user = kwargs.pop('user', None)
|
|
django_request = kwargs.pop('django_request', None)
|
|
without_user = kwargs.pop('without_user', False)
|
|
federation_key = kwargs.pop('federation_key', 'auto') # 'auto', 'email', 'nameid'
|
|
raise_if_not_cached = kwargs.pop('raise_if_not_cached', False)
|
|
log_errors = kwargs.pop('log_errors', True)
|
|
|
|
# don't use persistent cookies
|
|
self.cookies.clear()
|
|
|
|
if remote_service == 'auto':
|
|
remote_service = None
|
|
scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
|
|
for services in settings.KNOWN_SERVICES.values():
|
|
for service in services.values():
|
|
remote_url = service.get('url')
|
|
remote_scheme, remote_netloc, r_path, r_params, r_query, r_fragment = \
|
|
urlparse.urlparse(remote_url)
|
|
if remote_scheme == scheme and remote_netloc == netloc:
|
|
remote_service = service
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
if remote_service:
|
|
# only keeps the path (URI) in url parameter, scheme and netloc are
|
|
# in remote_service
|
|
url = urlparse.urlunparse(('', '', path, params, query, fragment))
|
|
else:
|
|
logging.warning('service not found in settings.KNOWN_SERVICES for %s', url)
|
|
|
|
if remote_service:
|
|
if isinstance(user, dict):
|
|
query_params = user.copy()
|
|
elif not user or not user.is_authenticated():
|
|
if without_user:
|
|
query_params = {}
|
|
else:
|
|
query_params = {'NameID': '', 'email': ''}
|
|
else:
|
|
query_params = {}
|
|
if federation_key == 'nameid':
|
|
query_params['NameID'] = user.get_name_id()
|
|
elif federation_key == 'email':
|
|
query_params['email'] = user.email
|
|
else: # 'auto'
|
|
user_name_id = user.get_name_id()
|
|
if user_name_id:
|
|
query_params['NameID'] = user_name_id
|
|
else:
|
|
query_params['email'] = user.email
|
|
|
|
query_params['orig'] = remote_service.get('orig')
|
|
|
|
remote_service_base_url = remote_service.get('url')
|
|
scheme, netloc, old_path, params, old_query, fragment = urlparse.urlparse(
|
|
remote_service_base_url)
|
|
|
|
query = urlencode(query_params)
|
|
if '?' in url:
|
|
path, old_query = url.split('?')
|
|
query += '&' + old_query
|
|
else:
|
|
path = url
|
|
|
|
url = urlparse.urlunparse((scheme, netloc, path, params, query, fragment))
|
|
|
|
if method == 'GET' and cache_duration:
|
|
# handle cache
|
|
cache_key = hashlib.md5(smart_bytes(url)).hexdigest()
|
|
cache_content = cache.get(cache_key)
|
|
if cache_content and not invalidate_cache:
|
|
response = Response()
|
|
response.status_code = 200
|
|
response.raw = BytesIO(smart_bytes(cache_content))
|
|
return response
|
|
elif raise_if_not_cached:
|
|
raise NothingInCacheException()
|
|
|
|
if remote_service: # sign
|
|
url = sign_url(url, remote_service.get('secret'))
|
|
|
|
kwargs['timeout'] = kwargs.get('timeout') or settings.REQUESTS_TIMEOUT
|
|
|
|
response = super(Requests, self).request(method, url, **kwargs)
|
|
if log_errors and (response.status_code // 100 != 2):
|
|
extra = {}
|
|
if django_request:
|
|
extra['request'] = django_request
|
|
logging.error('failed to %s %s (%s)', method, url, response.status_code, extra=extra)
|
|
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
|
|
cache.set(cache_key, response.content, cache_duration)
|
|
|
|
return response
|
|
|
|
requests = Requests()
|