passerelle/passerelle/utils/sftp.py

238 lines
7.7 KiB
Python

# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import absolute_import
import contextlib
import io
import json
import os
import re
from urllib import parse as urlparse
import paramiko
from django import forms
from django.core import validators
from django.db import models
from django.utils.encoding import force_bytes, force_str
from django.utils.translation import gettext_lazy as _
from paramiko.dsskey import DSSKey
from paramiko.ecdsakey import ECDSAKey
try:
from paramiko.ed25519key import Ed25519Key
except ImportError:
Ed25519Key = None
from paramiko._version import __version_info__
from paramiko.rsakey import RSAKey
def _load_private_key(content_or_file, password=None):
if not hasattr(content_or_file, 'read'):
fd = io.TextIOWrapper(io.BytesIO(force_bytes(content_or_file)))
else:
fd = content_or_file
for pkey_class in RSAKey, DSSKey, Ed25519Key, ECDSAKey:
if pkey_class is None:
continue
try:
fd.seek(0)
return pkey_class.from_private_key(fd, password=password)
except paramiko.PasswordRequiredException:
raise
except paramiko.SSHException:
pass
class SFTP:
def __init__(self, url, private_key_content=None, private_key_password=None):
self.url = url
parsed = urlparse.urlparse(url)
if not parsed.scheme == 'sftp':
raise ValueError('invalid scheme %s' % parsed.scheme)
if not parsed.hostname:
raise ValueError('missing hostname')
self.username = parsed.username or None
self.password = parsed.password or None
self.hostname = parsed.hostname
self.port = parsed.port or 22
self.path = parsed.path.strip('/')
self.private_key_content = force_str(private_key_content)
self.private_key_password = private_key_password
if private_key_content:
self.private_key = _load_private_key(private_key_content, private_key_password)
else:
self.private_key = None
self._client = None
self._transport = None
def __json__(self):
return {
'url': self.url,
'private_key_content': self.private_key_content,
'private_key_password': self.private_key_password,
}
def __str__(self):
return re.sub(r'://([^/]*:[^/]*?)@', '://***:***@', self.url)
def __eq__(self, other):
return (
isinstance(other, SFTP)
and other.url == self.url
and other.private_key_content == self.private_key_content
and other.private_key_password == self.private_key_password
)
# Paramiko can hang processes if not closed, it's important to use it as a
# contextmanager
@contextlib.contextmanager
def client(self):
ssh = paramiko.SSHClient()
try:
if __version_info__ < (2, 2):
ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
else:
ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy)
ssh.connect(
hostname=self.hostname,
port=self.port,
timeout=5,
pkey=self.private_key,
look_for_keys=False,
allow_agent=False,
username=self.username,
password=self.password,
)
client = ssh.open_sftp()
try:
if self.path:
client.chdir(self.path)
base_cwd = client._cwd
old_adjust_cwd = client._adjust_cwd
def _adjust_cwd(path):
path = old_adjust_cwd(path)
if not os.path.normpath(path).startswith(base_cwd):
raise ValueError('all paths must be under base path %s: %s' % (base_cwd, path))
return path
client._adjust_cwd = _adjust_cwd
yield client
finally:
client.close()
finally:
ssh.close()
class SFTPURLField(forms.URLField):
default_validators = [validators.URLValidator(schemes=['sftp'])]
class SFTPWidget(forms.MultiWidget):
template_name = 'passerelle/widgets/sftp.html'
def __init__(self, **kwargs):
widgets = [
forms.TextInput,
forms.FileInput,
forms.Textarea,
forms.TextInput,
]
super().__init__(widgets=widgets, **kwargs)
def decompress(self, value):
if not value:
return [None, None, None, None]
if hasattr(value, '__json__'):
value = value.__json__()
return [
value['url'],
None,
value.get('private_key_content'),
value.get('private_key_password'),
]
# XXX: bug in Django https://code.djangoproject.com/ticket/29205
# required_attribute is initialized from the parent.field required
# attribute and not from each sub-field attribute
def use_required_attribute(self, initial):
return False
class SFTPFormField(forms.MultiValueField):
widget = SFTPWidget
def __init__(self, **kwargs):
fields = [
SFTPURLField(),
forms.FileField(required=False),
forms.CharField(required=False),
forms.CharField(required=False),
]
super().__init__(fields=fields, require_all_fields=False, **kwargs)
def compress(self, data_list):
if not data_list:
return None
url, private_key_file, private_key_content, private_key_password = data_list
if private_key_file:
private_key_content = private_key_file.read().decode('ascii')
if private_key_content:
try:
pkey = _load_private_key(private_key_content, private_key_password)
except paramiko.PasswordRequiredException:
raise forms.ValidationError(_('SSH private key needs a password'))
if not pkey:
raise forms.ValidationError(_('SSH private key invalid'))
return SFTP(
url=url, private_key_content=private_key_content, private_key_password=private_key_password
)
class SFTPField(models.Field):
description = 'A SFTP connection'
def __init__(self, **kwargs):
kwargs.setdefault('default', None)
super().__init__(**kwargs)
def get_internal_type(self):
return 'TextField'
def from_db_value(self, value, *args, **kwargs):
return self.to_python(value)
def to_python(self, value):
if not value:
return None
if isinstance(value, SFTP):
return value
return SFTP(**json.loads(value))
def get_prep_value(self, value):
if not value:
return ''
return json.dumps(value.__json__())
def formfield(self, **kwargs):
defaults = {
'form_class': SFTPFormField,
}
defaults.update(**kwargs)
return super().formfield(**defaults)