api: add find duplicate users endpoint (#46424)

This commit is contained in:
Valentin Deniaud 2020-09-21 15:11:52 +02:00
parent 458712039c
commit e6b2e5dbf4
9 changed files with 314 additions and 4 deletions

View File

@ -523,6 +523,10 @@ class BaseUserSerializer(serializers.ModelSerializer):
exclude = ('user_permissions', 'groups')
class DuplicateUserSerializer(BaseUserSerializer):
duplicate_distance = serializers.FloatField(required=True, source='dist')
class SlugFromNameDefault:
requires_context = False
serializer_instance = None
@ -800,6 +804,35 @@ class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin
utils.send_email_change_email(user, serializer.validated_data['email'], request=request)
return Response({'result': 1})
@action(detail=False, methods=['get'],
permission_classes=(DjangoPermission('custom_user.search_user'),))
def find_duplicates(self, request):
serializer = self.get_serializer(data=request.query_params, partial=True)
if not serializer.is_valid():
response = {
'results': [],
'errors': serializer.errors
}
return Response(response, status.HTTP_400_BAD_REQUEST)
data = serializer.validated_data
first_name = data.get('first_name')
last_name = data.get('last_name')
if not (first_name and last_name):
response = {
'results': [],
'errors': 'first_name and last_name parameters are mandatory.',
}
return Response(response, status.HTTP_400_BAD_REQUEST)
attributes = data.pop('attributes', {})
birthdate = attributes.get('birthdate')
qs = User.objects.find_duplicates(first_name, last_name, birthdate=birthdate)
return Response({
'results': DuplicateUserSerializer(qs, many=True).data,
})
class RolesAPI(api_mixins.GetOrCreateMixinView, ExceptionHandlerMixin, ModelViewSet):
permission_classes = (permissions.IsAuthenticated,)

View File

@ -342,6 +342,12 @@ default_settings = dict(
A2_TOKEN_EXISTS_WARNING=Setting(
default=True,
definition='If an active token exists, warn user before generating a new one.'),
A2_DUPLICATES_THRESHOLD=Setting(
default=0.7,
definition='Trigram similarity threshold for considering user as duplicate.'),
A2_DUPLICATES_BIRTHDATE_BONUS=Setting(
default=0.3,
definition='Bonus in case of birthdate match (no bonus is 0, max is 1).'),
)
app_settings = AppSettings(default_settings)

View File

@ -16,14 +16,20 @@
import datetime
import logging
import unicodedata
from django.db import models, transaction
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.search import TrigramDistance
from django.db import models, transaction, connection
from django.db.models import F, Value, FloatField, Subquery, OuterRef
from django.db.models.functions import Lower, Coalesce
from django.utils import six
from django.utils import timezone
from django.contrib.auth.models import BaseUserManager
from authentic2 import app_settings
from authentic2.models import Attribute
from authentic2.models import Attribute, AttributeValue
from authentic2.utils.lookups import Unaccent, ImmutableConcat
class UserQuerySet(models.QuerySet):
@ -72,6 +78,38 @@ class UserQuerySet(models.QuerySet):
self = self.distinct()
return self
def find_duplicates(self, first_name, last_name, birthdate=None):
with connection.cursor() as cursor:
cursor.execute(
"SET pg_trgm.similarity_threshold = %f" % app_settings.A2_DUPLICATES_THRESHOLD
)
name = '%s %s' % (first_name, last_name)
name = unicodedata.normalize('NFKD', name).encode('ascii', 'ignore').decode('ascii').lower()
qs = self.annotate(name=Lower(Unaccent(ImmutableConcat('first_name', Value(' '), 'last_name'))))
qs = qs.filter(name__trigram_similar=name)
qs = qs.annotate(dist=TrigramDistance('name', name))
qs = qs.order_by('dist')
qs = qs[:5]
# alter distance according to additionnal parameters
if birthdate:
bonus = app_settings.A2_DUPLICATES_BIRTHDATE_BONUS
content_type = ContentType.objects.get_for_model(self.model)
same_birthdate = AttributeValue.objects.filter(
object_id=OuterRef('pk'),
content_type=content_type,
attribute__kind='birthdate',
content=birthdate
).annotate(bonus=Value(1 - bonus, output_field=FloatField()))
qs = qs.annotate(dist=Coalesce(
Subquery(same_birthdate.values('bonus'), output_field=FloatField()) * F('dist'),
F('dist')
))
return qs
@transaction.atomic
def cleanup(self, threshold=600, timestamp=None):
'''Delete all deleted users for more than 10 minutes.'''

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.18 on 2020-09-17 15:38
from __future__ import unicode_literals
from django.db import migrations, transaction
from django.db.migrations.operations.base import Operation
from django.db.utils import InternalError, OperationalError, ProgrammingError
class SafeExtensionOperation(Operation):
reversible = True
def state_forwards(self, app_label, state):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if schema_editor.connection.vendor != 'postgresql':
return
try:
with transaction.atomic():
try:
schema_editor.execute('CREATE EXTENSION IF NOT EXISTS %s SCHEMA public' % self.name)
except (OperationalError, ProgrammingError):
# OperationalError if the extension is not available
# ProgrammingError in case of denied permission
RunSQLIfExtension.extensions_installed = False
except InternalError:
# InternalError (current transaction is aborted, commands ignored
# until end of transaction block) would be raised when django-
# tenant-schemas set search_path.
RunSQLIfExtension.extensions_installed = False
def database_backwards(self, app_label, schema_editor, from_state, to_state):
try:
with transaction.atomic():
schema_editor.execute('DROP EXTENSION IF EXISTS %s' % self.name)
except InternalError:
# Raised when other objects depend on the extension. This happens in a multitenant
# context, where extension in installed in schema "public" but referenced in others (via
# public.gist_trgm_ops). In this case, do nothing, as the query should be successful
# when last tenant is processed.
pass
class RunSQLIfExtension(migrations.RunSQL):
extensions_installed = True
def __getattribute__(self, name):
if name == 'sql' and not self.extensions_installed:
return migrations.RunSQL.noop
return object.__getattribute__(self, name)
class UnaccentExtension(SafeExtensionOperation):
name = 'unaccent'
class TrigramExtension(SafeExtensionOperation):
name = 'pg_trgm'
class Migration(migrations.Migration):
dependencies = [
('authentic2', '0027_remove_deleteduser'),
]
operations = [
TrigramExtension(),
UnaccentExtension(),
RunSQLIfExtension(
sql=["CREATE OR REPLACE FUNCTION immutable_unaccent(text) RETURNS varchar AS $$ "
"SELECT public.unaccent('public.unaccent',$1::text); $$ LANGUAGE 'sql' IMMUTABLE"],
reverse_sql=['DROP FUNCTION IF EXISTS immutable_unaccent(text)']
),
RunSQLIfExtension(
sql=["CREATE INDEX custom_user_name_gist_idx ON custom_user_user USING gist "
"(LOWER(immutable_unaccent(first_name || ' ' || last_name)) public.gist_trgm_ops)"],
reverse_sql=['DROP INDEX IF EXISTS custom_user_name_gist_idx'],
),
]

View File

@ -128,6 +128,7 @@ INSTALLED_APPS = (
'django.contrib.messages',
'django.contrib.admin',
'django.contrib.humanize',
'django.contrib.postgres',
'django_select2',
'django_tables2',
'mellon',

View File

@ -0,0 +1,24 @@
from django.contrib.postgres.lookups import Unaccent as PGUnaccent
from django.db.models import Func
from django.db.models.functions import Concat, ConcatPair as DjConcatPair
class Unaccent(PGUnaccent):
function = 'immutable_unaccent'
class ConcatPair(DjConcatPair):
"""Django ConcatPair does not implement as_postgresql, using CONCAT as a default.
But we need immutable concatenation, || being immutable while CONCAT is not.
"""
def as_postgresql(self, compiler, connection):
return super(ConcatPair, self).as_sql(
compiler, connection, template='%(expressions)s', arg_joiner=' || '
)
class ImmutableConcat(Concat):
def _paired(self, expressions):
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))

