From e6b2e5dbf4e956461d36528cae4c0760edd50e31 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 21 Sep 2020 15:11:52 +0200 Subject: [PATCH] api: add find duplicate users endpoint (#46424) --- src/authentic2/api_views.py | 33 +++++++ src/authentic2/app_settings.py | 6 ++ src/authentic2/custom_user/managers.py | 42 ++++++++- .../migrations/0028_trigram_unaccent_index.py | 81 +++++++++++++++++ src/authentic2/settings.py | 1 + src/authentic2/utils/lookups.py | 24 ++++++ tests/test_api.py | 86 +++++++++++++++++++ tests/test_large_userbase.py | 29 ++++++- tests/test_migrations.py | 16 ++++ 9 files changed, 314 insertions(+), 4 deletions(-) create mode 100644 src/authentic2/migrations/0028_trigram_unaccent_index.py create mode 100644 src/authentic2/utils/lookups.py diff --git a/src/authentic2/api_views.py b/src/authentic2/api_views.py index cabf579d7..9adebefe1 100644 --- a/src/authentic2/api_views.py +++ b/src/authentic2/api_views.py @@ -523,6 +523,10 @@ class BaseUserSerializer(serializers.ModelSerializer): exclude = ('user_permissions', 'groups') +class DuplicateUserSerializer(BaseUserSerializer): + duplicate_distance = serializers.FloatField(required=True, source='dist') + + class SlugFromNameDefault: requires_context = False serializer_instance = None @@ -800,6 +804,35 @@ class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin utils.send_email_change_email(user, serializer.validated_data['email'], request=request) return Response({'result': 1}) + @action(detail=False, methods=['get'], + permission_classes=(DjangoPermission('custom_user.search_user'),)) + def find_duplicates(self, request): + serializer = self.get_serializer(data=request.query_params, partial=True) + if not serializer.is_valid(): + response = { + 'results': [], + 'errors': serializer.errors + } + return Response(response, status.HTTP_400_BAD_REQUEST) + data = serializer.validated_data + + first_name = data.get('first_name') + last_name = data.get('last_name') + if not (first_name and last_name): + response = { + 'results': [], + 'errors': 'first_name and last_name parameters are mandatory.', + } + return Response(response, status.HTTP_400_BAD_REQUEST) + + attributes = data.pop('attributes', {}) + birthdate = attributes.get('birthdate') + qs = User.objects.find_duplicates(first_name, last_name, birthdate=birthdate) + + return Response({ + 'results': DuplicateUserSerializer(qs, many=True).data, + }) + class RolesAPI(api_mixins.GetOrCreateMixinView, ExceptionHandlerMixin, ModelViewSet): permission_classes = (permissions.IsAuthenticated,) diff --git a/src/authentic2/app_settings.py b/src/authentic2/app_settings.py index 1a95f4fa1..949cdd3fc 100644 --- a/src/authentic2/app_settings.py +++ b/src/authentic2/app_settings.py @@ -342,6 +342,12 @@ default_settings = dict( A2_TOKEN_EXISTS_WARNING=Setting( default=True, definition='If an active token exists, warn user before generating a new one.'), + A2_DUPLICATES_THRESHOLD=Setting( + default=0.7, + definition='Trigram similarity threshold for considering user as duplicate.'), + A2_DUPLICATES_BIRTHDATE_BONUS=Setting( + default=0.3, + definition='Bonus in case of birthdate match (no bonus is 0, max is 1).'), ) app_settings = AppSettings(default_settings) diff --git a/src/authentic2/custom_user/managers.py b/src/authentic2/custom_user/managers.py index de512ce51..0832b31b1 100644 --- a/src/authentic2/custom_user/managers.py +++ b/src/authentic2/custom_user/managers.py @@ -16,14 +16,20 @@ import datetime import logging +import unicodedata -from django.db import models, transaction +from django.contrib.contenttypes.models import ContentType +from django.contrib.postgres.search import TrigramDistance +from django.db import models, transaction, connection +from django.db.models import F, Value, FloatField, Subquery, OuterRef +from django.db.models.functions import Lower, Coalesce from django.utils import six from django.utils import timezone from django.contrib.auth.models import BaseUserManager from authentic2 import app_settings -from authentic2.models import Attribute +from authentic2.models import Attribute, AttributeValue +from authentic2.utils.lookups import Unaccent, ImmutableConcat class UserQuerySet(models.QuerySet): @@ -72,6 +78,38 @@ class UserQuerySet(models.QuerySet): self = self.distinct() return self + def find_duplicates(self, first_name, last_name, birthdate=None): + with connection.cursor() as cursor: + cursor.execute( + "SET pg_trgm.similarity_threshold = %f" % app_settings.A2_DUPLICATES_THRESHOLD + ) + + name = '%s %s' % (first_name, last_name) + name = unicodedata.normalize('NFKD', name).encode('ascii', 'ignore').decode('ascii').lower() + + qs = self.annotate(name=Lower(Unaccent(ImmutableConcat('first_name', Value(' '), 'last_name')))) + qs = qs.filter(name__trigram_similar=name) + qs = qs.annotate(dist=TrigramDistance('name', name)) + qs = qs.order_by('dist') + qs = qs[:5] + + # alter distance according to additionnal parameters + if birthdate: + bonus = app_settings.A2_DUPLICATES_BIRTHDATE_BONUS + content_type = ContentType.objects.get_for_model(self.model) + same_birthdate = AttributeValue.objects.filter( + object_id=OuterRef('pk'), + content_type=content_type, + attribute__kind='birthdate', + content=birthdate + ).annotate(bonus=Value(1 - bonus, output_field=FloatField())) + qs = qs.annotate(dist=Coalesce( + Subquery(same_birthdate.values('bonus'), output_field=FloatField()) * F('dist'), + F('dist') + )) + + return qs + @transaction.atomic def cleanup(self, threshold=600, timestamp=None): '''Delete all deleted users for more than 10 minutes.''' diff --git a/src/authentic2/migrations/0028_trigram_unaccent_index.py b/src/authentic2/migrations/0028_trigram_unaccent_index.py new file mode 100644 index 000000000..4c1e89c01 --- /dev/null +++ b/src/authentic2/migrations/0028_trigram_unaccent_index.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.18 on 2020-09-17 15:38 +from __future__ import unicode_literals + +from django.db import migrations, transaction +from django.db.migrations.operations.base import Operation +from django.db.utils import InternalError, OperationalError, ProgrammingError + + +class SafeExtensionOperation(Operation): + reversible = True + + def state_forwards(self, app_label, state): + pass + + def database_forwards(self, app_label, schema_editor, from_state, to_state): + if schema_editor.connection.vendor != 'postgresql': + return + try: + with transaction.atomic(): + try: + schema_editor.execute('CREATE EXTENSION IF NOT EXISTS %s SCHEMA public' % self.name) + except (OperationalError, ProgrammingError): + # OperationalError if the extension is not available + # ProgrammingError in case of denied permission + RunSQLIfExtension.extensions_installed = False + except InternalError: + # InternalError (current transaction is aborted, commands ignored + # until end of transaction block) would be raised when django- + # tenant-schemas set search_path. + RunSQLIfExtension.extensions_installed = False + + def database_backwards(self, app_label, schema_editor, from_state, to_state): + try: + with transaction.atomic(): + schema_editor.execute('DROP EXTENSION IF EXISTS %s' % self.name) + except InternalError: + # Raised when other objects depend on the extension. This happens in a multitenant + # context, where extension in installed in schema "public" but referenced in others (via + # public.gist_trgm_ops). In this case, do nothing, as the query should be successful + # when last tenant is processed. + pass + + +class RunSQLIfExtension(migrations.RunSQL): + extensions_installed = True + + def __getattribute__(self, name): + if name == 'sql' and not self.extensions_installed: + return migrations.RunSQL.noop + return object.__getattribute__(self, name) + + +class UnaccentExtension(SafeExtensionOperation): + name = 'unaccent' + + +class TrigramExtension(SafeExtensionOperation): + name = 'pg_trgm' + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentic2', '0027_remove_deleteduser'), + ] + + operations = [ + TrigramExtension(), + UnaccentExtension(), + RunSQLIfExtension( + sql=["CREATE OR REPLACE FUNCTION immutable_unaccent(text) RETURNS varchar AS $$ " + "SELECT public.unaccent('public.unaccent',$1::text); $$ LANGUAGE 'sql' IMMUTABLE"], + reverse_sql=['DROP FUNCTION IF EXISTS immutable_unaccent(text)'] + ), + RunSQLIfExtension( + sql=["CREATE INDEX custom_user_name_gist_idx ON custom_user_user USING gist " + "(LOWER(immutable_unaccent(first_name || ' ' || last_name)) public.gist_trgm_ops)"], + reverse_sql=['DROP INDEX IF EXISTS custom_user_name_gist_idx'], + ), + ] diff --git a/src/authentic2/settings.py b/src/authentic2/settings.py index bca5618b4..c0ea6d9ef 100644 --- a/src/authentic2/settings.py +++ b/src/authentic2/settings.py @@ -128,6 +128,7 @@ INSTALLED_APPS = ( 'django.contrib.messages', 'django.contrib.admin', 'django.contrib.humanize', + 'django.contrib.postgres', 'django_select2', 'django_tables2', 'mellon', diff --git a/src/authentic2/utils/lookups.py b/src/authentic2/utils/lookups.py new file mode 100644 index 000000000..135c9143b --- /dev/null +++ b/src/authentic2/utils/lookups.py @@ -0,0 +1,24 @@ +from django.contrib.postgres.lookups import Unaccent as PGUnaccent +from django.db.models import Func +from django.db.models.functions import Concat, ConcatPair as DjConcatPair + +class Unaccent(PGUnaccent): + function = 'immutable_unaccent' + + +class ConcatPair(DjConcatPair): + """Django ConcatPair does not implement as_postgresql, using CONCAT as a default. + + But we need immutable concatenation, || being immutable while CONCAT is not. + """ + def as_postgresql(self, compiler, connection): + return super(ConcatPair, self).as_sql( + compiler, connection, template='%(expressions)s', arg_joiner=' || ' + ) + + +class ImmutableConcat(Concat): + def _paired(self, expressions): + if len(expressions) == 2: + return ConcatPair(*expressions) + return ConcatPair(expressions[0], self._paired(expressions[1:])) diff --git a/tests/test_api.py b/tests/test_api.py index 120f4f791..5c0916cc8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1839,3 +1839,89 @@ def test_free_text_search(app, admin, settings): resp = app.get('/api/users/?q=10/02/1982') assert len(resp.json['results']) == 1 assert resp.json['results'][0]['id'] == user.id + + +def test_find_duplicates(app, admin, settings): + settings.LANGUAGE_CODE = 'fr' + app.authorization = ('Basic', (admin.username, admin.username)) + + first_name = 'Jean-Kévin' + last_name = 'Du Château' + user = User.objects.create(first_name=first_name, last_name=last_name) + + exact_match = { + 'first_name': first_name, + 'last_name': last_name, + } + resp = app.get('/api/users/find_duplicates/', params=exact_match) + assert resp.json['results'][0]['id'] == user.id + assert resp.json['results'][0]['duplicate_distance'] == 0 + + typo = { + 'first_name': 'Jean Kévin', + 'last_name': 'Du Châtau', + } + resp = app.get('/api/users/find_duplicates/', params=typo) + assert resp.json['results'][0]['id'] == user.id + assert resp.json['results'][0]['duplicate_distance'] > 0 + + typo = { + 'first_name': 'Jean Kévin', + 'last_name': 'Château', + } + resp = app.get('/api/users/find_duplicates/', params=typo) + assert resp.json['results'][0]['id'] == user.id + + other_person = { + 'first_name': 'Jean-Kévin', + 'last_name': 'Du Chêne', + } + user = User.objects.create(first_name='Éléonore', last_name='âêîôû') + resp = app.get('/api/users/find_duplicates/', params=other_person) + assert len(resp.json['results']) == 0 + + other_person = { + 'first_name': 'Pierre', + 'last_name': 'Du Château', + } + resp = app.get('/api/users/find_duplicates/', params=other_person) + assert len(resp.json['results']) == 0 + + +def test_find_duplicates_unaccent(app, admin, settings): + settings.LANGUAGE_CODE = 'fr' + app.authorization = ('Basic', (admin.username, admin.username)) + + user = User.objects.create(first_name='Éléonore', last_name='âêîôû') + + resp = app.get('/api/users/find_duplicates/', params={'first_name': 'Eleonore', 'last_name': 'aeiou'}) + assert resp.json['results'][0]['id'] == user.id + + +def test_find_duplicates_birthdate(app, admin, settings): + settings.LANGUAGE_CODE = 'fr' + app.authorization = ('Basic', (admin.username, admin.username)) + + Attribute.objects.create(kind='birthdate', name='birthdate', label='birthdate', required=False, searchable=True) + + user = User.objects.create(first_name='Jean', last_name='Dupont') + homonym = User.objects.create(first_name='Jean', last_name='Dupont') + user.attributes.birthdate = datetime.date(1980, 1, 2) + homonym.attributes.birthdate = datetime.date(1980, 1, 3) + + params = { + 'first_name': 'Jeanne', + 'last_name': 'Dupont', + } + resp = app.get('/api/users/find_duplicates/', params=params) + assert len(resp.json['results']) == 2 + + params['birthdate'] = '1980-01-2', + resp = app.get('/api/users/find_duplicates/', params=params) + assert len(resp.json['results']) == 2 + assert resp.json['results'][0]['id'] == user.pk + + params['birthdate'] = '1980-01-3', + resp = app.get('/api/users/find_duplicates/', params=params) + assert len(resp.json['results']) == 2 + assert resp.json['results'][0]['id'] == homonym.pk diff --git a/tests/test_large_userbase.py b/tests/test_large_userbase.py index 567a5df69..2010112bc 100644 --- a/tests/test_large_userbase.py +++ b/tests/test_large_userbase.py @@ -64,5 +64,30 @@ def large_userbase(db, freezer): for user_id in user_ids) -def test_large_userbase(large_userbase): - pass +def test_large_userbase_find_duplicates(large_userbase, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + + user = User.objects.first() + params = { + 'first_name': user.first_name, + 'last_name': user.last_name, + } + + for i in range(100): + resp = app.get('/api/users/find_duplicates/', params=params) + assert len(resp.json['results']) >= 1 + + +def test_large_userbase_find_duplicates_with_birthdate(large_userbase, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + + user = User.objects.first() + params = { + 'first_name': user.first_name, + 'last_name': user.last_name, + 'birthdate': str(user.attributes.birthdate), + } + + for i in range(100): + resp = app.get('/api/users/find_duplicates/', params=params) + assert len(resp.json['results']) >= 1 diff --git a/tests/test_migrations.py b/tests/test_migrations.py index eba7ff351..86b12ff9e 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -14,6 +14,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import mock + +from django.db.utils import ProgrammingError + def test_migration_0019_add_user_deleted(transactional_db, migration): old_apps = migration.before([ @@ -33,3 +37,15 @@ def test_migration_0019_add_user_deleted(transactional_db, migration): NewUser = new_apps.get_model('custom_user', 'User') new_user = NewUser.objects.get(id=user.id) assert new_user.deleted + + +def test_migration_0028_trigram_unaccent_index(transactional_db, migration): + migration.before([('authentic2', '0027_remove_deleteduser')]) + + def programming_error(*args, **kwargs): + raise ProgrammingError + + # when an error occurs, ensure migration runs anyway without complaining + with mock.patch('django.db.backends.postgresql.schema.DatabaseSchemaEditor.execute') as mocked: + mocked.side_effect = programming_error + migration.apply([('authentic2', '0028_trigram_unaccent_index')])