authentic/src/authentic2/csv_import.py

762 lines
26 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
import csv
import io
from chardet.universaldetector import UniversalDetector
import attr
from django import forms
from django.core.exceptions import FieldDoesNotExist
from django.core.validators import RegexValidator
from django.db import IntegrityError, models
from django.db.transaction import atomic
from django.utils import six
from django.utils.encoding import force_bytes
from django.utils.encoding import force_text
from django.utils.translation import ugettext as _
from django_rbac.utils import get_role_model
from authentic2 import app_settings
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.custom_user.models import User
from authentic2.forms.profile import modelform_factory, BaseUserForm
from authentic2.models import Attribute, AttributeValue, UserExternalId
from authentic2.utils import send_password_reset_mail
Role = get_role_model()
# http://www.attrs.org/en/stable/changelog.html :
# 19.2.0
# ------
# The cmp argument to attr.s() and attr.ib() is now deprecated.
# Please use eq to add equality methods (__eq__ and __ne__) and order to add
# ordering methods (__lt__, __le__, __gt__, and __ge__) instead - just like
# with dataclasses.
#
# DeprecationWarning: The usage of `cmp` is deprecated and will be removed on
# or after 2021-06-01. Please use `eq` and `order` instead.
# description = attr.ib(default='', cmp=False)
#
def attrib(**kwargs):
if [int(x) for x in attr.__version__.split('.')] >= [19, 2]:
if 'cmp' in kwargs:
kwargs['eq'] = kwargs['cmp']
kwargs['order'] = kwargs['cmp']
del kwargs['cmp']
return attr.ib(**kwargs)
def attrs(func=None, **kwargs):
if [int(x) for x in attr.__version__.split('.')] >= [19, 2]:
if 'cmp' in kwargs:
kwargs['eq'] = kwargs['cmp']
kwargs['order'] = kwargs['cmp']
del kwargs['cmp']
if func is None:
return attr.s(**kwargs)
else:
return attr.s(func, **kwargs)
class UTF8Recoder(object):
def __init__(self, fd):
self.fd = fd
def __iter__(self):
return self
def __next__(self):
if six.PY2:
return self.fd.next().encode('utf-8')
else:
return force_text(self.fd.__next__().encode('utf-8'))
next = __next__
class UnicodeReader(object):
def __init__(self, fd, dialect='excel', **kwargs):
self.reader = csv.reader(UTF8Recoder(fd), dialect=dialect, **kwargs)
def __next__(self):
if six.PY2:
row = self.reader.next()
return [s.decode('utf-8') for s in row]
else:
row = self.reader.__next__()
return [force_bytes(s).decode('utf-8') for s in row]
def __iter__(self):
return self
next = __next__
class CsvImporter(object):
rows = None
error = None
error_description = None
encoding = None
def run(self, fd_or_str, encoding):
if isinstance(fd_or_str, six.binary_type):
input_fd = io.BytesIO(fd_or_str)
elif isinstance(fd_or_str, six.text_type):
input_fd = io.StringIO(fd_or_str)
elif not hasattr(fd_or_str, 'read1'):
try:
input_fd = io.open(fd_or_str.fileno(), closefd=False, mode='rb')
except Exception:
try:
fd_or_str.seek(0)
except Exception:
pass
content = fd_or_str.read()
if isinstance(content, six.text_type):
input_fd = io.StringIO(content)
else:
input_fd = io.BytesIO(content)
else:
input_fd = fd_or_str
assert hasattr(input_fd, 'read'), 'fd_or_str is not a string or a file object'
def set_encoding(input_fd, encoding):
# detect StringIO
if hasattr(input_fd, 'line_buffering'):
return input_fd
if encoding == 'detect':
detector = UniversalDetector()
try:
for line in input_fd:
detector.feed(line)
if detector.done:
break
else:
self.error = Error('cannot-detect-encoding', _('Cannot detect encoding'))
return None
detector.close()
encoding = detector.result['encoding']
finally:
input_fd.seek(0)
if not hasattr(input_fd, 'readable'):
input_fd = io.open(input_fd.fileno(), 'rb', closefd=False)
return io.TextIOWrapper(input_fd, encoding=encoding)
def parse_csv(input_fd):
try:
content = force_text(input_fd.read().encode('utf-8'))
except UnicodeDecodeError:
self.error = Error('bad-encoding', _('Bad encoding'))
return False
try:
dialect = csv.Sniffer().sniff(content)
except csv.Error as e:
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect: %s') % e)
return False
finally:
input_fd.seek(0)
if not dialect:
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect'))
return False
reader = UnicodeReader(input_fd, dialect)
self.rows = list(reader)
return True
with input_fd:
final_input_fd = set_encoding(input_fd, encoding)
if final_input_fd is None:
return False
return parse_csv(final_input_fd)
@attrs
class CsvHeader(object):
column = attrib()
name = attrib(default='')
field = attrib(default=False, converter=bool)
attribute = attrib(default=False, converter=bool)
create = attrib(default=True, metadata={'flag': True})
update = attrib(default=True, metadata={'flag': True})
key = attrib(default=False, metadata={'flag': True})
unique = attrib(default=False, metadata={'flag': True})
globally_unique = attrib(default=False, metadata={'flag': True})
verified = attrib(default=False, metadata={'flag': True})
delete = attrib(default=False, metadata={'flag': True})
clear = attrib(default=False, metadata={'flag': True})
@property
def flags(self):
flags = []
for attribute in attr.fields(self.__class__):
if attribute.metadata.get('flag'):
if getattr(self, attribute.name, attribute.default):
flags.append(attribute.name)
else:
flags.append('no-' + attribute.name.replace('_', '-'))
return flags
@attrs
class Error(object):
code = attrib()
description = attrib(default='', cmp=False)
@attrs(cmp=False)
class LineError(Error):
line = attrib(default=0)
column = attrib(default=0)
@classmethod
def from_error(cls, error):
return cls(**attr.asdict(error))
def as_error(self):
return Error(self.code, self.description)
def __eq__(self, other):
if isinstance(other, Error):
return self.as_error() == other
return (self.code, self.line, self.column) == (other.code, other.line, other.column)
SOURCE_NAME = '_source_name'
SOURCE_ID = '_source_id'
SOURCE_COLUMNS = set([SOURCE_NAME, SOURCE_ID])
ROLE_NAME = '_role_name'
ROLE_SLUG = '_role_slug'
REGISTRATION = '@registration'
REGISTRATION_RESET_EMAIL = 'send-email'
SPECIAL_COLUMNS = SOURCE_COLUMNS | {ROLE_NAME, ROLE_SLUG, REGISTRATION}
class ImportUserForm(BaseUserForm):
locals()[ROLE_NAME] = forms.CharField(
label=_('Role name'),
required=False)
locals()[ROLE_SLUG] = forms.CharField(
label=_('Role slug'),
required=False)
choices = [
(REGISTRATION_RESET_EMAIL, _('Email user so they can set a password')),
]
locals()[REGISTRATION] = forms.ChoiceField(
choices=choices,
label=_('Registration option'),
required=False)
def clean(self):
super(BaseUserForm, self).clean()
self._validate_unique = False
class ImportUserFormWithExternalId(ImportUserForm):
locals()[SOURCE_NAME] = forms.CharField(
label=_('Source name'),
required=False,
validators=[
RegexValidator(
r'^[a-zA-Z0-9_-]+$',
_('_source_name must contain no spaces and only letters, digits, - and _'),
'invalid')])
locals()[SOURCE_ID] = forms.CharField(
label=_('Source external id'))
@attrs
class CsvRow(object):
line = attrib()
cells = attrib(default=[])
errors = attrib(default=[])
is_valid = attrib(default=True)
action = attrib(default=None)
user_first_seen = attrib(default=True)
def __getitem__(self, header):
for cell in self.cells:
if cell.header == header or cell.header.name == header:
return cell
raise KeyError(header.name)
ACTIONS = {
'update': _('update'),
'create': _('create'),
}
@property
def action_display(self):
return self.ACTIONS.get(self.action, self.action)
@property
def has_errors(self):
return self.errors or self.has_cell_errors
@property
def has_cell_errors(self):
return any(cell.errors for cell in self)
def __iter__(self):
return iter(self.cells)
def __len__(self):
return len(self.cells)
@attrs
class CsvCell(object):
line = attrib()
header = attrib()
value = attrib(default=None)
missing = attrib(default=False)
errors = attrib(default=[])
action = attrib(default=None)
@property
def column(self):
return self.header.column
class Simulate(Exception):
pass
class CancelImport(Exception):
pass
class UserCsvImporter(object):
csv_importer = None
errors = None
headers = None
headers_by_name = None
rows = None
has_errors = False
ou = None
updated = 0
created = 0
rows_with_errors = 0
def add_error(self, line_error):
if not hasattr(line_error, 'line'):
line_error = LineError.from_error(line_error)
self.errors.append(line_error)
def run(self, fd_or_str, encoding, ou=None, simulate=False):
self.ou = ou or get_default_ou()
self.errors = []
self.csv_importer = CsvImporter()
self.max_user_id = User.objects.aggregate(max=models.Max('id'))['max'] or -1
def parse_csv():
if not self.csv_importer.run(fd_or_str, encoding):
self.add_error(self.csv_importer.error)
def do_import():
unique_map = {}
try:
with atomic():
for row in self.rows:
try:
if not self.do_import_row(row, unique_map):
self.rows_with_errors += 1
row.is_valid = False
except CancelImport:
self.rows_with_errors += 1
if row.errors or not row.is_valid:
self.has_errors = True
if simulate:
raise Simulate
except Simulate:
pass
for action in [
parse_csv,
self.parse_header_row,
self.parse_rows,
do_import]:
action()
if self.errors:
break
self.has_errors = self.has_errors or bool(self.errors)
return not bool(self.errors)
def parse_header_row(self):
self.headers = []
self.headers_by_name = {}
try:
header_row = self.csv_importer.rows[0]
except IndexError:
self.add_error(Error('no-header-row', _('Missing header row')))
return
for i, head in enumerate(header_row):
self.parse_header(head, column=i + 1)
if not self.headers:
self.add_error(Error('empty-header-row', _('Empty header row')))
return
key_counts = sum(1 for header in self.headers if header.key)
if not key_counts:
self.add_error(Error('missing-key-column', _('Missing key column')))
if key_counts > 1:
self.add_error(Error('too-many-key-columns', _('Too many key columns')))
header_names = set(self.headers_by_name)
if header_names & SOURCE_COLUMNS and not SOURCE_COLUMNS.issubset(header_names):
self.add_error(
Error('invalid-external-id-pair',
_('You must have a _source_name and a _source_id column')))
if ROLE_NAME in header_names and ROLE_SLUG in header_names:
self.add_error(
Error('invalid-role-column',
_('Either specify role names or role slugs, not both')))
def parse_header(self, head, column):
splitted = head.split()
try:
header = CsvHeader(column, splitted[0])
if header.name in self.headers_by_name:
self.add_error(
Error('duplicate-header', _('Header "%s" is duplicated') % header.name))
return
self.headers_by_name[header.name] = header
except IndexError:
header = CsvHeader(column)
else:
if header.name in SOURCE_COLUMNS:
if header.name == SOURCE_ID:
header.key = True
elif header.name not in SPECIAL_COLUMNS:
try:
if header.name in ['email', 'first_name', 'last_name', 'username']:
User._meta.get_field(header.name)
header.field = True
if header.name == 'email':
# by default email are expected to be verified
header.verified = True
if header.name == 'email' and self.email_is_unique:
header.unique = True
if app_settings.A2_EMAIL_IS_UNIQUE:
header.globally_unique = True
if header.name == 'username' and self.username_is_unique:
header.unique = True
if app_settings.A2_USERNAME_IS_UNIQUE:
header.globally_unique = True
except FieldDoesNotExist:
pass
if not header.field:
try:
attribute = Attribute.objects.get(name=header.name) # NOQA: F841
header.attribute = True
except Attribute.DoesNotExist:
pass
self.headers.append(header)
if (not (header.field or header.attribute)
and header.name not in SPECIAL_COLUMNS):
self.add_error(LineError('unknown-or-missing-attribute',
_('unknown or missing attribute "%s"') % head,
line=1, column=column))
return
for flag in splitted[1:]:
if header.name in SOURCE_COLUMNS:
self.add_error(LineError(
'flag-forbidden-on-source-columns',
_('You cannot set flags on _source_name and _source_id columns'),
line=1))
break
value = True
if flag.startswith('no-'):
value = False
flag = flag[3:]
flag = flag.replace('-', '_')
try:
if not getattr(attr.fields(CsvHeader), flag).metadata['flag']:
raise TypeError
setattr(header, flag, value)
except (AttributeError, TypeError, KeyError):
self.add_error(LineError('unknown-flag', _('unknown flag "%s"'), line=1, column=column))
def parse_rows(self):
base_form_class = ImportUserForm
if SOURCE_NAME in self.headers_by_name:
base_form_class = ImportUserFormWithExternalId
form_class = modelform_factory(User, fields=self.headers_by_name.keys(), form=base_form_class)
rows = self.rows = []
for i, row in enumerate(self.csv_importer.rows[1:]):
csv_row = self.parse_row(form_class, row, line=i + 2)
self.has_errors = self.has_errors or not(csv_row.is_valid)
rows.append(csv_row)
def parse_row(self, form_class, row, line):
data = {}
for header in self.headers:
try:
data[header.name] = row[header.column - 1]
except IndexError:
pass
form = form_class(data=data)
form.is_valid()
def get_form_errors(form, name):
return [Error('data-error', six.text_type(value)) for value in form.errors.get(name, [])]
cells = [
CsvCell(
line=line,
header=header,
value=form.cleaned_data.get(header.name),
missing=header.name not in data,
errors=get_form_errors(form, header.name))
for header in self.headers]
cell_errors = any(bool(cell.errors) for cell in cells)
errors = get_form_errors(form, '__all__')
return CsvRow(
line=line,
cells=cells,
errors=errors,
is_valid=not bool(cell_errors or errors))
@property
def email_is_unique(self):
return app_settings.A2_EMAIL_IS_UNIQUE or self.ou.email_is_unique
@property
def username_is_unique(self):
return app_settings.A2_USERNAME_IS_UNIQUE or self.ou.username_is_unique
@property
def allow_duplicate_key(self):
return ROLE_NAME in self.headers_by_name or ROLE_SLUG in self.headers_by_name
def check_unique_constraints(self, row, unique_map, user=None):
ou_users = User.objects.filter(ou=self.ou)
# ignore new users
users = User.objects.exclude(id__gt=self.max_user_id)
if user:
users = users.exclude(pk=user.pk)
ou_users = ou_users.exclude(pk=user.pk)
errors = []
for cell in row:
header = cell.header
if header.name == SOURCE_ID:
unique_key = (SOURCE_ID, row[SOURCE_NAME].value, cell.value)
elif header.key or header.globally_unique or header.unique:
if not cell.value:
# empty values are not checked
continue
unique_key = (header.name, cell.value)
else:
continue
if unique_key in unique_map:
if user and self.allow_duplicate_key:
row.user_first_seen = False
else:
errors.append(
Error('unique-constraint-failed',
_('Unique constraint on column "%(column)s" failed: '
'value already appear on line %(line)d') % {
'column': header.name,
'line': unique_map[unique_key]}))
else:
unique_map[unique_key] = row.line
for cell in row:
if (not cell.header.globally_unique and not cell.header.unique) or (user and not cell.header.update):
continue
if not cell.value:
continue
qs = ou_users
if cell.header.globally_unique:
qs = users
if cell.header.field:
unique = not qs.filter(**{cell.header.name: cell.value}).exists()
elif cell.header.attribute:
atvs = AttributeValue.objects.filter(attribute__name=cell.header.name, content=cell.value)
unique = not qs.filter(attribute_values__in=atvs).exists()
if not unique:
if user and self.allow_duplicate_key:
row.user_first_seen = False
else:
errors.append(
Error('unique-constraint-failed',
_('Unique constraint on column "%s" failed') % cell.header.name))
row.errors.extend(errors)
row.is_valid = row.is_valid and not bool(errors)
return not bool(errors)
@atomic
def do_import_row(self, row, unique_map):
if not row.is_valid:
return False
success = True
for header in self.headers:
if header.key:
header_key = header
break
else:
assert False, 'should not happen'
user = None
if header_key.name == SOURCE_ID:
# lookup by external id
source_name = row[SOURCE_NAME].value
source_id = row[SOURCE_ID].value
userexternalids = UserExternalId.objects.filter(source=source_name, external_id=source_id)
users = User.objects.filter(userexternalid__in=userexternalids)[:2]
else:
# lookup by field/attribute
key_value = row[header_key].value
if header_key.field:
users = User.objects.filter(
**{header_key.name: key_value}, deleted__isnull=True)
elif header_key.attribute:
atvs = AttributeValue.objects.filter(attribute__name=header_key.name, content=key_value)
users = User.objects.filter(
attribute_values__in=atvs, deleted__isnull=True)
users = users[:2]
if users:
row.action = 'update'
else:
row.action = 'create'
if len(users) > 1:
row.errors.append(
Error('key-matches-too-many-users',
_('Key value "%s" matches too many users') % key_value))
return False
user = None
if users:
user = users[0]
if not self.check_unique_constraints(row, unique_map, user=user):
return False
if not row.user_first_seen:
cell = next(c for c in row.cells if c.header.name in {ROLE_NAME, ROLE_SLUG})
return self.add_role(cell, user)
if not user:
user = User(ou=self.ou)
user.set_random_password()
for cell in row.cells:
if not cell.header.field:
continue
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
if getattr(user, cell.header.name) != cell.value:
setattr(user, cell.header.name, cell.value)
if cell.header.name == 'email' and cell.header.verified:
user.email_verified = True
cell.action = 'updated'
continue
cell.action = 'nothing'
user.save()
if header_key.name == SOURCE_ID and row.action == 'create':
try:
UserExternalId.objects.create(user=user,
source=source_name,
external_id=source_id)
except IntegrityError:
# should never happen since we have a unique index...
source_full_id = '%s.%s' % (source_name, source_id)
row.errors.append(
Error('external-id-already-exist',
_('External id "%s" already exists') % source_full_id))
raise CancelImport
for cell in row.cells:
if cell.header.field or not cell.header.attribute:
continue
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
attributes = user.attributes
if cell.header.verified:
attributes = user.verified_attributes
if getattr(attributes, cell.header.name) != cell.value:
setattr(attributes, cell.header.name, cell.value)
cell.action = 'updated'
continue
cell.action = 'nothing'
for cell in row.cells:
if cell.header.field or cell.header.attribute:
continue
if cell.header.name in {ROLE_NAME, ROLE_SLUG}:
success &= self.add_role(cell, user, do_clear=True)
elif cell.header.name == REGISTRATION and row.action == 'create':
success &= self.registration_option(cell, user)
setattr(self, row.action + 'd', getattr(self, row.action + 'd') + 1)
return success
def add_role(self, cell, user, do_clear=False):
if not cell.value.strip():
return True
try:
if cell.header.name == ROLE_NAME:
role = Role.objects.get(name=cell.value, ou=self.ou)
elif cell.header.name == ROLE_SLUG:
role = Role.objects.get(slug=cell.value, ou=self.ou)
except Role.DoesNotExist:
cell.errors.append(
Error('role-not-found',
_('Role "%s" does not exist') % cell.value))
return False
if cell.header.delete:
user.roles.remove(role)
elif cell.header.clear:
if do_clear:
user.roles.clear()
user.roles.add(role)
else:
user.roles.add(role)
cell.action = 'updated'
return True
def registration_option(self, cell, user):
if cell.value == REGISTRATION_RESET_EMAIL:
send_password_reset_mail(
user,
template_names=['authentic2/manager/user_create_registration_email',
'authentic2/password_reset'],
next_url='/accounts/',
context={'user': user})
return True