View File

@ -1839,3 +1839,89 @@ def test_free_text_search(app, admin, settings):
resp = app.get('/api/users/?q=10/02/1982')
assert len(resp.json['results']) == 1
assert resp.json['results'][0]['id'] == user.id
def test_find_duplicates(app, admin, settings):
settings.LANGUAGE_CODE = 'fr'
app.authorization = ('Basic', (admin.username, admin.username))
first_name = 'Jean-Kévin'
last_name = 'Du Château'
user = User.objects.create(first_name=first_name, last_name=last_name)
exact_match = {
'first_name': first_name,
'last_name': last_name,
}
resp = app.get('/api/users/find_duplicates/', params=exact_match)
assert resp.json['results'][0]['id'] == user.id
assert resp.json['results'][0]['duplicate_distance'] == 0
typo = {
'first_name': 'Jean Kévin',
'last_name': 'Du Châtau',
}
resp = app.get('/api/users/find_duplicates/', params=typo)
assert resp.json['results'][0]['id'] == user.id
assert resp.json['results'][0]['duplicate_distance'] > 0
typo = {
'first_name': 'Jean Kévin',
'last_name': 'Château',
}
resp = app.get('/api/users/find_duplicates/', params=typo)
assert resp.json['results'][0]['id'] == user.id
other_person = {
'first_name': 'Jean-Kévin',
'last_name': 'Du Chêne',
}
user = User.objects.create(first_name='Éléonore', last_name='âêîôû')
resp = app.get('/api/users/find_duplicates/', params=other_person)
assert len(resp.json['results']) == 0
other_person = {
'first_name': 'Pierre',
'last_name': 'Du Château',
}
resp = app.get('/api/users/find_duplicates/', params=other_person)
assert len(resp.json['results']) == 0
def test_find_duplicates_unaccent(app, admin, settings):
settings.LANGUAGE_CODE = 'fr'
app.authorization = ('Basic', (admin.username, admin.username))
user = User.objects.create(first_name='Éléonore', last_name='âêîôû')
resp = app.get('/api/users/find_duplicates/', params={'first_name': 'Eleonore', 'last_name': 'aeiou'})
assert resp.json['results'][0]['id'] == user.id
def test_find_duplicates_birthdate(app, admin, settings):
settings.LANGUAGE_CODE = 'fr'
app.authorization = ('Basic', (admin.username, admin.username))
Attribute.objects.create(kind='birthdate', name='birthdate', label='birthdate', required=False, searchable=True)
user = User.objects.create(first_name='Jean', last_name='Dupont')
homonym = User.objects.create(first_name='Jean', last_name='Dupont')
user.attributes.birthdate = datetime.date(1980, 1, 2)
homonym.attributes.birthdate = datetime.date(1980, 1, 3)
params = {
'first_name': 'Jeanne',
'last_name': 'Dupont',
}
resp = app.get('/api/users/find_duplicates/', params=params)
assert len(resp.json['results']) == 2
params['birthdate'] = '1980-01-2',
resp = app.get('/api/users/find_duplicates/', params=params)
assert len(resp.json['results']) == 2
assert resp.json['results'][0]['id'] == user.pk
params['birthdate'] = '1980-01-3',
resp = app.get('/api/users/find_duplicates/', params=params)
assert len(resp.json['results']) == 2
assert resp.json['results'][0]['id'] == homonym.pk

