csv: add possibility of filtering any column (#10686)

This commit is contained in:
Frédéric Péters 2016-04-29 07:39:26 +02:00
parent 060d7482f2
commit b71ec791e2
3 changed files with 79 additions and 30 deletions

View File

@ -39,41 +39,65 @@ class CsvDataSource(BaseResource):
def get_absolute_url(self):
return reverse('csvdatasource-detail', kwargs={'slug': self.slug})
def get_data(self, filter_criteria=None, case_insensitive=False):
def get_data(self, filters=None, query=None, case_insensitive=False):
if case_insensitive:
if filters:
for filter_key in filters.keys():
filters[filter_key] = filters[filter_key].lower()
if query:
query = query.lower()
if case_insensitive and filter_criteria:
filter_criteria = filter_criteria.lower()
titles = [t.strip() for t in self.columns_keynames.split(',')]
indexes = [titles.index(t) for t in titles if t]
caption = [titles[i] for i in indexes]
def filter_row(row, titles):
# validate query (a "text" column must exist)
if query:
if 'text' in titles:
col = titles.index('text')
text = unicode(row[col], 'utf-8')
text_idx = titles.index('text')
else:
# no text column -> no query
query = None
# validate filters (appropriate columns must exist)
if filters:
for filter_key in filters.keys():
if not filter_key in titles:
del filters[filter_key]
def filter_row(row):
if query:
text = unicode(row[text_idx], 'utf-8')
if case_insensitive:
text = text.lower()
if filter_criteria not in text:
if not query in text:
return False
if filters:
for filter_key, filter_value in filters.items():
col = titles.index(filter_key)
text = unicode(row[col], 'utf-8')
if case_insensitive:
text = text.lower()
if filter_value != text:
return False
return True
content = self.csv_file.read()
if not content:
return None
# remove BOM
# remove BOM and detect CSV dialect
content = content.decode('utf-8-sig').encode('utf-8')
dialect = csv.Sniffer().sniff(content[:1024])
reader = csv.reader(content.splitlines(), dialect)
if self.skip_header:
reader.next()
titles = [t.strip() for t in self.columns_keynames.split(',')]
indexes = [titles.index(t) for t in titles if t]
caption = [titles[i] for i in indexes]
data = []
for row in reader:
if filter_criteria and not filter_row(row, titles):
if (filters or query) and not filter_row(row):
continue
try:
line = [row[i] for i in indexes]

View File

@ -43,5 +43,9 @@ class CsvDataView(View, SingleObjectMixin):
def get(self, request, *args, **kwargs):
obj = self.get_object()
case_insensitive = 'case-insensitive' in request.GET
return obj.get_data(filter_criteria=request.GET.get('q'),
case_insensitive=case_insensitive)
query = request.GET.get('q')
filters = {}
for column_title in [t.strip() for t in obj.columns_keynames.split(',')]:
if column_title in request.GET:
filters[column_title] = request.GET[column_title]
return obj.get_data(filters, query=query, case_insensitive=case_insensitive)

View File

@ -42,10 +42,10 @@ def test_unfiltered_data():
assert 'another_field' in item
def test_good_filter_data():
filter_criteria = 'ak'
filter_criteria = 'Zakia'
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames=',id,,text,')
result = csv.get_data(filter_criteria)
result = csv.get_data({'text': filter_criteria})
assert len(result)
for item in result:
assert 'id' in item
@ -56,40 +56,40 @@ def test_bad_filter_data():
filter_criteria = 'bad'
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames=',id,,text,')
result = csv.get_data(filter_criteria)
result = csv.get_data({'text': filter_criteria})
assert len(result) == 0
def test_useless_filter_data():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='id,,nom,prenom,sexe')
result = csv.get_data('Ali')
result = csv.get_data({'text': 'Ali'})
assert len(result) == 20
def test_columns_keynames_with_spaces():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='id , , nom,text , ')
result = csv.get_data('Yaniss')
result = csv.get_data({'text': 'Yaniss'})
assert len(result) == 1
def test_skipped_header_data():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames=',id,,text,',
skip_header=True)
result = csv.get_data('Eliot')
result = csv.get_data({'text': 'Eliot'})
assert len(result) == 0
def test_data():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='fam,id,, text,sexe ')
result = csv.get_data('Sacha')
result = csv.get_data({'text': 'Sacha'})
assert result[0] == {'id': '59', 'text': 'Sacha',
'fam': '2431', 'sexe': 'H'}
def test_unicode_filter_data():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames=',id,,text,')
filter_criteria = u'noît'
result = csv.get_data(filter_criteria)
filter_criteria = u'Benoît'
result = csv.get_data({'text': filter_criteria})
assert len(result)
for item in result:
assert 'id' in item
@ -99,18 +99,39 @@ def test_unicode_filter_data():
def test_unicode_case_insensitive_filter_data():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames=',id,,text,')
filter_criteria = u'Aëlle'
result = csv.get_data(filter_criteria=filter_criteria,
case_insensitive=True)
filter_criteria = u'anaëlle'
result = csv.get_data({'text': filter_criteria}, case_insensitive=True)
assert len(result)
for item in result:
assert 'id' in item
assert 'text' in item
assert filter_criteria.lower() in item['text'].decode('utf-8')
assert filter_criteria.lower() in item['text'].decode('utf-8').lower()
def test_data_bom():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data_bom), 'data.csv'),
columns_keynames='fam,id,, text,sexe ')
result = csv.get_data('Eliot')
result = csv.get_data({'text': 'Eliot'})
assert result[0] == {'id': '69981', 'text': 'Eliot',
'fam': '121', 'sexe': 'H'}
def test_multi_filter():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='fam,id,, text,sexe ')
result = csv.get_data({'sexe': 'F'})
assert result[0] == {'id': '6', 'text': 'Shanone',
'fam': '525', 'sexe': 'F'}
assert len(result) == 10
def test_query():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='fam,id,, text,sexe ')
result = csv.get_data(query='liot')
assert result[0]['text'] == 'Eliot'
assert len(result) == 1
def test_query_insensitive():
csv = CsvDataSource.objects.create(csv_file=File(StringIO(data), 'data.csv'),
columns_keynames='fam,id,, text,sexe ')
result = csv.get_data(query='elIo', case_insensitive=True)
assert result[0]['text'] == 'Eliot'
assert len(result) == 1