authentic/src/authentic2/csv_import.py

585 lines
20 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
import csv
import io
from chardet.universaldetector import UniversalDetector
import attr
from django import forms
from django.core.exceptions import FieldDoesNotExist
from django.core.validators import RegexValidator
from django.db import IntegrityError
from django.db.transaction import atomic
from django.utils import six
from django.utils.translation import ugettext as _
from authentic2 import app_settings
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.custom_user.models import User
from authentic2.forms.profile import modelform_factory, BaseUserForm
from authentic2.models import Attribute, AttributeValue, UserExternalId
class UTF8Recoder(object):
def __init__(self, fd):
self.fd = fd
def __iter__(self):
return self
def next(self):
return self.fd.next().encode('utf-8')
class UnicodeReader(object):
def __init__(self, fd, dialect='excel', **kwargs):
self.reader = csv.reader(UTF8Recoder(fd), dialect=dialect, **kwargs)
def next(self):
row = self.reader.next()
return [s.decode('utf-8') for s in row]
def __iter__(self):
return self
class CsvImporter(object):
rows = None
error = None
error_description = None
encoding = None
def run(self, fd_or_str, encoding):
if isinstance(fd_or_str, six.binary_type):
input_fd = io.BytesIO(fd_or_str)
elif isinstance(fd_or_str, six.text_type):
input_fd = io.StringIO(fd_or_str)
elif not hasattr(fd_or_str, 'read1'):
try:
input_fd = io.open(fd_or_str.fileno(), closefd=False, mode='rb')
except Exception:
try:
fd_or_str.seek(0)
except Exception:
pass
content = fd_or_str.read()
if isinstance(content, six.text_type):
input_fd = io.StringIO(content)
else:
input_fd = io.BytesIO(content)
else:
input_fd = fd_or_str
assert hasattr(input_fd, 'read'), 'fd_or_str is not a string or a file object'
def set_encoding(input_fd, encoding):
# detect StringIO
if hasattr(input_fd, 'line_buffering'):
return input_fd
if encoding == 'detect':
detector = UniversalDetector()
try:
for line in input_fd:
detector.feed(line)
if detector.done:
break
else:
self.error = Error('cannot-detect-encoding', _('Cannot detect encoding'))
return None
detector.close()
encoding = detector.result['encoding']
finally:
input_fd.seek(0)
if not hasattr(input_fd, 'readable'):
input_fd = io.open(input_fd.fileno(), 'rb', closefd=False)
return io.TextIOWrapper(input_fd, encoding=encoding)
def parse_csv():
try:
dialect = csv.Sniffer().sniff(input_fd.read().encode('utf-8'))
except csv.Error as e:
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect: %s') % e)
return False
finally:
input_fd.seek(0)
if not dialect:
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect'))
return False
reader = UnicodeReader(input_fd, dialect)
self.rows = list(reader)
return True
input_fd = set_encoding(input_fd, encoding)
if input_fd is None:
return False
return parse_csv()
@attr.s
class CsvHeader(object):
column = attr.ib()
name = attr.ib(default='')
field = attr.ib(default=False, converter=bool)
attribute = attr.ib(default=False, converter=bool)
create = attr.ib(default=True, metadata={'flag': True})
update = attr.ib(default=True, metadata={'flag': True})
key = attr.ib(default=False, metadata={'flag': True})
unique = attr.ib(default=False, metadata={'flag': True})
globally_unique = attr.ib(default=False, metadata={'flag': True})
verified = attr.ib(default=False, metadata={'flag': True})
@property
def flags(self):
flags = []
for attribute in attr.fields(self.__class__):
if attribute.metadata.get('flag'):
if getattr(self, attribute.name):
flags.append(attribute.name)
else:
flags.append('no-' + attribute.name.replace('_', '-'))
return flags
@attr.s
class Error(object):
code = attr.ib()
description = attr.ib(default='', cmp=False)
@attr.s(cmp=False)
class LineError(Error):
line = attr.ib(default=0)
column = attr.ib(default=0)
@classmethod
def from_error(cls, error):
return cls(**attr.asdict(error))
def as_error(self):
return Error(self.code, self.description)
def __eq__(self, other):
if isinstance(other, Error):
return self.as_error() == other
return (self.code, self.line, self.column) == (other.code, other.line, other.column)
class ImportUserForm(BaseUserForm):
def clean(self):
super(BaseUserForm, self).clean()
self._validate_unique = False
SOURCE_NAME = '_source_name'
SOURCE_ID = '_source_id'
SOURCE_COLUMNS = set([SOURCE_NAME, SOURCE_ID])
class ImportUserFormWithExternalId(ImportUserForm):
locals()[SOURCE_NAME] = forms.CharField(
label=_('Source name'),
required=False,
validators=[
RegexValidator(
r'^[a-zA-Z0-9_-]+$',
_('_source_name must contain no spaces and only letters, digits, - and _'),
'invalid')])
locals()[SOURCE_ID] = forms.CharField(
label=_('Source external id'))
@attr.s
class CsvRow(object):
line = attr.ib()
cells = attr.ib(default=[])
errors = attr.ib(default=[])
is_valid = attr.ib(default=True)
action = attr.ib(default=None)
def __getitem__(self, header):
for cell in self.cells:
if cell.header == header or cell.header.name == header:
return cell
raise KeyError(header.name)
def __iter__(self):
return iter(self.cells)
@attr.s
class CsvCell(object):
line = attr.ib()
header = attr.ib()
value = attr.ib(default=None)
missing = attr.ib(default=False)
errors = attr.ib(default=[])
action = attr.ib(default=None)
@property
def column(self):
return self.header.column
class Simulate(Exception):
pass
class CancelImport(Exception):
pass
class UserCsvImporter(object):
csv_importer = None
errors = None
headers = None
headers_by_name = None
rows = None
has_errors = False
ou = None
updated = 0
created = 0
rows_with_errors = 0
def add_error(self, line_error):
if not hasattr(line_error, 'line'):
line_error = LineError.from_error(line_error)
self.errors.append(line_error)
def run(self, fd_or_str, encoding, ou=None, simulate=False):
self.ou = ou or get_default_ou()
self.errors = []
self.csv_importer = CsvImporter()
def parse_csv():
if not self.csv_importer.run(fd_or_str, encoding):
self.add_error(self.csv_importer.error)
def do_import():
unique_map = {}
try:
with atomic():
for row in self.rows:
if not self.do_import_row(row, unique_map):
self.rows_with_errors += 1
if simulate:
raise Simulate
except Simulate:
pass
for action in [
parse_csv,
self.parse_header_row,
self.parse_rows,
do_import]:
action()
if self.errors:
break
self.has_errors = self.has_errors or bool(self.errors)
return not bool(self.errors)
def parse_header_row(self):
self.headers = []
self.headers_by_name = {}
try:
header_row = self.csv_importer.rows[0]
except IndexError:
self.add_error(Error('no-header-row', _('Missing header row')))
return
for i, head in enumerate(header_row):
self.parse_header(head, column=i + 1)
if not self.headers:
self.add_error(Error('empty-header-row', _('Empty header row')))
return
key_counts = sum(1 for header in self.headers if header.key)
if not key_counts:
self.add_error(Error('missing-key-column', _('Missing key column')))
if key_counts > 1:
self.add_error(Error('too-many-key-columns', _('Too many key columns')))
header_names = set(self.headers_by_name)
if header_names & SOURCE_COLUMNS and not SOURCE_COLUMNS.issubset(header_names):
self.add_error(
Error('invalid-external-id-pair',
_('You must have a _source_name and a _source_id column')))
def parse_header(self, head, column):
splitted = head.split()
try:
header = CsvHeader(column, splitted[0])
if header.name in self.headers_by_name:
self.add_error(
Error('duplicate-header', _('Header "%s" is duplicated') % header.name))
return
self.headers_by_name[header.name] = header
except IndexError:
header = CsvHeader(column)
else:
if header.name in SOURCE_COLUMNS:
if header.name == SOURCE_ID:
header.key = True
else:
try:
if header.name in ['email', 'first_name', 'last_name', 'username']:
User._meta.get_field(header.name)
header.field = True
if header.name == 'email':
# by default email are expected to be verified
header.verified = True
if header.name == 'email' and self.email_is_unique:
header.unique = True
if app_settings.A2_EMAIL_IS_UNIQUE:
header.globally_unique = True
if header.name == 'username' and self.username_is_unique:
header.unique = True
if app_settings.A2_USERNAME_IS_UNIQUE:
header.globally_unique = True
except FieldDoesNotExist:
pass
if not header.field:
try:
attribute = Attribute.objects.get(name=header.name) # NOQA: F841
header.attribute = True
except Attribute.DoesNotExist:
pass
self.headers.append(header)
if (not (header.field or header.attribute)
and header.name not in SOURCE_COLUMNS):
self.add_error(LineError('unknown-or-missing-attribute',
_('unknown or missing attribute "%s"') % head,
line=1, column=column))
return
for flag in splitted[1:]:
if header.name in SOURCE_COLUMNS:
self.add_error(LineError(
'flag-forbidden-on-source-columns',
_('You cannot set flags on _source_name and _source_id columns'),
line=1))
break
value = True
if flag.startswith('no-'):
value = False
flag = flag[3:]
flag = flag.replace('-', '_')
try:
if not getattr(attr.fields(CsvHeader), flag).metadata['flag']:
raise TypeError
setattr(header, flag, value)
except (AttributeError, TypeError, KeyError):
self.add_error(LineError('unknown-flag', _('unknown flag "%s"'), line=1, column=column))
def parse_rows(self):
base_form_class = ImportUserForm
if SOURCE_NAME in self.headers_by_name:
base_form_class = ImportUserFormWithExternalId
form_class = modelform_factory(User, fields=self.headers_by_name.keys(), form=base_form_class)
rows = self.rows = []
for i, row in enumerate(self.csv_importer.rows[1:]):
csv_row = self.parse_row(form_class, row, line=i + 2)
self.has_errors = self.has_errors or not(csv_row.is_valid)
rows.append(csv_row)
def parse_row(self, form_class, row, line):
data = {}
for header in self.headers:
try:
data[header.name] = row[header.column - 1]
except IndexError:
pass
form = form_class(data=data)
form.is_valid()
def get_form_errors(form, name):
return [Error('data-error', six.text_type(value)) for value in form.errors.get(name, [])]
cells = [
CsvCell(
line=line,
header=header,
value=data.get(header.name),
missing=header.name not in data,
errors=get_form_errors(form, header.name))
for header in self.headers]
cell_errors = any(bool(cell.errors) for cell in cells)
errors = get_form_errors(form, '__all__')
return CsvRow(
line=line,
cells=cells,
errors=errors,
is_valid=not bool(cell_errors or errors))
@property
def email_is_unique(self):
return app_settings.A2_EMAIL_IS_UNIQUE or self.ou.email_is_unique
@property
def username_is_unique(self):
return app_settings.A2_USERNAME_IS_UNIQUE or self.ou.username_is_unique
def check_unique_constraints(self, row, unique_map, user=None):
ou_users = User.objects.filter(ou=self.ou)
users = User.objects.all()
if user:
users = users.exclude(pk=user.pk)
ou_users = ou_users.exclude(pk=user.pk)
errors = []
for cell in row:
header = cell.header
if header.name == SOURCE_ID:
unique_key = (SOURCE_ID, row[SOURCE_NAME].value, cell.value)
elif header.key or header.globally_unique or header.unique:
unique_key = (header.name, cell.value)
else:
continue
if unique_key in unique_map:
errors.append(
Error('unique-constraint-failed',
_('Unique constraint on column "%(column)s" failed: '
'value already appear on line %(line)d') % {'column': header.name, 'line': row.line}))
else:
unique_map[unique_key] = row.line
for cell in row:
if (not cell.header.globally_unique and not cell.header.unique) or (user and not cell.header.update):
continue
qs = ou_users
if cell.header.globally_unique:
qs = users
if cell.header.field:
unique = not qs.filter(**{cell.header.name: cell.value}).exists()
elif cell.header.attribute:
atvs = AttributeValue.objects.filter(attribute__name=cell.header.name, content=cell.value)
unique = not qs.filter(attribute_values__in=atvs).exists()
if not unique:
errors.append(
Error('unique-constraint-failed', _('Unique constraint on column "%s" failed') % cell.header.name))
row.errors.extend(errors)
row.is_valid = row.is_valid and not bool(errors)
return not bool(errors)
@atomic
def do_import_row(self, row, unique_map):
if not row.is_valid:
return False
for header in self.headers:
if header.key:
header_key = header
break
else:
assert False, 'should not happen'
user = None
if header_key.name == SOURCE_ID:
# lookup by external id
source_name = row[SOURCE_NAME].value
source_id = row[SOURCE_ID].value
userexternalids = UserExternalId.objects.filter(source=source_name, external_id=source_id)
users = User.objects.filter(userexternalid__in=userexternalids)[:2]
else:
# lookup by field/attribute
key_value = row[header_key].value
if header_key.field:
users = User.objects.filter(
**{header_key.name: key_value})
elif header_key.attribute:
atvs = AttributeValue.objects.filter(attribute__name=header_key.name, content=key_value)
users = User.objects.filter(attribute_values__in=atvs)
users = users[:2]
if users:
row.action = 'update'
else:
row.action = 'create'
if len(users) > 1:
row.errors.append(
Error('key-matches-too-many-users',
_('Key value "%s" matches too many users') % key_value))
return False
user = None
if users:
user = users[0]
if not self.check_unique_constraints(row, unique_map, user=user):
return False
if not user:
user = User(ou=self.ou)
for cell in row.cells:
if not cell.header.field:
continue
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
if getattr(user, cell.header.name) != cell.value:
setattr(user, cell.header.name, cell.value)
if cell.header.name == 'email' and cell.header.verified:
user.email_verified = True
cell.action = 'updated'
continue
cell.action = 'nothing'
user.save()
if header_key.name == SOURCE_ID:
try:
UserExternalId.objects.create(user=user,
source=source_name,
external_id=source_id)
except IntegrityError:
# should never happen since we have a unique index...
source_full_id = '%s.%s' % (source_name, source_id)
self.errors.append(
Error('external-id-already-exist',
_('External id "%s" already exists') % source_full_id))
raise CancelImport
for cell in row.cells:
if cell.header.field or not cell.header.attribute:
continue
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
attributes = user.attributes
if cell.header.verified:
attributes = user.verified_attributes
if getattr(attributes, cell.header.name) != cell.value:
setattr(attributes, cell.header.name, cell.value)
cell.action = 'updated'
continue
cell.action = 'nothing'
setattr(self, row.action + 'd', getattr(self, row.action + 'd') + 1)
return True