View File

@ -64,5 +64,30 @@ def large_userbase(db, freezer):
for user_id in user_ids)
def test_large_userbase(large_userbase):
pass
def test_large_userbase_find_duplicates(large_userbase, app, admin):
app.authorization = ('Basic', (admin.username, admin.username))
user = User.objects.first()
params = {
'first_name': user.first_name,
'last_name': user.last_name,
}
for i in range(100):
resp = app.get('/api/users/find_duplicates/', params=params)
assert len(resp.json['results']) >= 1
def test_large_userbase_find_duplicates_with_birthdate(large_userbase, app, admin):
app.authorization = ('Basic', (admin.username, admin.username))
user = User.objects.first()
params = {
'first_name': user.first_name,
'last_name': user.last_name,
'birthdate': str(user.attributes.birthdate),
}
for i in range(100):
resp = app.get('/api/users/find_duplicates/', params=params)
assert len(resp.json['results']) >= 1

View File

@ -14,6 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import mock
from django.db.utils import ProgrammingError
def test_migration_0019_add_user_deleted(transactional_db, migration):
old_apps = migration.before([
@ -33,3 +37,15 @@ def test_migration_0019_add_user_deleted(transactional_db, migration):
NewUser = new_apps.get_model('custom_user', 'User')
new_user = NewUser.objects.get(id=user.id)
assert new_user.deleted
def test_migration_0028_trigram_unaccent_index(transactional_db, migration):
migration.before([('authentic2', '0027_remove_deleteduser')])
def programming_error(*args, **kwargs):
raise ProgrammingError
# when an error occurs, ensure migration runs anyway without complaining
with mock.patch('django.db.backends.postgresql.schema.DatabaseSchemaEditor.execute') as mocked:
mocked.side_effect = programming_error
migration.apply([('authentic2', '0028_trigram_unaccent_index')])