api: add find duplicate users endpoint (#46424)
This commit is contained in:
parent
458712039c
commit
e6b2e5dbf4
|
@ -523,6 +523,10 @@ class BaseUserSerializer(serializers.ModelSerializer):
|
|||
exclude = ('user_permissions', 'groups')
|
||||
|
||||
|
||||
class DuplicateUserSerializer(BaseUserSerializer):
|
||||
duplicate_distance = serializers.FloatField(required=True, source='dist')
|
||||
|
||||
|
||||
class SlugFromNameDefault:
|
||||
requires_context = False
|
||||
serializer_instance = None
|
||||
|
@ -800,6 +804,35 @@ class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin
|
|||
utils.send_email_change_email(user, serializer.validated_data['email'], request=request)
|
||||
return Response({'result': 1})
|
||||
|
||||
@action(detail=False, methods=['get'],
|
||||
permission_classes=(DjangoPermission('custom_user.search_user'),))
|
||||
def find_duplicates(self, request):
|
||||
serializer = self.get_serializer(data=request.query_params, partial=True)
|
||||
if not serializer.is_valid():
|
||||
response = {
|
||||
'results': [],
|
||||
'errors': serializer.errors
|
||||
}
|
||||
return Response(response, status.HTTP_400_BAD_REQUEST)
|
||||
data = serializer.validated_data
|
||||
|
||||
first_name = data.get('first_name')
|
||||
last_name = data.get('last_name')
|
||||
if not (first_name and last_name):
|
||||
response = {
|
||||
'results': [],
|
||||
'errors': 'first_name and last_name parameters are mandatory.',
|
||||
}
|
||||
return Response(response, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
attributes = data.pop('attributes', {})
|
||||
birthdate = attributes.get('birthdate')
|
||||
qs = User.objects.find_duplicates(first_name, last_name, birthdate=birthdate)
|
||||
|
||||
return Response({
|
||||
'results': DuplicateUserSerializer(qs, many=True).data,
|
||||
})
|
||||
|
||||
|
||||
class RolesAPI(api_mixins.GetOrCreateMixinView, ExceptionHandlerMixin, ModelViewSet):
|
||||
permission_classes = (permissions.IsAuthenticated,)
|
||||
|
|
|
@ -342,6 +342,12 @@ default_settings = dict(
|
|||
A2_TOKEN_EXISTS_WARNING=Setting(
|
||||
default=True,
|
||||
definition='If an active token exists, warn user before generating a new one.'),
|
||||
A2_DUPLICATES_THRESHOLD=Setting(
|
||||
default=0.7,
|
||||
definition='Trigram similarity threshold for considering user as duplicate.'),
|
||||
A2_DUPLICATES_BIRTHDATE_BONUS=Setting(
|
||||
default=0.3,
|
||||
definition='Bonus in case of birthdate match (no bonus is 0, max is 1).'),
|
||||
)
|
||||
|
||||
app_settings = AppSettings(default_settings)
|
||||
|
|
|
@ -16,14 +16,20 @@
|
|||
|
||||
import datetime
|
||||
import logging
|
||||
import unicodedata
|
||||
|
||||
from django.db import models, transaction
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.contrib.postgres.search import TrigramDistance
|
||||
from django.db import models, transaction, connection
|
||||
from django.db.models import F, Value, FloatField, Subquery, OuterRef
|
||||
from django.db.models.functions import Lower, Coalesce
|
||||
from django.utils import six
|
||||
from django.utils import timezone
|
||||
from django.contrib.auth.models import BaseUserManager
|
||||
|
||||
from authentic2 import app_settings
|
||||
from authentic2.models import Attribute
|
||||
from authentic2.models import Attribute, AttributeValue
|
||||
from authentic2.utils.lookups import Unaccent, ImmutableConcat
|
||||
|
||||
|
||||
class UserQuerySet(models.QuerySet):
|
||||
|
@ -72,6 +78,38 @@ class UserQuerySet(models.QuerySet):
|
|||
self = self.distinct()
|
||||
return self
|
||||
|
||||
def find_duplicates(self, first_name, last_name, birthdate=None):
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SET pg_trgm.similarity_threshold = %f" % app_settings.A2_DUPLICATES_THRESHOLD
|
||||
)
|
||||
|
||||
name = '%s %s' % (first_name, last_name)
|
||||
name = unicodedata.normalize('NFKD', name).encode('ascii', 'ignore').decode('ascii').lower()
|
||||
|
||||
qs = self.annotate(name=Lower(Unaccent(ImmutableConcat('first_name', Value(' '), 'last_name'))))
|
||||
qs = qs.filter(name__trigram_similar=name)
|
||||
qs = qs.annotate(dist=TrigramDistance('name', name))
|
||||
qs = qs.order_by('dist')
|
||||
qs = qs[:5]
|
||||
|
||||
# alter distance according to additionnal parameters
|
||||
if birthdate:
|
||||
bonus = app_settings.A2_DUPLICATES_BIRTHDATE_BONUS
|
||||
content_type = ContentType.objects.get_for_model(self.model)
|
||||
same_birthdate = AttributeValue.objects.filter(
|
||||
object_id=OuterRef('pk'),
|
||||
content_type=content_type,
|
||||
attribute__kind='birthdate',
|
||||
content=birthdate
|
||||
).annotate(bonus=Value(1 - bonus, output_field=FloatField()))
|
||||
qs = qs.annotate(dist=Coalesce(
|
||||
Subquery(same_birthdate.values('bonus'), output_field=FloatField()) * F('dist'),
|
||||
F('dist')
|
||||
))
|
||||
|
||||
return qs
|
||||
|
||||
@transaction.atomic
|
||||
def cleanup(self, threshold=600, timestamp=None):
|
||||
'''Delete all deleted users for more than 10 minutes.'''
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.18 on 2020-09-17 15:38
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import migrations, transaction
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.utils import InternalError, OperationalError, ProgrammingError
|
||||
|
||||
|
||||
class SafeExtensionOperation(Operation):
|
||||
reversible = True
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
pass
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if schema_editor.connection.vendor != 'postgresql':
|
||||
return
|
||||
try:
|
||||
with transaction.atomic():
|
||||
try:
|
||||
schema_editor.execute('CREATE EXTENSION IF NOT EXISTS %s SCHEMA public' % self.name)
|
||||
except (OperationalError, ProgrammingError):
|
||||
# OperationalError if the extension is not available
|
||||
# ProgrammingError in case of denied permission
|
||||
RunSQLIfExtension.extensions_installed = False
|
||||
except InternalError:
|
||||
# InternalError (current transaction is aborted, commands ignored
|
||||
# until end of transaction block) would be raised when django-
|
||||
# tenant-schemas set search_path.
|
||||
RunSQLIfExtension.extensions_installed = False
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
try:
|
||||
with transaction.atomic():
|
||||
schema_editor.execute('DROP EXTENSION IF EXISTS %s' % self.name)
|
||||
except InternalError:
|
||||
# Raised when other objects depend on the extension. This happens in a multitenant
|
||||
# context, where extension in installed in schema "public" but referenced in others (via
|
||||
# public.gist_trgm_ops). In this case, do nothing, as the query should be successful
|
||||
# when last tenant is processed.
|
||||
pass
|
||||
|
||||
|
||||
class RunSQLIfExtension(migrations.RunSQL):
|
||||
extensions_installed = True
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'sql' and not self.extensions_installed:
|
||||
return migrations.RunSQL.noop
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
|
||||
class UnaccentExtension(SafeExtensionOperation):
|
||||
name = 'unaccent'
|
||||
|
||||
|
||||
class TrigramExtension(SafeExtensionOperation):
|
||||
name = 'pg_trgm'
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('authentic2', '0027_remove_deleteduser'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
TrigramExtension(),
|
||||
UnaccentExtension(),
|
||||
RunSQLIfExtension(
|
||||
sql=["CREATE OR REPLACE FUNCTION immutable_unaccent(text) RETURNS varchar AS $$ "
|
||||
"SELECT public.unaccent('public.unaccent',$1::text); $$ LANGUAGE 'sql' IMMUTABLE"],
|
||||
reverse_sql=['DROP FUNCTION IF EXISTS immutable_unaccent(text)']
|
||||
),
|
||||
RunSQLIfExtension(
|
||||
sql=["CREATE INDEX custom_user_name_gist_idx ON custom_user_user USING gist "
|
||||
"(LOWER(immutable_unaccent(first_name || ' ' || last_name)) public.gist_trgm_ops)"],
|
||||
reverse_sql=['DROP INDEX IF EXISTS custom_user_name_gist_idx'],
|
||||
),
|
||||
]
|
|
@ -128,6 +128,7 @@ INSTALLED_APPS = (
|
|||
'django.contrib.messages',
|
||||
'django.contrib.admin',
|
||||
'django.contrib.humanize',
|
||||
'django.contrib.postgres',
|
||||
'django_select2',
|
||||
'django_tables2',
|
||||
'mellon',
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
from django.contrib.postgres.lookups import Unaccent as PGUnaccent
|
||||
from django.db.models import Func
|
||||
from django.db.models.functions import Concat, ConcatPair as DjConcatPair
|
||||
|
||||
class Unaccent(PGUnaccent):
|
||||
function = 'immutable_unaccent'
|
||||
|
||||
|
||||
class ConcatPair(DjConcatPair):
|
||||
"""Django ConcatPair does not implement as_postgresql, using CONCAT as a default.
|
||||
|
||||
But we need immutable concatenation, || being immutable while CONCAT is not.
|
||||
"""
|
||||
def as_postgresql(self, compiler, connection):
|
||||
return super(ConcatPair, self).as_sql(
|
||||
compiler, connection, template='%(expressions)s', arg_joiner=' || '
|
||||
)
|
||||
|
||||
|
||||
class ImmutableConcat(Concat):
|
||||
def _paired(self, expressions):
|
||||
if len(expressions) == 2:
|
||||
return ConcatPair(*expressions)
|
||||
return ConcatPair(expressions[0], self._paired(expressions[1:]))
|
|
@ -1839,3 +1839,89 @@ def test_free_text_search(app, admin, settings):
|
|||
resp = app.get('/api/users/?q=10/02/1982')
|
||||
assert len(resp.json['results']) == 1
|
||||
assert resp.json['results'][0]['id'] == user.id
|
||||
|
||||
|
||||
def test_find_duplicates(app, admin, settings):
|
||||
settings.LANGUAGE_CODE = 'fr'
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
|
||||
first_name = 'Jean-Kévin'
|
||||
last_name = 'Du Château'
|
||||
user = User.objects.create(first_name=first_name, last_name=last_name)
|
||||
|
||||
exact_match = {
|
||||
'first_name': first_name,
|
||||
'last_name': last_name,
|
||||
}
|
||||
resp = app.get('/api/users/find_duplicates/', params=exact_match)
|
||||
assert resp.json['results'][0]['id'] == user.id
|
||||
assert resp.json['results'][0]['duplicate_distance'] == 0
|
||||
|
||||
typo = {
|
||||
'first_name': 'Jean Kévin',
|
||||
'last_name': 'Du Châtau',
|
||||
}
|
||||
resp = app.get('/api/users/find_duplicates/', params=typo)
|
||||
assert resp.json['results'][0]['id'] == user.id
|
||||
assert resp.json['results'][0]['duplicate_distance'] > 0
|
||||
|
||||
typo = {
|
||||
'first_name': 'Jean Kévin',
|
||||
'last_name': 'Château',
|
||||
}
|
||||
resp = app.get('/api/users/find_duplicates/', params=typo)
|
||||
assert resp.json['results'][0]['id'] == user.id
|
||||
|
||||
other_person = {
|
||||
'first_name': 'Jean-Kévin',
|
||||
'last_name': 'Du Chêne',
|
||||
}
|
||||
user = User.objects.create(first_name='Éléonore', last_name='âêîôû')
|
||||
resp = app.get('/api/users/find_duplicates/', params=other_person)
|
||||
assert len(resp.json['results']) == 0
|
||||
|
||||
other_person = {
|
||||
'first_name': 'Pierre',
|
||||
'last_name': 'Du Château',
|
||||
}
|
||||
resp = app.get('/api/users/find_duplicates/', params=other_person)
|
||||
assert len(resp.json['results']) == 0
|
||||
|
||||
|
||||
def test_find_duplicates_unaccent(app, admin, settings):
|
||||
settings.LANGUAGE_CODE = 'fr'
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
|
||||
user = User.objects.create(first_name='Éléonore', last_name='âêîôû')
|
||||
|
||||
resp = app.get('/api/users/find_duplicates/', params={'first_name': 'Eleonore', 'last_name': 'aeiou'})
|
||||
assert resp.json['results'][0]['id'] == user.id
|
||||
|
||||
|
||||
def test_find_duplicates_birthdate(app, admin, settings):
|
||||
settings.LANGUAGE_CODE = 'fr'
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
|
||||
Attribute.objects.create(kind='birthdate', name='birthdate', label='birthdate', required=False, searchable=True)
|
||||
|
||||
user = User.objects.create(first_name='Jean', last_name='Dupont')
|
||||
homonym = User.objects.create(first_name='Jean', last_name='Dupont')
|
||||
user.attributes.birthdate = datetime.date(1980, 1, 2)
|
||||
homonym.attributes.birthdate = datetime.date(1980, 1, 3)
|
||||
|
||||
params = {
|
||||
'first_name': 'Jeanne',
|
||||
'last_name': 'Dupont',
|
||||
}
|
||||
resp = app.get('/api/users/find_duplicates/', params=params)
|
||||
assert len(resp.json['results']) == 2
|
||||
|
||||
params['birthdate'] = '1980-01-2',
|
||||
resp = app.get('/api/users/find_duplicates/', params=params)
|
||||
assert len(resp.json['results']) == 2
|
||||
assert resp.json['results'][0]['id'] == user.pk
|
||||
|
||||
params['birthdate'] = '1980-01-3',
|
||||
resp = app.get('/api/users/find_duplicates/', params=params)
|
||||
assert len(resp.json['results']) == 2
|
||||
assert resp.json['results'][0]['id'] == homonym.pk
|
||||
|
|
|
@ -64,5 +64,30 @@ def large_userbase(db, freezer):
|
|||
for user_id in user_ids)
|
||||
|
||||
|
||||
def test_large_userbase(large_userbase):
|
||||
pass
|
||||
def test_large_userbase_find_duplicates(large_userbase, app, admin):
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
|
||||
user = User.objects.first()
|
||||
params = {
|
||||
'first_name': user.first_name,
|
||||
'last_name': user.last_name,
|
||||
}
|
||||
|
||||
for i in range(100):
|
||||
resp = app.get('/api/users/find_duplicates/', params=params)
|
||||
assert len(resp.json['results']) >= 1
|
||||
|
||||
|
||||
def test_large_userbase_find_duplicates_with_birthdate(large_userbase, app, admin):
|
||||
app.authorization = ('Basic', (admin.username, admin.username))
|
||||
|
||||
user = User.objects.first()
|
||||
params = {
|
||||
'first_name': user.first_name,
|
||||
'last_name': user.last_name,
|
||||
'birthdate': str(user.attributes.birthdate),
|
||||
}
|
||||
|
||||
for i in range(100):
|
||||
resp = app.get('/api/users/find_duplicates/', params=params)
|
||||
assert len(resp.json['results']) >= 1
|
||||
|
|
|
@ -14,6 +14,10 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import mock
|
||||
|
||||
from django.db.utils import ProgrammingError
|
||||
|
||||
|
||||
def test_migration_0019_add_user_deleted(transactional_db, migration):
|
||||
old_apps = migration.before([
|
||||
|
@ -33,3 +37,15 @@ def test_migration_0019_add_user_deleted(transactional_db, migration):
|
|||
NewUser = new_apps.get_model('custom_user', 'User')
|
||||
new_user = NewUser.objects.get(id=user.id)
|
||||
assert new_user.deleted
|
||||
|
||||
|
||||
def test_migration_0028_trigram_unaccent_index(transactional_db, migration):
|
||||
migration.before([('authentic2', '0027_remove_deleteduser')])
|
||||
|
||||
def programming_error(*args, **kwargs):
|
||||
raise ProgrammingError
|
||||
|
||||
# when an error occurs, ensure migration runs anyway without complaining
|
||||
with mock.patch('django.db.backends.postgresql.schema.DatabaseSchemaEditor.execute') as mocked:
|
||||
mocked.side_effect = programming_error
|
||||
migration.apply([('authentic2', '0028_trigram_unaccent_index')])
|
||||
|
|
Loading…
Reference in New Issue