authentic/src/authentic2/manager/user_export.py

149 lines
4.5 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2021 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import collections
import datetime
import os
import pickle
import uuid
import tablib
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from django.core.files.storage import default_storage
from authentic2.manager.resources import UserResource
from authentic2.models import Attribute, AttributeValue
from authentic2.utils.misc import batch_queryset
def get_user_dataset(qs):
user_resource = UserResource()
fields = user_resource._meta.export_order + ('email_verified', 'is_active', 'modified')
attributes = [attr.name for attr in Attribute.objects.all()]
headers = fields + tuple('attribute_%s' % attr for attr in attributes)
at_mapping = {a.id: a for a in Attribute.objects.all()}
avs = (
AttributeValue.objects.filter(content_type=ContentType.objects.get_for_model(get_user_model()))
.filter(attribute__disabled=False)
.values()
)
user_attrs = collections.defaultdict(dict)
for av in avs:
user_attrs[av['object_id']][at_mapping[av['attribute_id']].name] = av['content']
def iso(rec):
if rec is None or rec == {}:
return ''
if hasattr(rec, 'strftime'):
if isinstance(rec, datetime.datetime):
_format = "%Y-%m-%d %H:%M:%S"
else:
_format = "%Y-%m-%d"
return rec.strftime(_format)
return rec
def create_record(user):
record = []
for field in fields:
if field == 'roles':
value = user_resource.dehydrate_roles(user)
else:
value = getattr(user, field)
record.append(value)
attr_d = user_attrs[user.pk]
for attr in attributes:
record.append(attr_d.get(attr))
return [iso(x) for x in record]
dataset = tablib.Dataset(headers=headers)
for user in qs:
dataset.append(create_record(user))
return dataset
class UserExport:
def __init__(self, uuid):
self.uuid = uuid
self.path = os.path.join(self.base_path(), self.uuid)
self.export_path = os.path.join(self.path, 'export.csv')
self.progress_path = os.path.join(self.path, 'progress')
@classmethod
def base_path(self):
path = default_storage.path('user_exports')
if not os.path.exists(path):
os.makedirs(path)
return path
@property
def exists(self):
return os.path.exists(self.path)
@classmethod
def new(cls):
export = cls(str(uuid.uuid4()))
os.makedirs(export.path)
return export
@property
def csv(self):
return open(self.export_path)
def set_export_content(self, content):
with open(self.export_path, 'w') as f:
f.write(content)
@property
def progress(self):
progress = 0
if os.path.exists(self.progress_path):
with open(self.progress_path) as f:
progress = f.read()
return int(progress) if progress else 0
def set_progress(self, progress):
with open(self.progress_path, 'w') as f:
f.write(str(progress))
def export_users_to_file(uuid, query):
export = UserExport(uuid)
qs = get_user_model().objects.all()
qs.set_trigram_similarity_threshold()
qs.query = query
qs = qs.select_related('ou')
qs = qs.prefetch_related('roles', 'roles__parent_relation__parent')
count = qs.count() or 1
def callback(progress):
export.set_progress(round(progress / count * 100))
qs = batch_queryset(qs, progress_callback=callback)
dataset = get_user_dataset(qs)
if hasattr(dataset, 'csv'):
# compatiblity for tablib < 0.11
csv = dataset.csv
else:
csv = dataset.export('csv')
export.set_export_content(csv)
export.set_progress(100)