passerelle/passerelle/apps/csvdatasource/models.py

533 lines
20 KiB
Python

# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2016 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import os
import re
import csv
from collections import OrderedDict
import tempfile
import six
import pytz
from pyexcel_ods import get_data as get_data_ods
from pyexcel_xls import get_data as get_data_xls
from django.contrib.postgres.fields import JSONField
from django.utils.encoding import force_str, smart_text, force_text
from django.utils.timezone import make_aware
from django.conf import settings
from django.db import models, transaction
from django.core.exceptions import ValidationError
from django.urls import reverse
from django.utils.timezone import now
from django.utils.translation import ugettext_lazy as _
from passerelle.base.models import BaseResource
from passerelle.utils import batch
from passerelle.utils.jsonresponse import APIError
from passerelle.utils.api import endpoint
from passerelle.utils.conversion import normalize
identifier_re = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE)
code_cache = OrderedDict()
def get_code(expr):
# limit size of code cache to 1024
if len(code_cache) > 1024:
for key in list(code_cache.keys())[: len(code_cache) - 1024]:
code_cache.pop(key)
if expr not in code_cache:
code_cache[expr] = compile(expr, '<inline>', 'eval')
return code_cache[expr]
class Query(models.Model):
resource = models.ForeignKey('CsvDataSource', on_delete=models.CASCADE, related_name='queries')
slug = models.SlugField(_('Name (slug)'))
label = models.CharField(_('Label'), max_length=100)
description = models.TextField(_('Description'), blank=True)
filters = models.TextField(
_('Filters'), blank=True, help_text=_('List of filter clauses (Python expression)')
)
order = models.TextField(_('Sort Order'), blank=True, help_text=_('Columns to use for sorting rows'))
distinct = models.TextField(_('Distinct'), blank=True, help_text=_('Distinct columns'))
projections = models.TextField(
_('Projections'), blank=True, help_text=_('List of projections (name:expression)')
)
structure = models.CharField(
_('Structure'),
max_length=20,
choices=[
('array', _('Array')),
('dict', _('Dictionary')),
('keyed-distinct', _('Keyed Dictionary')),
('tuples', _('Tuples')),
('onerow', _('Single Row')),
('one', _('Single Value')),
],
default='dict',
help_text=_('Data structure used for the response'),
)
class Meta:
ordering = ['slug']
unique_together = ['resource', 'slug']
def get_list(self, attribute):
if not getattr(self, attribute):
return []
return getattr(self, attribute).strip().splitlines()
def export_json(self):
return {
'slug': self.slug,
'label': self.label,
'description': self.description,
'filters': self.filters,
'projections': self.projections,
'order': self.order,
'distinct': self.distinct,
'structure': self.structure,
}
@classmethod
def import_json(cls, d):
return cls(**d)
@property
def name(self):
return self.slug
def delete_url(self):
return reverse('csv-delete-query', kwargs={'connector_slug': self.resource.slug, 'pk': self.pk})
def edit_url(self):
return reverse('csv-edit-query', kwargs={'connector_slug': self.resource.slug, 'pk': self.pk})
def upload_to(instance, filename):
return '%s/%s/%s' % (instance.get_connector_slug(), instance.slug, filename)
class CsvDataSource(BaseResource):
csv_file = models.FileField(
_('Spreadsheet file'), upload_to=upload_to, help_text=_('Supported file formats: csv, ods, xls, xlsx')
)
columns_keynames = models.CharField(
max_length=256,
verbose_name=_('Column keynames'),
default='id, text',
help_text=_('ex: id,text,data1,data2'),
blank=True,
)
skip_header = models.BooleanField(_('Skip first line'), default=False)
_dialect_options = JSONField(editable=False, null=True)
sheet_name = models.CharField(_('Sheet name'), blank=True, max_length=150)
category = _('Data Sources')
documentation_url = (
'https://doc-publik.entrouvert.com/admin-fonctionnel/parametrage-avance/source-de-donnees-csv/'
)
class Meta:
verbose_name = _('Spreadsheet File')
def clean(self, *args, **kwargs):
file_type = self.csv_file.name.split('.')[-1]
if file_type in ('ods', 'xls', 'xlsx') and not self.sheet_name:
raise ValidationError(_('You must specify a sheet name'))
if file_type not in ('ods', 'xls', 'xlsx'):
try:
self._detect_dialect_options()
except Exception as e:
raise ValidationError(_('Could not detect CSV dialect: %s') % e)
try:
self.get_rows()
except Exception as e:
raise ValidationError(_('Invalid CSV file: %s') % e)
return super(CsvDataSource, self).clean(*args, **kwargs)
def _detect_dialect_options(self):
content = self.get_content_without_bom()
dialect = csv.Sniffer().sniff(content)
self.dialect_options = {k: v for k, v in vars(dialect).items() if not k.startswith('_')}
def save(self, *args, **kwargs):
cache = kwargs.pop('cache', True)
result = super(CsvDataSource, self).save(*args, **kwargs)
if cache:
self.cache_data()
return result
def cache_data(self):
# FIXME: why are those dead variables computed ?
titles = [t.strip() for t in self.columns_keynames.split(',')]
indexes = [titles.index(t) for t in titles if t]
caption = [titles[i] for i in indexes]
with transaction.atomic():
TableRow.objects.filter(resource=self).delete()
for block in batch(enumerate(self.get_rows()), 5000):
TableRow.objects.bulk_create(
TableRow(resource=self, line_number=i, data=data) for i, data in block
)
def csv_file_datetime(self):
ctime = os.fstat(self.csv_file.fileno()).st_ctime
try:
return make_aware(datetime.datetime.fromtimestamp(ctime))
except (pytz.NonExistentTimeError, pytz.AmbiguousTimeError):
timezone = pytz.timezone(settings.TIME_ZONE)
return timezone.localize(datetime.datetime.fromtimestamp(ctime), is_dst=False)
@property
def dialect_options(self):
"""turn dict items into string"""
file_type = self.csv_file.name.split('.')[-1]
if file_type in ('ods', 'xls', 'xlsx'):
return None
# Set dialect_options if None
if self._dialect_options is None:
self._detect_dialect_options()
self.save(cache=False)
options = {}
for k, v in self._dialect_options.items():
if isinstance(v, six.text_type):
v = force_str(v.encode('ascii'))
options[force_str(k.encode('ascii'))] = v
return options
@dialect_options.setter
def dialect_options(self, value):
self._dialect_options = value
@classmethod
def get_verbose_name(cls):
return cls._meta.verbose_name
def get_content_without_bom(self):
self.csv_file.seek(0)
content = self.csv_file.read()
return force_str(content.decode('utf-8-sig', 'ignore').encode('utf-8'))
def get_rows(self):
file_type = self.csv_file.name.split('.')[-1]
if file_type not in ('ods', 'xls', 'xlsx'):
content = self.get_content_without_bom()
reader = csv.reader(content.splitlines(), **self.dialect_options)
rows = list(reader)
else:
if file_type == 'ods':
content = get_data_ods(self.csv_file)
elif file_type == 'xls' or file_type == 'xlsx':
# Suffix is necessary as pyexcel is too stupid to detect the
# filetype from content
with tempfile.NamedTemporaryFile(mode='wb', suffix='.' + file_type) as fd:
self.csv_file.seek(0)
for buf in iter(lambda: self.csv_file.read(32768), b''):
fd.write(buf)
fd.flush()
content = get_data_xls(fd.name)
if len(content.keys()) == 1:
# if there's a single sheet, ignore specified sheet name and
# take the first one.
self.sheet_name = list(content.keys())[0]
if self.sheet_name not in content:
return []
rows = content[self.sheet_name]
if not rows:
return []
if self.skip_header:
rows = rows[1:]
rows = [[smart_text(x) for x in y] for y in rows if y]
titles = [t.strip() for t in self.columns_keynames.split(',')]
indexes = [titles.index(t) for t in titles if t]
caption = [titles[i] for i in indexes]
def get_cell(row, index):
try:
return row[index]
except IndexError:
return ''
return [{caption: get_cell(row, index) for caption, index in zip(caption, indexes)} for row in rows]
def get_cached_rows(self, initial=True):
found = False
for row in TableRow.objects.filter(resource=self):
found = True
yield row.data
if not found and initial:
# if there was no row probably the data was not cached in database
# yet.
self.cache_data()
for data in self.get_cached_rows(initial=False):
yield data
@property
def titles(self):
return [smart_text(t.strip()) for t in self.columns_keynames.split(',')]
@endpoint(perm='can_access', methods=['get'], name='data')
def data(self, request, **kwargs):
params = request.GET
filters = []
for column_title in [t.strip() for t in self.columns_keynames.split(',') if t]:
if column_title in params.keys():
query_value = request.GET.get(column_title, '')
if 'case-insensitive' in params:
filters.append("%s.lower() == %r" % (column_title, query_value.lower()))
else:
filters.append("%s == %r" % (column_title, query_value))
query = Query(filters='\n'.join(filters))
return self.execute_query(request, query, query_params=params.dict())
@endpoint(perm='can_access', methods=['get'], name='query', pattern=r'^(?P<query_name>[\w-]+)/$')
def select(self, request, query_name, **kwargs):
try:
query = Query.objects.get(resource=self.id, slug=query_name)
except Query.DoesNotExist:
raise APIError(u'no such query')
return self.execute_query(request, query, query_params=kwargs)
def execute_query(self, request, query, query_params=None):
query_params = query_params or {}
titles = self.titles
data = self.get_cached_rows()
def stream_expressions(expressions, data, kind, titles=None):
codes = []
for i, expr in enumerate(expressions):
try:
code = get_code(expr)
except (TypeError, SyntaxError) as e:
data = {
'expr': expr,
'error': smart_text(e),
}
if titles:
data['name'] = titles[i]
else:
data['idx'] = i
raise APIError(u'invalid %s expression' % kind, data=data)
codes.append((code, expr))
for row in data:
new_row = []
row_vars = dict(row)
row_vars['query'] = query_params
for i, (code, expr) in enumerate(codes):
try:
result = eval(code, {'normalize': normalize}, row_vars)
except Exception as e:
data = {
'expr': expr,
'row': repr(row),
'error': smart_text(e),
}
if titles:
data['name'] = titles[i]
else:
data['idx'] = i
raise APIError(u'invalid %s expression' % kind, data=data)
new_row.append(result)
yield new_row, row
filters = query.get_list('filters')
if filters:
data = [row for new_row, row in stream_expressions(filters, data, kind='filters') if all(new_row)]
order = query.get_list('order')
if order:
generator = stream_expressions(order, data, kind='order')
new_data = [(tuple(new_row), row) for new_row, row in generator]
new_data.sort(key=lambda x: x[0])
data = [row for key, row in new_data]
distinct = query.get_list('distinct')
if distinct:
generator = stream_expressions(distinct, data, kind='distinct')
seen = set()
new_data = []
for new_row, row in generator:
new_row = tuple(new_row)
try:
hash(new_row)
except TypeError:
raise APIError(
u'distinct value is unhashable',
data={
'row': repr(row),
'distinct': repr(new_row),
},
)
if new_row in seen:
continue
new_data.append(row)
seen.add(new_row)
data = new_data
projection = query.get_list('projections')
if projection:
expressions = []
titles = []
for mapping in projection:
name, expr = mapping.split(':', 1)
if not identifier_re.match(name):
raise APIError(u'invalid projection name', data=name)
titles.append(name)
expressions.append(expr)
new_data = []
for new_row, row in stream_expressions(expressions, data, kind='projection', titles=titles):
new_data.append(dict(zip(titles, new_row)))
data = new_data
if 'id' in request.GET:
# always provide a ?id= filter.
filters = ["id == %r" % force_text(request.GET['id'])]
data = [row for new_row, row in stream_expressions(filters, data, kind='filters') if new_row[0]]
# allow jsonp queries by select2
# filtering is done there after projection because we need a projection named text for
# retro-compatibility with previous use of the csvdatasource with select2
if 'q' in request.GET:
filters = ["%s in normalize(text.lower())" % repr(normalize(request.GET['q'].lower()))]
data = [row for new_row, row in stream_expressions(filters, data, kind='filters') if new_row[0]]
# force rendition of iterator as list
data = list(data)
if 'limit' in request.GET:
try:
limit = int(request.GET['limit'])
except ValueError:
raise APIError('invalid limit parameter')
if limit < 1:
raise APIError('invalid limit parameter')
try:
offset = int(request.GET.get('offset') or 0)
except ValueError:
raise APIError('invalid offset parameter')
if offset < 0:
raise APIError('invalid offset parameter')
# paginate data
data = data[offset : offset + limit]
if query.structure == 'array':
return {'data': [[row[t] for t in titles] for row in data]}
elif query.structure == 'dict':
return {'data': data}
elif query.structure == 'keyed-distinct':
distinct = query.get_list('distinct')
if len(distinct) != 1:
raise APIError('keyed format requires a single distinct field')
return {'data': {x[distinct[0]]: x for x in data}}
elif query.structure == 'tuples':
return {'data': [[[t, row[t]] for t in titles] for row in data]}
elif query.structure == 'onerow':
if len(data) != 1:
raise APIError('more or less than one row', data=data)
return {'data': data[0]}
elif query.structure == 'one':
if len(data) != 1:
raise APIError('more or less than one row', data=data)
if len(data[0]) != 1:
raise APIError('more or less than one column', data=data)
return {'data': list(data[0].values())[0]}
def export_json(self):
d = super(CsvDataSource, self).export_json()
d['queries'] = [query.export_json() for query in Query.objects.filter(resource=self)]
return d
@classmethod
def import_json_real(cls, overwrite, instance, d, **kwargs):
queries = d.pop('queries', [])
instance = super(CsvDataSource, cls).import_json_real(overwrite, instance, d, **kwargs)
new = []
if instance and overwrite:
Query.objects.filter(resource=instance).delete()
for query in queries:
q = Query.import_json(query)
q.resource = instance
new.append(q)
Query.objects.bulk_create(new)
return instance
def create_query_url(self):
return reverse('csv-new-query', kwargs={'connector_slug': self.slug})
def daily(self):
super().daily()
self.clean_used_files()
def clean_used_files(self):
if not os.path.exists(self.csv_file.path):
return
base_dir = os.path.dirname(self.csv_file.path)
if not os.path.exists(base_dir):
return
if os.path.dirname(self.csv_file.name) != os.path.join(self.get_connector_slug(), self.slug):
# path is not compliant with upload_to, do nothing
return
for filename in os.listdir(base_dir):
filepath = os.path.join(base_dir, filename)
if not os.path.isfile(filepath):
continue
if os.path.basename(self.csv_file.name) == filename:
# current file
continue
mtime = os.stat(filepath).st_mtime
if mtime > (now() + datetime.timedelta(days=-7)).timestamp():
# too young
continue
if getattr(settings, 'CSVDATASOURCE_REMOVE_ON_CLEAN', False) is True:
# remove
os.unlink(filepath)
else:
# move file in unused-files dir
unused_dir = os.path.join(base_dir, 'unused-files')
os.makedirs(unused_dir, exist_ok=True)
os.rename(filepath, os.path.join(unused_dir, filename))
class TableRow(models.Model):
resource = models.ForeignKey('CsvDataSource', on_delete=models.CASCADE)
line_number = models.IntegerField(null=False)
data = JSONField(blank=True)
class Meta:
ordering = ('line_number',)
unique_together = ('resource', 'line_number')