journal: permit custom prefetching (#51808)

This commit is contained in:
Benjamin Dauvergne 2021-03-09 22:25:10 +01:00
parent 32cc38c814
commit e77c98b57b
3 changed files with 62 additions and 3 deletions

View File

@ -330,13 +330,16 @@ class JournalForm(forms.Form):
first = len(page) <= limit
last = True
page = page[-limit:]
models.prefetch_events_references(page)
models.prefetch_events_references(page, prefetcher=self.prefetcher)
if page:
self.data = self.data.copy()
self.cleaned_data['after_cursor'] = self.data['after_cursor'] = page[0].cursor.minus_one()
self.cleaned_data['before_cursor'] = ''
return Page(self, page, first, last)
def prefetcher(self, model, pks):
return []
@cached_property
def date_hierarchy(self):
self.is_valid()

View File

@ -20,6 +20,7 @@ from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime, timedelta
import django
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
@ -458,7 +459,7 @@ class EventCursor(str):
return EventCursor('%s %s' % (self.timestamp.timestamp(), self.event_id - 1))
def prefetch_events_references(events):
def prefetch_events_references(events, prefetcher=None):
'''Prefetch references on an iterable of events, prevent N+1 queries problem.'''
grouped_references = defaultdict(set)
references = {}
@ -473,6 +474,25 @@ def prefetch_events_references(events):
content_type = ContentType.objects.get_for_id(content_type_id)
for instance in content_type.get_all_objects_for_this_type(pk__in=instance_pks):
references[(content_type_id, instance.pk)] = instance
if prefetcher:
deleted_pks = [pk for pk in instance_pks if (content_type_id, pk) not in references]
if deleted_pks:
for found_pk, instance in prefetcher(content_type.model_class(), deleted_pks):
references[(content_type_id, found_pk)] = instance
# prefetch the user column if absent
if prefetcher:
user_to_events = {}
for event in events:
if event.user is None and event.user_id:
user_to_events.setdefault(event.user_id, []).append(event)
for found_pk, instance in prefetcher(User, user_to_events.keys()):
for event in user_to_events[found_pk]:
# prevent TypeError in user's field descriptor __set__ method
if django.VERSION < (2,):
event._user_cache = instance
else:
event._state.fields_cache['user'] = instance
# assign references to events
for event in events:

View File

@ -28,7 +28,13 @@ from authentic2.a2_rbac.models import OrganizationalUnit as OU
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.apps.journal.forms import JournalForm
from authentic2.apps.journal.journal import Journal
from authentic2.apps.journal.models import Event, EventType, EventTypeDefinition, clean_registry
from authentic2.apps.journal.models import (
Event,
EventType,
EventTypeDefinition,
clean_registry,
prefetch_events_references,
)
from authentic2.models import Service
User = get_user_model()
@ -146,6 +152,7 @@ def test_references(db):
assert list(event.get_typed_references(User, None)) == [None, None]
event = Event.objects.get()
assert list(event.get_typed_references(Service, User)) == [None, None]
assert event.user is None
def test_event_types(clean_event_types_definition_registry):
@ -669,3 +676,32 @@ def test_statistics_ou_with_no_service(db, freezer):
ou_with_no_service = OU.objects.create(name='Second OU')
stats = event_type_definition.get_method_statistics('month', services_ou=ou_with_no_service)
assert stats == {'x_labels': [], 'series': []}
def test_prefetcher(db):
event_type = EventType.objects.get_for_name('user.login')
for i in range(10):
user = User.objects.create()
Event.objects.create(type=event_type, user=user, references=[user])
Event.objects.create(type=event_type, user=user, references=[user])
User.objects.all().delete()
events = list(Event.objects.all())
prefetch_events_references(events)
for event in events:
assert event.user is None
assert list(event.get_typed_references(User)) == [None]
def prefetcher(model, pks):
if not issubclass(model, User):
return
for pk in pks:
yield pk, 'deleted %s' % pk
events = list(Event.objects.all())
prefetch_events_references(events, prefetcher=prefetcher)
for event in events:
s = 'deleted %s' % event.user_id
assert event.user == s
assert list(event.get_typed_references((str, User))) == [s]