536 lines
21 KiB
Python
536 lines
21 KiB
Python
# Copyright (C) 2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
|
|
import base64
|
|
import hashlib
|
|
import re
|
|
import time
|
|
import urllib.parse
|
|
import warnings
|
|
from functools import wraps
|
|
from io import BytesIO
|
|
from itertools import chain, islice
|
|
|
|
from django.conf import settings
|
|
from django.contrib.contenttypes.models import ContentType
|
|
from django.core.cache import cache
|
|
from django.core.exceptions import PermissionDenied
|
|
from django.db import transaction
|
|
from django.http import HttpResponse, HttpResponseBadRequest
|
|
from django.utils.encoding import force_bytes, force_str
|
|
from django.utils.functional import lazy
|
|
from django.utils.html import mark_safe
|
|
from django.utils.translation import ngettext_lazy
|
|
from django.views.generic.detail import SingleObjectMixin
|
|
from requests import Response as RequestResponse
|
|
from requests import Session as RequestSession
|
|
from requests.adapters import HTTPAdapter
|
|
from requests.structures import CaseInsensitiveDict
|
|
from urllib3.exceptions import InsecureRequestWarning
|
|
from urllib3.util.retry import Retry
|
|
|
|
from passerelle.base.signature import check_query, check_url
|
|
|
|
# legacy import, other modules keep importing to_json from passerelle.utils
|
|
from .jsonresponse import to_json # noqa F401 pylint: disable=unused-import
|
|
from .sftp import SFTP, SFTPField # noqa F401 pylint: disable=unused-import
|
|
from .soap import SOAPClient, SOAPTransport # noqa F401 pylint: disable=unused-import
|
|
|
|
mark_safe_lazy = lazy(mark_safe, str)
|
|
|
|
|
|
class ImportSiteError(Exception):
|
|
pass
|
|
|
|
|
|
def response_for_json(request, data):
|
|
import json
|
|
|
|
response = HttpResponse(content_type='application/json')
|
|
json_str = json.dumps(data)
|
|
for variable in ('jsonpCallback', 'callback'):
|
|
if variable in request.GET:
|
|
identifier = request.GET[variable]
|
|
if not re.match(r'^[$A-Za-z_][0-9A-Za-z_$]*$', identifier):
|
|
return HttpResponseBadRequest('invalid JSONP callback name')
|
|
json_str = '%s(%s);' % (identifier, json_str)
|
|
response['Content-Type'] = 'application/javascript'
|
|
break
|
|
response.write(json_str)
|
|
return response
|
|
|
|
|
|
def get_request_users(request):
|
|
from passerelle.base.models import ApiUser
|
|
|
|
users = []
|
|
|
|
users.extend(ApiUser.objects.filter(keytype=''))
|
|
|
|
if 'orig' in request.GET and 'signature' in request.GET:
|
|
orig = request.GET['orig']
|
|
query = request.META['QUERY_STRING']
|
|
signature_users = ApiUser.objects.filter(keytype='SIGN', username=orig)
|
|
for signature_user in signature_users:
|
|
if check_query(query, signature_user.key):
|
|
users.append(signature_user)
|
|
|
|
elif 'apikey' in request.GET:
|
|
users.extend(ApiUser.objects.filter(keytype='API', key=request.GET['apikey']))
|
|
|
|
elif 'HTTP_AUTHORIZATION' in request.META:
|
|
http_authorization = request.headers['Authorization'].split(' ', 1)
|
|
scheme = http_authorization[0].lower()
|
|
if scheme == 'basic' and len(http_authorization) > 1:
|
|
param = http_authorization[1]
|
|
try:
|
|
decoded = force_str(base64.b64decode(force_bytes(param.strip())))
|
|
username, password = decoded.split(':', 1)
|
|
except (TypeError, ValueError):
|
|
pass
|
|
else:
|
|
users.extend(ApiUser.objects.filter(keytype='SIGN', username=username, key=password))
|
|
|
|
def ip_match(ip, match):
|
|
if not ip:
|
|
return True
|
|
if ip == match:
|
|
return True
|
|
return False
|
|
|
|
users = [x for x in users if ip_match(x.ipsource, request.META.get('REMOTE_ADDR'))]
|
|
return users
|
|
|
|
|
|
def get_trusted_services():
|
|
"""
|
|
All services in settings.KNOWN_SERVICES are "trusted"
|
|
"""
|
|
trusted_services = []
|
|
for service_type in getattr(settings, 'KNOWN_SERVICES', {}):
|
|
for slug, service in settings.KNOWN_SERVICES[service_type].items():
|
|
if service.get('secret') and service.get('verif_orig'):
|
|
trusted_service = service.copy()
|
|
trusted_service['service_type'] = service_type
|
|
trusted_service['slug'] = slug
|
|
trusted_services.append(trusted_service)
|
|
return trusted_services
|
|
|
|
|
|
def is_trusted(request):
|
|
"""
|
|
True if query-string is signed by a trusted service (see get_trusted_services() above)
|
|
"""
|
|
if not request.GET.get('orig') or not request.GET.get('signature'):
|
|
return False
|
|
full_path = request.get_full_path()
|
|
for service in get_trusted_services():
|
|
if (
|
|
service.get('verif_orig') == request.GET['orig']
|
|
and service.get('secret')
|
|
and check_url(full_path, service['secret'])
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_authorized(request, obj, perm):
|
|
from passerelle.base.models import AccessRight
|
|
|
|
if request.user.is_superuser:
|
|
return True
|
|
if is_trusted(request):
|
|
return True
|
|
resource_type = ContentType.objects.get_for_model(obj)
|
|
rights = AccessRight.objects.filter(resource_type=resource_type, resource_pk=obj.id, codename=perm)
|
|
users = [x.apiuser for x in rights]
|
|
return set(users).intersection(get_request_users(request))
|
|
|
|
|
|
def protected_api(perm):
|
|
def decorator(view_func):
|
|
@wraps(view_func)
|
|
def _wrapped_view(instance, request, *args, **kwargs):
|
|
if not isinstance(instance, SingleObjectMixin):
|
|
raise Exception("protected_api must be applied on a method of a class based view")
|
|
obj = instance.get_object()
|
|
if not is_authorized(request, obj, perm):
|
|
raise PermissionDenied()
|
|
return view_func(instance, request, *args, **kwargs)
|
|
|
|
return _wrapped_view
|
|
|
|
return decorator
|
|
|
|
|
|
def should_content_type_body_be_logged(ctype):
|
|
content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES
|
|
if not ctype:
|
|
return False
|
|
for content_type in content_types:
|
|
if re.match(content_type, ctype):
|
|
return True
|
|
return False
|
|
|
|
|
|
def make_headers_safe(headers):
|
|
"""Convert dict of HTTP headers to text safely, as some services returns 8-bits encoding in headers."""
|
|
return {
|
|
force_str(key, errors='replace'): force_str(value, errors='replace') for key, value in headers.items()
|
|
}
|
|
|
|
|
|
def log_http_request(
|
|
logger, request, response=None, exception=None, error_log=True, extra=None, duration=None
|
|
):
|
|
log_function = logger.info
|
|
message = ''
|
|
extra = extra or {}
|
|
kwargs = {}
|
|
|
|
if request is not None:
|
|
message = '%s %s' % (request.method, request.url)
|
|
extra['request_url'] = request.url
|
|
if logger.level == 10 and request: # DEBUG
|
|
extra['request_headers'] = make_headers_safe(request.headers)
|
|
if request.body:
|
|
max_size = settings.LOGGED_REQUESTS_MAX_SIZE
|
|
if hasattr(logger, 'connector'):
|
|
max_size = logger.connector.logging_parameters.requests_max_size or max_size
|
|
extra['request_payload'] = request.body[:max_size]
|
|
if duration is not None:
|
|
extra['request_duration'] = duration
|
|
if response is not None:
|
|
message = message + ' (=> %s)' % response.status_code
|
|
extra['response_status'] = response.status_code
|
|
if logger.level == 10: # DEBUG
|
|
extra['response_headers'] = make_headers_safe(response.headers)
|
|
# log body only if content type is allowed
|
|
content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
|
|
if should_content_type_body_be_logged(content_type):
|
|
max_size = settings.LOGGED_RESPONSES_MAX_SIZE
|
|
if hasattr(logger, 'connector'):
|
|
max_size = logger.connector.logging_parameters.responses_max_size or max_size
|
|
content = response.content[:max_size]
|
|
extra['response_content'] = content
|
|
if response.status_code // 100 == 3:
|
|
log_function = logger.warning
|
|
elif response.status_code // 100 >= 4:
|
|
log_function = logger.error
|
|
elif exception:
|
|
if message:
|
|
message = message + ' (=> %s)' % repr(exception)
|
|
else:
|
|
message = repr(exception)
|
|
extra['response_exception'] = repr(exception)
|
|
log_function = logger.error
|
|
kwargs['exc_info'] = exception
|
|
|
|
# allow resources to disable any error log at requests level
|
|
if not error_log:
|
|
log_function = logger.info
|
|
log_function(message, extra=extra, **kwargs)
|
|
|
|
|
|
# Wrapper around requests.Session
|
|
# - log input and output data
|
|
# - use HTTP Basic auth if resource.basic_auth_username and resource.basic_auth_password exist
|
|
# - use client side certificate if resource.client_certificate (FileField) exists
|
|
# - verify server certificate CA if resource.trusted_certificate_authorities (FileField) exists
|
|
# - disable CA verification if resource.verify_cert (BooleanField) exists and is set
|
|
# - use a proxy for HTTP and HTTPS if resource.http_proxy exists
|
|
|
|
|
|
class Request(RequestSession):
|
|
ADAPTER_REGISTRY = {} # connection pooling
|
|
log_requests_errors = True
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.logger = kwargs.pop('logger')
|
|
self.resource = kwargs.pop('resource', None)
|
|
resource_log_requests_errors = getattr(self.resource, 'log_requests_errors', True)
|
|
self.log_requests_errors = kwargs.pop('log_requests_errors', resource_log_requests_errors)
|
|
timeout = kwargs.pop('timeout', None)
|
|
super().__init__(*args, **kwargs)
|
|
if self.resource:
|
|
timeout = timeout if timeout is not None else getattr(self.resource, 'requests_timeout', None)
|
|
http_adapter_init_kwargs = {}
|
|
requests_max_retries = dict(settings.REQUESTS_MAX_RETRIES)
|
|
if getattr(self.resource, 'requests_max_retries', None):
|
|
requests_max_retries = dict(self.resource.requests_max_retries)
|
|
if requests_max_retries:
|
|
requests_max_retries.setdefault('read', None)
|
|
http_adapter_init_kwargs['max_retries'] = Retry(**requests_max_retries)
|
|
adapter = Request.ADAPTER_REGISTRY.setdefault(
|
|
type(self.resource), HTTPAdapter(**http_adapter_init_kwargs)
|
|
)
|
|
self.mount('https://', adapter)
|
|
self.mount('http://', adapter)
|
|
self.timeout = timeout if timeout is not None else settings.REQUESTS_TIMEOUT
|
|
|
|
def _substitute(self, search, replace, value):
|
|
if isinstance(value, str):
|
|
value, nsub = re.subn(search, replace, value)
|
|
if nsub:
|
|
self.logger.debug('substitution: %d occurences', nsub)
|
|
elif isinstance(value, list):
|
|
value = [self._substitute(search, replace, v) for v in value]
|
|
elif isinstance(value, dict):
|
|
value = {
|
|
self._substitute(search, replace, k): self._substitute(search, replace, v)
|
|
for k, v in value.items()
|
|
}
|
|
return value
|
|
|
|
def apply_requests_substitution(self, response, substitution):
|
|
if not isinstance(substitution, dict):
|
|
self.logger.warning('substitution: invalid substitution, %r', substitution)
|
|
return
|
|
for key in ['search', 'replace']:
|
|
if key not in substitution:
|
|
self.logger.warning('substitution: missing field "%s": %s', key, substitution)
|
|
return
|
|
if not isinstance(substitution[key], str):
|
|
self.logger.warning(
|
|
'substitution: invalid type for field "%s", must be str: %s', key, substitution
|
|
)
|
|
return
|
|
search = substitution['search']
|
|
replace = substitution['replace']
|
|
|
|
# filter on url
|
|
if isinstance(substitution.get('url'), str):
|
|
url = urllib.parse.urlparse(substitution['url'])
|
|
request_url = urllib.parse.urlparse(response.request.url)
|
|
if url.scheme and url.scheme != request_url.scheme:
|
|
return
|
|
# substitution without a netloc are ignored
|
|
if not url.netloc:
|
|
return
|
|
if request_url.netloc != url.netloc:
|
|
return
|
|
if url.path and url.path != '/' and not request_url.path.startswith(url.path):
|
|
return
|
|
|
|
# filter on content-type
|
|
content_type = response.headers.get('Content-Type', '').split(';')[0].strip().lower()
|
|
for content_type_re in settings.REQUESTS_SUBSTITUTIONS_CONTENT_TYPES:
|
|
if re.match(content_type_re, content_type):
|
|
break
|
|
else:
|
|
self.logger.debug('substitution: content_type did not match %s', content_type)
|
|
return
|
|
|
|
self.logger.debug('substitution: try %s', substitution)
|
|
try:
|
|
if re.match(r'application/([^;]\+)?json', content_type):
|
|
import json
|
|
|
|
response._content = json.dumps(self._substitute(search, replace, response.json())).encode()
|
|
else:
|
|
response._content = self._substitute(search, replace, response.text).encode()
|
|
response.encoding = 'utf-8'
|
|
return True
|
|
except Exception:
|
|
self.logger.exception('substitution: "%s" failed', substitution)
|
|
|
|
def request(self, method, url, **kwargs):
|
|
cache_duration = kwargs.pop('cache_duration', None)
|
|
invalidate_cache = kwargs.pop('invalidate_cache', False)
|
|
|
|
# search in legacy urls
|
|
legacy_urls_mapping = getattr(settings, 'LEGACY_URLS_MAPPING', None)
|
|
if legacy_urls_mapping:
|
|
splitted_url = urllib.parse.urlparse(url)
|
|
hostname = splitted_url.netloc
|
|
if hostname in legacy_urls_mapping:
|
|
url = splitted_url._replace(netloc=legacy_urls_mapping[hostname]).geturl()
|
|
|
|
if self.resource:
|
|
if 'auth' not in kwargs:
|
|
username = getattr(self.resource, 'basic_auth_username', None)
|
|
if username and hasattr(self.resource, 'basic_auth_password'):
|
|
kwargs['auth'] = (username, self.resource.basic_auth_password)
|
|
if 'cert' not in kwargs:
|
|
keystore = getattr(self.resource, 'client_certificate', None)
|
|
if keystore:
|
|
kwargs['cert'] = keystore.path
|
|
if 'verify' not in kwargs:
|
|
trusted_certificate_authorities = getattr(
|
|
self.resource, 'trusted_certificate_authorities', None
|
|
)
|
|
if trusted_certificate_authorities:
|
|
kwargs['verify'] = trusted_certificate_authorities.path
|
|
elif hasattr(self.resource, 'verify_cert'):
|
|
kwargs['verify'] = self.resource.verify_cert
|
|
if 'proxies' not in kwargs:
|
|
proxy = getattr(self.resource, 'http_proxy', None)
|
|
if proxy:
|
|
kwargs['proxies'] = {'http': proxy, 'https': proxy}
|
|
|
|
if method == 'GET' and cache_duration:
|
|
cache_key = hashlib.md5(force_bytes('%r;%r' % (url, kwargs))).hexdigest()
|
|
cache_content = cache.get(cache_key)
|
|
if cache_content and not invalidate_cache:
|
|
response = RequestResponse()
|
|
response.raw = BytesIO(cache_content.get('content'))
|
|
response.headers = CaseInsensitiveDict(cache_content.get('headers', {}))
|
|
response.status_code = cache_content.get('status_code')
|
|
return response
|
|
|
|
if settings.REQUESTS_PROXIES and 'proxies' not in kwargs:
|
|
kwargs['proxies'] = settings.REQUESTS_PROXIES
|
|
|
|
if 'timeout' not in kwargs:
|
|
kwargs['timeout'] = self.timeout
|
|
|
|
with warnings.catch_warnings():
|
|
if kwargs.get('verify') is False:
|
|
# disable urllib3 warnings
|
|
warnings.simplefilter(action='ignore', category=InsecureRequestWarning)
|
|
response = super().request(method, url, **kwargs)
|
|
|
|
if self.resource:
|
|
requests_substitutions = self.resource.get_setting('requests_substitutions')
|
|
if isinstance(requests_substitutions, list):
|
|
for requests_substitution in requests_substitutions:
|
|
if not self.apply_requests_substitution(response, requests_substitution):
|
|
self.logger.debug('substitution: %s does not match', requests_substitution)
|
|
|
|
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
|
|
cache.set(
|
|
cache_key,
|
|
{
|
|
'content': response.content,
|
|
'headers': response.headers,
|
|
'status_code': response.status_code,
|
|
},
|
|
cache_duration,
|
|
)
|
|
|
|
return response
|
|
|
|
def send(self, request, **kwargs):
|
|
start_time = time.time()
|
|
try:
|
|
response = super().send(request, **kwargs)
|
|
duration = time.time() - start_time
|
|
except Exception as exc:
|
|
duration = time.time() - start_time
|
|
self.log_http_request(request, exception=exc, duration=duration)
|
|
raise
|
|
self.log_http_request(request, response=response, duration=duration)
|
|
return response
|
|
|
|
def log_http_request(self, request, response=None, exception=None, duration=None):
|
|
error_log = self.log_requests_errors
|
|
log_http_request(
|
|
self.logger,
|
|
request=request,
|
|
response=response,
|
|
exception=exception,
|
|
error_log=error_log,
|
|
duration=duration,
|
|
)
|
|
|
|
|
|
def export_site(slugs=None):
|
|
'''Dump passerelle configuration (users, resources and ACLs) to JSON dumpable dictionnary'''
|
|
from passerelle.base.models import ApiUser
|
|
from passerelle.views import get_all_apps
|
|
|
|
d = {}
|
|
d['apiusers'] = [apiuser.export_json() for apiuser in ApiUser.objects.all()]
|
|
d['resources'] = resources = []
|
|
for app in get_all_apps():
|
|
for resource in app.objects.all():
|
|
if slugs and resource.slug not in slugs:
|
|
continue
|
|
try:
|
|
resources.append(resource.export_json())
|
|
except NotImplementedError:
|
|
break
|
|
return d
|
|
|
|
|
|
def import_site(d, if_empty=False, clean=False, overwrite=False, import_users=False):
|
|
"""Load passerelle configuration (users, resources and ACLs) from a dictionnary loaded from
|
|
JSON
|
|
"""
|
|
from passerelle.base.models import ApiUser, BaseResource
|
|
from passerelle.views import get_all_apps
|
|
|
|
d = d.copy()
|
|
|
|
def is_empty():
|
|
if import_users:
|
|
if ApiUser.objects.count():
|
|
return False
|
|
|
|
for app in get_all_apps():
|
|
if app.objects.count():
|
|
return False
|
|
return True
|
|
|
|
if if_empty and not is_empty():
|
|
return
|
|
|
|
if clean:
|
|
for app in get_all_apps():
|
|
app.objects.all().delete()
|
|
if import_users:
|
|
ApiUser.objects.all().delete()
|
|
|
|
with transaction.atomic():
|
|
if import_users:
|
|
for apiuser in d.get('apiusers', []):
|
|
ApiUser.import_json(apiuser, overwrite=overwrite)
|
|
|
|
unknown_connectors = []
|
|
|
|
def import_resource(res):
|
|
try:
|
|
BaseResource.import_json(res, overwrite=overwrite, import_users=import_users)
|
|
except BaseResource.UnknownBaseResourceError as e:
|
|
unknown_connectors.append(str(e))
|
|
|
|
resources = d.get('resources', [])
|
|
# import SectorResource first, as AddressResource may need them
|
|
for res in [r for r in resources if r['resource_type'] == 'sector.sectorresource']:
|
|
import_resource(res)
|
|
for res in [r for r in resources if r['resource_type'] != 'sector.sectorresource']:
|
|
import_resource(res)
|
|
if unknown_connectors:
|
|
raise ImportSiteError(
|
|
ngettext_lazy('Unknown connector: %s', 'Unknown connectors: %s', len(unknown_connectors))
|
|
% ', '.join(unknown_connectors)
|
|
)
|
|
|
|
|
|
def batch(iterable, size):
|
|
"""Batch an iterable as an iterable of iterables of at most size element
|
|
long.
|
|
"""
|
|
sourceiter = iter(iterable)
|
|
while True:
|
|
batchiter = islice(sourceiter, size)
|
|
# call next() at least one time to advance, if the caller does not
|
|
# consume the returned iterators, sourceiter will never be exhausted.
|
|
try:
|
|
yield chain([next(batchiter)], batchiter)
|
|
except StopIteration:
|
|
return
|