585 lines
20 KiB
Python
585 lines
20 KiB
Python
# authentic2 - versatile identity manager
|
|
# Copyright (C) 2010-2019 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import csv
|
|
import io
|
|
|
|
from chardet.universaldetector import UniversalDetector
|
|
import attr
|
|
|
|
from django import forms
|
|
from django.core.exceptions import FieldDoesNotExist
|
|
from django.core.validators import RegexValidator
|
|
from django.db import IntegrityError
|
|
from django.db.transaction import atomic
|
|
from django.utils import six
|
|
from django.utils.translation import ugettext as _
|
|
|
|
from authentic2 import app_settings
|
|
from authentic2.a2_rbac.utils import get_default_ou
|
|
from authentic2.custom_user.models import User
|
|
from authentic2.forms.profile import modelform_factory, BaseUserForm
|
|
from authentic2.models import Attribute, AttributeValue, UserExternalId
|
|
|
|
|
|
class UTF8Recoder(object):
|
|
def __init__(self, fd):
|
|
self.fd = fd
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def next(self):
|
|
return self.fd.next().encode('utf-8')
|
|
|
|
|
|
class UnicodeReader(object):
|
|
def __init__(self, fd, dialect='excel', **kwargs):
|
|
self.reader = csv.reader(UTF8Recoder(fd), dialect=dialect, **kwargs)
|
|
|
|
def next(self):
|
|
row = self.reader.next()
|
|
return [s.decode('utf-8') for s in row]
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
|
|
class CsvImporter(object):
|
|
rows = None
|
|
error = None
|
|
error_description = None
|
|
encoding = None
|
|
|
|
def run(self, fd_or_str, encoding):
|
|
if isinstance(fd_or_str, six.binary_type):
|
|
input_fd = io.BytesIO(fd_or_str)
|
|
elif isinstance(fd_or_str, six.text_type):
|
|
input_fd = io.StringIO(fd_or_str)
|
|
elif not hasattr(fd_or_str, 'read1'):
|
|
try:
|
|
input_fd = io.open(fd_or_str.fileno(), closefd=False, mode='rb')
|
|
except Exception:
|
|
try:
|
|
fd_or_str.seek(0)
|
|
except Exception:
|
|
pass
|
|
content = fd_or_str.read()
|
|
if isinstance(content, six.text_type):
|
|
input_fd = io.StringIO(content)
|
|
else:
|
|
input_fd = io.BytesIO(content)
|
|
else:
|
|
input_fd = fd_or_str
|
|
|
|
assert hasattr(input_fd, 'read'), 'fd_or_str is not a string or a file object'
|
|
|
|
def set_encoding(input_fd, encoding):
|
|
# detect StringIO
|
|
if hasattr(input_fd, 'line_buffering'):
|
|
return input_fd
|
|
|
|
if encoding == 'detect':
|
|
detector = UniversalDetector()
|
|
|
|
try:
|
|
for line in input_fd:
|
|
detector.feed(line)
|
|
if detector.done:
|
|
break
|
|
else:
|
|
self.error = Error('cannot-detect-encoding', _('Cannot detect encoding'))
|
|
return None
|
|
detector.close()
|
|
encoding = detector.result['encoding']
|
|
finally:
|
|
input_fd.seek(0)
|
|
|
|
if not hasattr(input_fd, 'readable'):
|
|
input_fd = io.open(input_fd.fileno(), 'rb', closefd=False)
|
|
return io.TextIOWrapper(input_fd, encoding=encoding)
|
|
|
|
def parse_csv():
|
|
try:
|
|
dialect = csv.Sniffer().sniff(input_fd.read().encode('utf-8'))
|
|
except csv.Error as e:
|
|
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect: %s') % e)
|
|
return False
|
|
finally:
|
|
input_fd.seek(0)
|
|
|
|
if not dialect:
|
|
self.error = Error('unknown-csv-dialect', _('Unknown CSV dialect'))
|
|
return False
|
|
reader = UnicodeReader(input_fd, dialect)
|
|
self.rows = list(reader)
|
|
return True
|
|
|
|
input_fd = set_encoding(input_fd, encoding)
|
|
if input_fd is None:
|
|
return False
|
|
|
|
return parse_csv()
|
|
|
|
|
|
@attr.s
|
|
class CsvHeader(object):
|
|
column = attr.ib()
|
|
name = attr.ib(default='')
|
|
field = attr.ib(default=False, converter=bool)
|
|
attribute = attr.ib(default=False, converter=bool)
|
|
create = attr.ib(default=True, metadata={'flag': True})
|
|
update = attr.ib(default=True, metadata={'flag': True})
|
|
key = attr.ib(default=False, metadata={'flag': True})
|
|
unique = attr.ib(default=False, metadata={'flag': True})
|
|
globally_unique = attr.ib(default=False, metadata={'flag': True})
|
|
verified = attr.ib(default=False, metadata={'flag': True})
|
|
|
|
@property
|
|
def flags(self):
|
|
flags = []
|
|
for attribute in attr.fields(self.__class__):
|
|
if attribute.metadata.get('flag'):
|
|
if getattr(self, attribute.name):
|
|
flags.append(attribute.name)
|
|
else:
|
|
flags.append('no-' + attribute.name.replace('_', '-'))
|
|
return flags
|
|
|
|
|
|
@attr.s
|
|
class Error(object):
|
|
code = attr.ib()
|
|
description = attr.ib(default='', cmp=False)
|
|
|
|
|
|
@attr.s(cmp=False)
|
|
class LineError(Error):
|
|
line = attr.ib(default=0)
|
|
column = attr.ib(default=0)
|
|
|
|
@classmethod
|
|
def from_error(cls, error):
|
|
return cls(**attr.asdict(error))
|
|
|
|
def as_error(self):
|
|
return Error(self.code, self.description)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, Error):
|
|
return self.as_error() == other
|
|
return (self.code, self.line, self.column) == (other.code, other.line, other.column)
|
|
|
|
|
|
class ImportUserForm(BaseUserForm):
|
|
def clean(self):
|
|
super(BaseUserForm, self).clean()
|
|
self._validate_unique = False
|
|
|
|
SOURCE_NAME = '_source_name'
|
|
SOURCE_ID = '_source_id'
|
|
SOURCE_COLUMNS = set([SOURCE_NAME, SOURCE_ID])
|
|
|
|
|
|
class ImportUserFormWithExternalId(ImportUserForm):
|
|
locals()[SOURCE_NAME] = forms.CharField(
|
|
label=_('Source name'),
|
|
required=False,
|
|
validators=[
|
|
RegexValidator(
|
|
r'^[a-zA-Z0-9_-]+$',
|
|
_('_source_name must contain no spaces and only letters, digits, - and _'),
|
|
'invalid')])
|
|
locals()[SOURCE_ID] = forms.CharField(
|
|
label=_('Source external id'))
|
|
|
|
|
|
@attr.s
|
|
class CsvRow(object):
|
|
line = attr.ib()
|
|
cells = attr.ib(default=[])
|
|
errors = attr.ib(default=[])
|
|
is_valid = attr.ib(default=True)
|
|
action = attr.ib(default=None)
|
|
|
|
def __getitem__(self, header):
|
|
for cell in self.cells:
|
|
if cell.header == header or cell.header.name == header:
|
|
return cell
|
|
raise KeyError(header.name)
|
|
|
|
def __iter__(self):
|
|
return iter(self.cells)
|
|
|
|
|
|
@attr.s
|
|
class CsvCell(object):
|
|
line = attr.ib()
|
|
header = attr.ib()
|
|
value = attr.ib(default=None)
|
|
missing = attr.ib(default=False)
|
|
errors = attr.ib(default=[])
|
|
action = attr.ib(default=None)
|
|
|
|
@property
|
|
def column(self):
|
|
return self.header.column
|
|
|
|
|
|
class Simulate(Exception):
|
|
pass
|
|
|
|
|
|
class CancelImport(Exception):
|
|
pass
|
|
|
|
|
|
class UserCsvImporter(object):
|
|
csv_importer = None
|
|
errors = None
|
|
headers = None
|
|
headers_by_name = None
|
|
rows = None
|
|
has_errors = False
|
|
ou = None
|
|
updated = 0
|
|
created = 0
|
|
rows_with_errors = 0
|
|
|
|
def add_error(self, line_error):
|
|
if not hasattr(line_error, 'line'):
|
|
line_error = LineError.from_error(line_error)
|
|
self.errors.append(line_error)
|
|
|
|
def run(self, fd_or_str, encoding, ou=None, simulate=False):
|
|
self.ou = ou or get_default_ou()
|
|
self.errors = []
|
|
self.csv_importer = CsvImporter()
|
|
|
|
def parse_csv():
|
|
if not self.csv_importer.run(fd_or_str, encoding):
|
|
self.add_error(self.csv_importer.error)
|
|
|
|
def do_import():
|
|
unique_map = {}
|
|
|
|
try:
|
|
with atomic():
|
|
for row in self.rows:
|
|
if not self.do_import_row(row, unique_map):
|
|
self.rows_with_errors += 1
|
|
if simulate:
|
|
raise Simulate
|
|
except Simulate:
|
|
pass
|
|
|
|
for action in [
|
|
parse_csv,
|
|
self.parse_header_row,
|
|
self.parse_rows,
|
|
do_import]:
|
|
action()
|
|
if self.errors:
|
|
break
|
|
|
|
self.has_errors = self.has_errors or bool(self.errors)
|
|
return not bool(self.errors)
|
|
|
|
def parse_header_row(self):
|
|
self.headers = []
|
|
self.headers_by_name = {}
|
|
|
|
try:
|
|
header_row = self.csv_importer.rows[0]
|
|
except IndexError:
|
|
self.add_error(Error('no-header-row', _('Missing header row')))
|
|
return
|
|
|
|
for i, head in enumerate(header_row):
|
|
self.parse_header(head, column=i + 1)
|
|
|
|
if not self.headers:
|
|
self.add_error(Error('empty-header-row', _('Empty header row')))
|
|
return
|
|
|
|
key_counts = sum(1 for header in self.headers if header.key)
|
|
|
|
if not key_counts:
|
|
self.add_error(Error('missing-key-column', _('Missing key column')))
|
|
if key_counts > 1:
|
|
self.add_error(Error('too-many-key-columns', _('Too many key columns')))
|
|
|
|
header_names = set(self.headers_by_name)
|
|
if header_names & SOURCE_COLUMNS and not SOURCE_COLUMNS.issubset(header_names):
|
|
self.add_error(
|
|
Error('invalid-external-id-pair',
|
|
_('You must have a _source_name and a _source_id column')))
|
|
|
|
def parse_header(self, head, column):
|
|
splitted = head.split()
|
|
try:
|
|
header = CsvHeader(column, splitted[0])
|
|
if header.name in self.headers_by_name:
|
|
self.add_error(
|
|
Error('duplicate-header', _('Header "%s" is duplicated') % header.name))
|
|
return
|
|
self.headers_by_name[header.name] = header
|
|
except IndexError:
|
|
header = CsvHeader(column)
|
|
else:
|
|
if header.name in SOURCE_COLUMNS:
|
|
if header.name == SOURCE_ID:
|
|
header.key = True
|
|
else:
|
|
try:
|
|
if header.name in ['email', 'first_name', 'last_name', 'username']:
|
|
User._meta.get_field(header.name)
|
|
header.field = True
|
|
if header.name == 'email':
|
|
# by default email are expected to be verified
|
|
header.verified = True
|
|
if header.name == 'email' and self.email_is_unique:
|
|
header.unique = True
|
|
if app_settings.A2_EMAIL_IS_UNIQUE:
|
|
header.globally_unique = True
|
|
if header.name == 'username' and self.username_is_unique:
|
|
header.unique = True
|
|
if app_settings.A2_USERNAME_IS_UNIQUE:
|
|
header.globally_unique = True
|
|
except FieldDoesNotExist:
|
|
pass
|
|
if not header.field:
|
|
try:
|
|
attribute = Attribute.objects.get(name=header.name) # NOQA: F841
|
|
header.attribute = True
|
|
except Attribute.DoesNotExist:
|
|
pass
|
|
|
|
self.headers.append(header)
|
|
|
|
if (not (header.field or header.attribute)
|
|
and header.name not in SOURCE_COLUMNS):
|
|
self.add_error(LineError('unknown-or-missing-attribute',
|
|
_('unknown or missing attribute "%s"') % head,
|
|
line=1, column=column))
|
|
return
|
|
|
|
for flag in splitted[1:]:
|
|
if header.name in SOURCE_COLUMNS:
|
|
self.add_error(LineError(
|
|
'flag-forbidden-on-source-columns',
|
|
_('You cannot set flags on _source_name and _source_id columns'),
|
|
line=1))
|
|
break
|
|
value = True
|
|
if flag.startswith('no-'):
|
|
value = False
|
|
flag = flag[3:]
|
|
flag = flag.replace('-', '_')
|
|
try:
|
|
if not getattr(attr.fields(CsvHeader), flag).metadata['flag']:
|
|
raise TypeError
|
|
setattr(header, flag, value)
|
|
except (AttributeError, TypeError, KeyError):
|
|
self.add_error(LineError('unknown-flag', _('unknown flag "%s"'), line=1, column=column))
|
|
|
|
def parse_rows(self):
|
|
base_form_class = ImportUserForm
|
|
if SOURCE_NAME in self.headers_by_name:
|
|
base_form_class = ImportUserFormWithExternalId
|
|
form_class = modelform_factory(User, fields=self.headers_by_name.keys(), form=base_form_class)
|
|
rows = self.rows = []
|
|
for i, row in enumerate(self.csv_importer.rows[1:]):
|
|
csv_row = self.parse_row(form_class, row, line=i + 2)
|
|
self.has_errors = self.has_errors or not(csv_row.is_valid)
|
|
rows.append(csv_row)
|
|
|
|
def parse_row(self, form_class, row, line):
|
|
data = {}
|
|
|
|
for header in self.headers:
|
|
try:
|
|
data[header.name] = row[header.column - 1]
|
|
except IndexError:
|
|
pass
|
|
|
|
form = form_class(data=data)
|
|
form.is_valid()
|
|
|
|
def get_form_errors(form, name):
|
|
return [Error('data-error', six.text_type(value)) for value in form.errors.get(name, [])]
|
|
|
|
cells = [
|
|
CsvCell(
|
|
line=line,
|
|
header=header,
|
|
value=data.get(header.name),
|
|
missing=header.name not in data,
|
|
errors=get_form_errors(form, header.name))
|
|
for header in self.headers]
|
|
cell_errors = any(bool(cell.errors) for cell in cells)
|
|
errors = get_form_errors(form, '__all__')
|
|
return CsvRow(
|
|
line=line,
|
|
cells=cells,
|
|
errors=errors,
|
|
is_valid=not bool(cell_errors or errors))
|
|
|
|
@property
|
|
def email_is_unique(self):
|
|
return app_settings.A2_EMAIL_IS_UNIQUE or self.ou.email_is_unique
|
|
|
|
@property
|
|
def username_is_unique(self):
|
|
return app_settings.A2_USERNAME_IS_UNIQUE or self.ou.username_is_unique
|
|
|
|
def check_unique_constraints(self, row, unique_map, user=None):
|
|
ou_users = User.objects.filter(ou=self.ou)
|
|
users = User.objects.all()
|
|
if user:
|
|
users = users.exclude(pk=user.pk)
|
|
ou_users = ou_users.exclude(pk=user.pk)
|
|
errors = []
|
|
for cell in row:
|
|
header = cell.header
|
|
if header.name == SOURCE_ID:
|
|
unique_key = (SOURCE_ID, row[SOURCE_NAME].value, cell.value)
|
|
elif header.key or header.globally_unique or header.unique:
|
|
unique_key = (header.name, cell.value)
|
|
else:
|
|
continue
|
|
if unique_key in unique_map:
|
|
errors.append(
|
|
Error('unique-constraint-failed',
|
|
_('Unique constraint on column "%(column)s" failed: '
|
|
'value already appear on line %(line)d') % {'column': header.name, 'line': row.line}))
|
|
else:
|
|
unique_map[unique_key] = row.line
|
|
|
|
for cell in row:
|
|
if (not cell.header.globally_unique and not cell.header.unique) or (user and not cell.header.update):
|
|
continue
|
|
qs = ou_users
|
|
if cell.header.globally_unique:
|
|
qs = users
|
|
if cell.header.field:
|
|
unique = not qs.filter(**{cell.header.name: cell.value}).exists()
|
|
elif cell.header.attribute:
|
|
atvs = AttributeValue.objects.filter(attribute__name=cell.header.name, content=cell.value)
|
|
unique = not qs.filter(attribute_values__in=atvs).exists()
|
|
if not unique:
|
|
errors.append(
|
|
Error('unique-constraint-failed', _('Unique constraint on column "%s" failed') % cell.header.name))
|
|
row.errors.extend(errors)
|
|
row.is_valid = row.is_valid and not bool(errors)
|
|
return not bool(errors)
|
|
|
|
@atomic
|
|
def do_import_row(self, row, unique_map):
|
|
if not row.is_valid:
|
|
return False
|
|
|
|
for header in self.headers:
|
|
if header.key:
|
|
header_key = header
|
|
break
|
|
else:
|
|
assert False, 'should not happen'
|
|
|
|
user = None
|
|
if header_key.name == SOURCE_ID:
|
|
# lookup by external id
|
|
source_name = row[SOURCE_NAME].value
|
|
source_id = row[SOURCE_ID].value
|
|
userexternalids = UserExternalId.objects.filter(source=source_name, external_id=source_id)
|
|
users = User.objects.filter(userexternalid__in=userexternalids)[:2]
|
|
else:
|
|
# lookup by field/attribute
|
|
key_value = row[header_key].value
|
|
if header_key.field:
|
|
users = User.objects.filter(
|
|
**{header_key.name: key_value})
|
|
elif header_key.attribute:
|
|
atvs = AttributeValue.objects.filter(attribute__name=header_key.name, content=key_value)
|
|
users = User.objects.filter(attribute_values__in=atvs)
|
|
users = users[:2]
|
|
|
|
if users:
|
|
row.action = 'update'
|
|
else:
|
|
row.action = 'create'
|
|
|
|
if len(users) > 1:
|
|
row.errors.append(
|
|
Error('key-matches-too-many-users',
|
|
_('Key value "%s" matches too many users') % key_value))
|
|
return False
|
|
|
|
user = None
|
|
if users:
|
|
user = users[0]
|
|
|
|
if not self.check_unique_constraints(row, unique_map, user=user):
|
|
return False
|
|
|
|
if not user:
|
|
user = User(ou=self.ou)
|
|
|
|
for cell in row.cells:
|
|
if not cell.header.field:
|
|
continue
|
|
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
|
|
if getattr(user, cell.header.name) != cell.value:
|
|
setattr(user, cell.header.name, cell.value)
|
|
if cell.header.name == 'email' and cell.header.verified:
|
|
user.email_verified = True
|
|
cell.action = 'updated'
|
|
continue
|
|
cell.action = 'nothing'
|
|
|
|
user.save()
|
|
|
|
if header_key.name == SOURCE_ID:
|
|
try:
|
|
UserExternalId.objects.create(user=user,
|
|
source=source_name,
|
|
external_id=source_id)
|
|
except IntegrityError:
|
|
# should never happen since we have a unique index...
|
|
source_full_id = '%s.%s' % (source_name, source_id)
|
|
self.errors.append(
|
|
Error('external-id-already-exist',
|
|
_('External id "%s" already exists') % source_full_id))
|
|
raise CancelImport
|
|
|
|
for cell in row.cells:
|
|
if cell.header.field or not cell.header.attribute:
|
|
continue
|
|
if (row.action == 'create' and cell.header.create) or (row.action == 'update' and cell.header.update):
|
|
attributes = user.attributes
|
|
if cell.header.verified:
|
|
attributes = user.verified_attributes
|
|
if getattr(attributes, cell.header.name) != cell.value:
|
|
setattr(attributes, cell.header.name, cell.value)
|
|
cell.action = 'updated'
|
|
continue
|
|
cell.action = 'nothing'
|
|
|
|
setattr(self, row.action + 'd', getattr(self, row.action + 'd') + 1)
|
|
return True
|