238 lines
7.7 KiB
Python
238 lines
7.7 KiB
Python
# passerelle - uniform access to multiple data sources and services
|
|
# Copyright (C) 2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
from __future__ import absolute_import
|
|
|
|
import contextlib
|
|
import io
|
|
import json
|
|
import os
|
|
import re
|
|
from urllib import parse as urlparse
|
|
|
|
import paramiko
|
|
from django import forms
|
|
from django.core import validators
|
|
from django.db import models
|
|
from django.utils.encoding import force_bytes, force_str
|
|
from django.utils.translation import gettext_lazy as _
|
|
from paramiko.dsskey import DSSKey
|
|
from paramiko.ecdsakey import ECDSAKey
|
|
|
|
try:
|
|
from paramiko.ed25519key import Ed25519Key
|
|
except ImportError:
|
|
Ed25519Key = None
|
|
from paramiko._version import __version_info__
|
|
from paramiko.rsakey import RSAKey
|
|
|
|
|
|
def _load_private_key(content_or_file, password=None):
|
|
if not hasattr(content_or_file, 'read'):
|
|
fd = io.TextIOWrapper(io.BytesIO(force_bytes(content_or_file)))
|
|
else:
|
|
fd = content_or_file
|
|
|
|
for pkey_class in RSAKey, DSSKey, Ed25519Key, ECDSAKey:
|
|
if pkey_class is None:
|
|
continue
|
|
try:
|
|
fd.seek(0)
|
|
return pkey_class.from_private_key(fd, password=password)
|
|
except paramiko.PasswordRequiredException:
|
|
raise
|
|
except paramiko.SSHException:
|
|
pass
|
|
|
|
|
|
class SFTP:
|
|
def __init__(self, url, private_key_content=None, private_key_password=None):
|
|
self.url = url
|
|
parsed = urlparse.urlparse(url)
|
|
if not parsed.scheme == 'sftp':
|
|
raise ValueError('invalid scheme %s' % parsed.scheme)
|
|
if not parsed.hostname:
|
|
raise ValueError('missing hostname')
|
|
self.username = parsed.username or None
|
|
self.password = parsed.password or None
|
|
self.hostname = parsed.hostname
|
|
self.port = parsed.port or 22
|
|
self.path = parsed.path.strip('/')
|
|
self.private_key_content = force_str(private_key_content)
|
|
self.private_key_password = private_key_password
|
|
if private_key_content:
|
|
self.private_key = _load_private_key(private_key_content, private_key_password)
|
|
else:
|
|
self.private_key = None
|
|
self._client = None
|
|
self._transport = None
|
|
|
|
def __json__(self):
|
|
return {
|
|
'url': self.url,
|
|
'private_key_content': self.private_key_content,
|
|
'private_key_password': self.private_key_password,
|
|
}
|
|
|
|
def __str__(self):
|
|
return re.sub(r'://([^/]*:[^/]*?)@', '://***:***@', self.url)
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
isinstance(other, SFTP)
|
|
and other.url == self.url
|
|
and other.private_key_content == self.private_key_content
|
|
and other.private_key_password == self.private_key_password
|
|
)
|
|
|
|
# Paramiko can hang processes if not closed, it's important to use it as a
|
|
# contextmanager
|
|
@contextlib.contextmanager
|
|
def client(self):
|
|
ssh = paramiko.SSHClient()
|
|
try:
|
|
if __version_info__ < (2, 2):
|
|
ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
|
|
else:
|
|
ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy)
|
|
ssh.connect(
|
|
hostname=self.hostname,
|
|
port=self.port,
|
|
timeout=5,
|
|
pkey=self.private_key,
|
|
look_for_keys=False,
|
|
allow_agent=False,
|
|
username=self.username,
|
|
password=self.password,
|
|
)
|
|
client = ssh.open_sftp()
|
|
try:
|
|
if self.path:
|
|
client.chdir(self.path)
|
|
base_cwd = client._cwd
|
|
old_adjust_cwd = client._adjust_cwd
|
|
|
|
def _adjust_cwd(path):
|
|
path = old_adjust_cwd(path)
|
|
if not os.path.normpath(path).startswith(base_cwd):
|
|
raise ValueError('all paths must be under base path %s: %s' % (base_cwd, path))
|
|
return path
|
|
|
|
client._adjust_cwd = _adjust_cwd
|
|
yield client
|
|
finally:
|
|
client.close()
|
|
finally:
|
|
ssh.close()
|
|
|
|
|
|
class SFTPURLField(forms.URLField):
|
|
default_validators = [validators.URLValidator(schemes=['sftp'])]
|
|
|
|
|
|
class SFTPWidget(forms.MultiWidget):
|
|
template_name = 'passerelle/widgets/sftp.html'
|
|
|
|
def __init__(self, **kwargs):
|
|
widgets = [
|
|
forms.TextInput,
|
|
forms.FileInput,
|
|
forms.Textarea,
|
|
forms.TextInput,
|
|
]
|
|
super().__init__(widgets=widgets, **kwargs)
|
|
|
|
def decompress(self, value):
|
|
if not value:
|
|
return [None, None, None, None]
|
|
if hasattr(value, '__json__'):
|
|
value = value.__json__()
|
|
return [
|
|
value['url'],
|
|
None,
|
|
value.get('private_key_content'),
|
|
value.get('private_key_password'),
|
|
]
|
|
|
|
# XXX: bug in Django https://code.djangoproject.com/ticket/29205
|
|
# required_attribute is initialized from the parent.field required
|
|
# attribute and not from each sub-field attribute
|
|
def use_required_attribute(self, initial):
|
|
return False
|
|
|
|
|
|
class SFTPFormField(forms.MultiValueField):
|
|
widget = SFTPWidget
|
|
|
|
def __init__(self, **kwargs):
|
|
fields = [
|
|
SFTPURLField(),
|
|
forms.FileField(required=False),
|
|
forms.CharField(required=False),
|
|
forms.CharField(required=False),
|
|
]
|
|
super().__init__(fields=fields, require_all_fields=False, **kwargs)
|
|
|
|
def compress(self, data_list):
|
|
if not data_list:
|
|
return None
|
|
url, private_key_file, private_key_content, private_key_password = data_list
|
|
if private_key_file:
|
|
private_key_content = private_key_file.read().decode('ascii')
|
|
if private_key_content:
|
|
try:
|
|
pkey = _load_private_key(private_key_content, private_key_password)
|
|
except paramiko.PasswordRequiredException:
|
|
raise forms.ValidationError(_('SSH private key needs a password'))
|
|
if not pkey:
|
|
raise forms.ValidationError(_('SSH private key invalid'))
|
|
return SFTP(
|
|
url=url, private_key_content=private_key_content, private_key_password=private_key_password
|
|
)
|
|
|
|
|
|
class SFTPField(models.Field):
|
|
description = 'A SFTP connection'
|
|
|
|
def __init__(self, **kwargs):
|
|
kwargs.setdefault('default', None)
|
|
super().__init__(**kwargs)
|
|
|
|
def get_internal_type(self):
|
|
return 'TextField'
|
|
|
|
def from_db_value(self, value, *args, **kwargs):
|
|
return self.to_python(value)
|
|
|
|
def to_python(self, value):
|
|
if not value:
|
|
return None
|
|
if isinstance(value, SFTP):
|
|
return value
|
|
return SFTP(**json.loads(value))
|
|
|
|
def get_prep_value(self, value):
|
|
if not value:
|
|
return ''
|
|
return json.dumps(value.__json__())
|
|
|
|
def formfield(self, **kwargs):
|
|
defaults = {
|
|
'form_class': SFTPFormField,
|
|
}
|
|
defaults.update(**kwargs)
|
|
return super().formfield(**defaults)
|