passerelle/passerelle/utils/__init__.py

389 lines
15 KiB
Python

# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import base64
from functools import wraps
import hashlib
import re
from itertools import islice, chain
import warnings
from requests import Session as RequestSession, Response as RequestResponse
from requests.adapters import HTTPAdapter
from requests.structures import CaseInsensitiveDict
from urllib3.exceptions import InsecureRequestWarning
from django.conf import settings
from django.core.cache import cache
from django.core.exceptions import PermissionDenied
from django.http import HttpResponse, HttpResponseBadRequest
from django.template import Template, Context
from django.utils.decorators import available_attrs
from django.utils.encoding import force_bytes, force_text
from django.utils.six import BytesIO
from django.views.generic.detail import SingleObjectMixin
from django.contrib.contenttypes.models import ContentType
from django.db import transaction
from passerelle.base.signature import check_query, check_url
def response_for_json(request, data):
import json
response = HttpResponse(content_type='application/json')
json_str = json.dumps(data)
for variable in ('jsonpCallback', 'callback'):
if variable in request.GET:
identifier = request.GET[variable]
if not re.match(r'^[$A-Za-z_][0-9A-Za-z_$]*$', identifier):
return HttpResponseBadRequest('invalid JSONP callback name')
json_str = '%s(%s);' % (identifier, json_str)
response['Content-Type'] = 'application/javascript'
break
response.write(json_str)
return response
def get_request_users(request):
from passerelle.base.models import ApiUser
users = []
users.extend(ApiUser.objects.filter(keytype=''))
if 'orig' in request.GET and 'signature' in request.GET:
orig = request.GET['orig']
query = request.META['QUERY_STRING']
signature_users = ApiUser.objects.filter(keytype='SIGN', username=orig)
for signature_user in signature_users:
if check_query(query, signature_user.key):
users.append(signature_user)
elif 'apikey' in request.GET:
users.extend(ApiUser.objects.filter(keytype='API', key=request.GET['apikey']))
elif 'HTTP_AUTHORIZATION' in request.META:
http_authorization = request.META['HTTP_AUTHORIZATION'].split(' ', 1)
scheme = http_authorization[0].lower()
if scheme == 'basic' and len(http_authorization) > 1:
param = http_authorization[1]
try:
decoded = force_text(base64.b64decode(force_bytes(param.strip())))
username, password = decoded.split(':', 1)
except (TypeError, ValueError):
pass
else:
users.extend(ApiUser.objects.filter(keytype='SIGN', username=username, key=password))
def ip_match(ip, match):
if not ip:
return True
if ip == match:
return True
return False
users = [x for x in users if ip_match(x.ipsource, request.META.get('REMOTE_ADDR'))]
return users
def get_trusted_services():
'''
All services in settings.KNOWN_SERVICES are "trusted"
'''
trusted_services = []
for service_type in getattr(settings, 'KNOWN_SERVICES', {}):
for slug, service in settings.KNOWN_SERVICES[service_type].items():
if service.get('secret') and service.get('verif_orig'):
trusted_service = service.copy()
trusted_service['service_type'] = service_type
trusted_service['slug'] = slug
trusted_services.append(trusted_service)
return trusted_services
def is_trusted(request):
'''
True if query-string is signed by a trusted service (see get_trusted_services() above)
'''
if not request.GET.get('orig') or not request.GET.get('signature'):
return False
full_path = request.get_full_path()
for service in get_trusted_services():
if (service.get('verif_orig') == request.GET['orig']
and service.get('secret')
and check_url(full_path, service['secret'])):
return True
return False
def is_authorized(request, obj, perm):
from passerelle.base.models import AccessRight
if request.user.is_superuser:
return True
if is_trusted(request):
return True
resource_type = ContentType.objects.get_for_model(obj)
rights = AccessRight.objects.filter(resource_type=resource_type, resource_pk=obj.id, codename=perm)
users = [x.apiuser for x in rights]
return set(users).intersection(get_request_users(request))
def protected_api(perm):
def decorator(view_func):
@wraps(view_func, assigned=available_attrs(view_func))
def _wrapped_view(instance, request, *args, **kwargs):
if not isinstance(instance, SingleObjectMixin):
raise Exception("protected_api must be applied on a method of a class based view")
obj = instance.get_object()
if not is_authorized(request, obj, perm):
raise PermissionDenied()
return view_func(instance, request, *args, **kwargs)
return _wrapped_view
return decorator
def content_type_match(ctype):
content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES
if not ctype:
return False
for content_type in content_types:
if re.match(content_type, ctype):
return True
return False
def log_http_request(logger, request, response=None, exception=None, error_log=True, extra=None):
log_function = logger.info
message = ''
extra = extra or {}
if request is not None:
message = '%s %s' % (request.method, request.url)
extra['request_url'] = request.url
if logger.level == 10 and request: # DEBUG
extra['request_headers'] = dict(request.headers.items())
if request.body:
if hasattr(logger, 'connector'):
max_size = logger.connector.logging_parameters.requests_max_size
else:
max_size = settings.LOGGED_REQUESTS_MAX_SIZE
extra['request_payload'] = repr(request.body[:max_size])
if response is not None:
message = message + ' (=> %s)' % response.status_code
extra['response_status'] = response.status_code
if logger.level == 10: # DEBUG
extra['response_headers'] = dict(response.headers.items())
# log body only if content type is allowed
if content_type_match(response.headers.get('Content-Type')):
if hasattr(logger, 'connector'):
max_size = logger.connector.logging_parameters.responses_max_size
else:
max_size = settings.LOGGED_RESPONSES_MAX_SIZE
content = response.content[:max_size]
extra['response_content'] = repr(content)
if response.status_code // 100 == 3:
log_function = logger.warning
elif response.status_code // 100 >= 4:
log_function = logger.error
elif exception:
if message:
message = message + ' (=> %s)' % repr(exception)
else:
message = repr(exception)
extra['response_exception'] = repr(exception)
log_function = logger.error
# allow resources to disable any error log at requests level
if not error_log:
log_function = logger.info
log_function(message, extra=extra)
# Wrapper around requests.Session
# - log input and output data
# - use HTTP Basic auth if resource.basic_auth_username and resource.basic_auth_password exist
# - use client side certificate if resource.client_certificate (FileField) exists
# - verify server certificate CA if resource.trusted_certificate_authorities (FileField) exists
# - disable CA verification if resource.verify_cert (BooleanField) exists and is set
# - use a proxy for HTTP and HTTPS if resource.http_proxy exists
class Request(RequestSession):
ADAPTER_REGISTRY = {} # connection pooling
def __init__(self, *args, **kwargs):
self.logger = kwargs.pop('logger')
self.resource = kwargs.pop('resource', None)
super(Request, self).__init__(*args, **kwargs)
if self.resource:
adapter = Request.ADAPTER_REGISTRY.setdefault(type(self.resource), HTTPAdapter())
self.mount('https://', adapter)
self.mount('http://', adapter)
def request(self, method, url, **kwargs):
cache_duration = kwargs.pop('cache_duration', None)
invalidate_cache = kwargs.pop('invalidate_cache', False)
if self.resource:
if 'auth' not in kwargs:
username = getattr(self.resource, 'basic_auth_username', None)
if username and hasattr(self.resource, 'basic_auth_password'):
kwargs['auth'] = (username, self.resource.basic_auth_password)
if 'cert' not in kwargs:
keystore = getattr(self.resource, 'client_certificate', None)
if keystore:
kwargs['cert'] = keystore.path
if 'verify' not in kwargs:
trusted_certificate_authorities = getattr(self.resource,
'trusted_certificate_authorities',
None)
if trusted_certificate_authorities:
kwargs['verify'] = trusted_certificate_authorities.path
elif hasattr(self.resource, 'verify_cert'):
kwargs['verify'] = self.resource.verify_cert
if 'proxies' not in kwargs:
proxy = getattr(self.resource, 'http_proxy', None)
if proxy:
kwargs['proxies'] = {'http': proxy, 'https': proxy}
if method == 'GET' and cache_duration:
cache_key = hashlib.md5(force_bytes('%r;%r' % (url, kwargs))).hexdigest()
cache_content = cache.get(cache_key)
if cache_content and not invalidate_cache:
response = RequestResponse()
response.raw = BytesIO(cache_content.get('content'))
response.headers = CaseInsensitiveDict(cache_content.get('headers', {}))
response.status_code = cache_content.get('status_code')
return response
if settings.REQUESTS_PROXIES and 'proxies' not in kwargs:
kwargs['proxies'] = settings.REQUESTS_PROXIES
if 'timeout' not in kwargs:
kwargs['timeout'] = settings.REQUESTS_TIMEOUT
with warnings.catch_warnings():
if kwargs.get('verify') is False:
# disable urllib3 warnings
warnings.simplefilter(action='ignore', category=InsecureRequestWarning)
response = super(Request, self).request(method, url, **kwargs)
if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
cache.set(cache_key, {
'content': response.content,
'headers': response.headers,
'status_code': response.status_code,
}, cache_duration)
return response
def send(self, request, **kwargs):
try:
response = super(Request, self).send(request, **kwargs)
except Exception as exc:
self.log_http_request(request, exception=exc)
raise
self.log_http_request(request, response=response)
return response
def log_http_request(self, request, response=None, exception=None):
error_log = getattr(self.resource, 'log_requests_errors', True)
log_http_request(self.logger, request=request, response=response, exception=exception, error_log=error_log)
def export_site(slugs=None):
'''Dump passerelle configuration (users, resources and ACLs) to JSON dumpable dictionnary'''
from passerelle.base.models import ApiUser
from passerelle.base.models import BaseResource
d = {}
d['apiusers'] = [apiuser.export_json() for apiuser in ApiUser.objects.all()]
d['resources'] = resources = []
for subclass in BaseResource.__subclasses__():
if subclass._meta.abstract:
continue
for resource in subclass.objects.all():
if slugs and resource.slug not in slugs:
continue
try:
resources.append(resource.export_json())
except NotImplementedError:
break
return d
def import_site(d, if_empty=False, clean=False, overwrite=False, import_users=False):
'''Load passerelle configuration (users, resources and ACLs) from a dictionnary loaded from
JSON
'''
from passerelle.base.models import ApiUser
from passerelle.base.models import BaseResource
d = d.copy()
def is_empty():
if import_users:
if ApiUser.objects.count():
return False
for subclass in BaseResource.__subclasses__():
if subclass._meta.abstract:
continue
if subclass.objects.count():
return False
return True
if if_empty and not is_empty():
return
if clean:
for subclass in BaseResource.__subclasses__():
if subclass._meta.abstract:
continue
subclass.objects.all().delete()
if import_users:
ApiUser.objects.all().delete()
with transaction.atomic():
if import_users:
for apiuser in d.get('apiusers', []):
ApiUser.import_json(apiuser, overwrite=overwrite)
for resource in d.get('resources', []):
BaseResource.import_json(resource, overwrite=overwrite, import_users=import_users)
def batch(iterable, size):
'''Batch an iterable as an iterable of iterables of at most size element
long.
'''
sourceiter = iter(iterable)
while True:
batchiter = islice(sourceiter, size)
# call next() at least one time to advance, if the caller does not
# consume the returned iterators, sourceiter will never be exhausted.
try:
yield chain([next(batchiter)], batchiter)
except StopIteration:
return
# legacy import, other modules keep importing to_json from passerelle.utils
from .jsonresponse import to_json
from .soap import SOAPClient, SOAPTransport
from .sftp import SFTPField, SFTP