custom_user: specialize free_text_search for common search terms (#49957)

This commit is contained in:
Benjamin Dauvergne 2021-01-08 12:00:46 +01:00
parent f4908a01f4
commit ab66385315
3 changed files with 169 additions and 43 deletions

View File

@ -17,68 +17,82 @@
import datetime
import logging
import unicodedata
import uuid
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.search import TrigramDistance
from django.core.exceptions import ValidationError
from django.db import models, transaction, connection
from django.db.models import F, Value, FloatField, Subquery, OuterRef
from django.db.models import F, Value, FloatField, Subquery, OuterRef, Q
from django.db.models.functions import Lower, Coalesce
from django.utils import six
from django.utils import timezone
from django.contrib.auth.models import BaseUserManager
from django.contrib.postgres.search import SearchQuery
from authentic2 import app_settings
from authentic2.models import Attribute, AttributeValue, UserExternalId
from authentic2.models import AttributeValue, UserExternalId
from authentic2.utils.lookups import Unaccent, ImmutableConcat
from authentic2.utils.date import parse_date
from authentic2.attribute_kinds import clean_number
class UserQuerySet(models.QuerySet):
def free_text_search(self, search):
terms = search.split()
search = search.strip()
if not terms:
return self
if len(search) == 0:
return self.none()
searchable_attributes = Attribute.objects.filter(searchable=True)
queries = []
for term in terms:
q = None
if '@' in search and len(search.split()) == 1:
qs = self.filter(email__iexact=search).order_by('last_name', 'first_name')
if qs.exists():
return qs
specific_queries = []
for a in searchable_attributes:
kind = a.get_kind()
free_text_search_function = kind.get('free_text_search')
if free_text_search_function:
q = free_text_search_function(term)
if q is not None:
specific_queries.append(q & models.query.Q(attribute_values__attribute=a))
try:
guid = uuid.UUID(search)
except ValueError:
pass
else:
return self.filter(uuid=guid.hex)
# if the term is recognized by some specific attribute type, like a
# date, does not use the later generic matcher
if specific_queries:
queries.append(six.moves.reduce(models.query.Q.__or__, specific_queries))
continue
try:
phone_number = clean_number(search)
except ValidationError:
pass
else:
attribute_values = AttributeValue.objects.filter(
search_vector=SearchQuery(phone_number), attribute__kind='phone_number')
qs = self.filter(attribute_values__in=attribute_values).order_by('last_name', 'first_name')
if qs.exists():
return qs
q = (
models.query.Q(username__icontains=term)
| models.query.Q(first_name__icontains=term)
| models.query.Q(last_name__icontains=term)
| models.query.Q(email__icontains=term)
)
for a in searchable_attributes:
if a.name in ('first_name', 'last_name'):
continue
q = q | models.query.Q(
attribute_values__content__icontains=term, attribute_values__attribute=a)
queries.append(q)
self = self.filter(six.moves.reduce(models.query.Q.__and__, queries))
# search by attributes can match multiple times
if searchable_attributes:
self = self.distinct()
return self
try:
date = parse_date(search)
except ValueError:
pass
else:
attribute_values = AttributeValue.objects.filter(
search_vector=SearchQuery(date.isoformat()), attribute__kind='birthdate')
qs = self.filter(attribute_values__in=attribute_values).order_by('last_name', 'first_name')
if qs.exists():
return qs
def find_duplicates(self, first_name=None, last_name=None, fullname=None, birthdate=None):
qs = self.find_duplicates(fullname=search, limit=None)
extra_user_ids = set()
attribute_values = AttributeValue.objects.filter(search_vector=SearchQuery(search), attribute__searchable=True)
extra_user_ids.update(self.filter(attribute_values__in=attribute_values).values_list('id', flat=True))
if len(search.split()) == 1:
extra_user_ids.update(
self.filter(
Q(username__istartswith=search)
| Q(email__istartswith=search)
).values_list('id', flat=True))
if extra_user_ids:
qs = qs | self.filter(id__in=extra_user_ids)
qs = qs.order_by('dist', 'last_name', 'first_name')
return qs
def find_duplicates(self, first_name=None, last_name=None, fullname=None, birthdate=None, limit=5):
with connection.cursor() as cursor:
cursor.execute(
"SET pg_trgm.similarity_threshold = %f" % app_settings.A2_DUPLICATES_THRESHOLD
@ -96,7 +110,8 @@ class UserQuerySet(models.QuerySet):
qs = qs.filter(name__trigram_similar=name)
qs = qs.annotate(dist=TrigramDistance('name', name))
qs = qs.order_by('dist')
qs = qs[:5]
if limit is not None:
qs = qs[:limit]
# alter distance according to additionnal parameters
if birthdate:

View File

@ -0,0 +1,33 @@
# authentic2 - versatile identity manager
# Copyright (C) 2010-2020 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime
from django.utils import formats
def parse_date(formatted_date):
parsed_date = None
for date_format in formats.get_format('DATE_INPUT_FORMATS'):
try:
parsed_date = datetime.strptime(formatted_date, date_format)
except ValueError:
continue
else:
break
if not parsed_date:
raise ValueError
return parsed_date.date()

View File

@ -14,10 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import date
from django.contrib.auth import get_user_model
from authentic2.models import Attribute
from django_rbac.utils import get_permission_model, get_role_model
import pytest
Permission = get_permission_model()
Role = get_role_model()
User = get_user_model()
@ -59,3 +64,76 @@ def test_user_mark_as_deleted(db):
User.objects.create(username='foo2', email='foo@example.net')
assert len(User.objects.filter(email='foo@example.net')) == 2
assert len(User.objects.filter(email='foo@example.net', deleted__isnull=True)) == 1
@pytest.fixture
def fts(db):
Attribute.objects.create(name='adresse', label='adresse', searchable=True, kind='string')
Attribute.objects.create(name='telephone', label='telephone', searchable=True, kind='phone_number')
Attribute.objects.create(name='dob', label='dob', searchable=True, kind='birthdate')
user1 = User.objects.create(
username='foo1234',
first_name='Jo',
last_name='darmettein',
email='jean.darmette@example.net'
)
user2 = User.objects.create(
username='bar1234',
first_name='Lea',
last_name='darmettein',
email='micheline.darmette@example.net'
)
user3 = User.objects.create(
first_name='',
last_name='peuplier',
)
user1.attributes.adresse = '4 rue des peupliers 13001 MARSEILLE'
user2.attributes.adresse = '4 rue des peupliers 13001 MARSEILLE'
user1.attributes.telephone = '0601020304'
user2.attributes.telephone = '0601020305'
user1.attributes.dob = date(1970, 1, 1)
user2.attributes.dob = date(1972, 2, 2)
return locals()
def test_fts_uuid(fts):
assert User.objects.free_text_search(fts['user1'].uuid).count() == 1
assert User.objects.free_text_search(fts['user2'].uuid).count() == 1
def test_fts_phone(fts):
assert list(User.objects.free_text_search('0601020304')) == [fts['user1']]
assert list(User.objects.free_text_search('0601020305')) == [fts['user2']]
def test_fts_dob(fts):
assert User.objects.free_text_search('01/01/1970').count() == 1
assert User.objects.free_text_search('02/02/1972').count() == 1
assert User.objects.free_text_search('03/03/1973').count() == 0
def test_fts_email(fts):
assert User.objects.free_text_search('jean.darmette@example.net').count() == 1
assert User.objects.free_text_search('micheline.darmette@example.net').count() == 1
def test_fts_username(fts):
assert User.objects.free_text_search('foo1234').count() == 1
assert User.objects.free_text_search('bar1234').count() == 1
def test_fts_trigram(fts):
assert User.objects.free_text_search('darmettein').count() == 2
# dist attribute signals queryset from find_duplicates()
assert hasattr(User.objects.free_text_search('darmettein')[0], 'dist')
assert User.objects.free_text_search('lea darmettein').count() == 1
assert hasattr(User.objects.free_text_search('darmettein')[0], 'dist')
def test_fts_legacy(fts):
assert User.objects.free_text_search('rue des peupliers').count() == 2
def test_fts_legacy_and_trigram(fts):
assert User.objects.free_text_search('peuplier').count() == 3