diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index 2d7c782..42f73a8 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -12,6 +12,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.cache import cache from django.db import connection, transaction from django.db.models import Count +from django.db.models.expressions import RawSQL from django.db.transaction import TransactionManagementError from django.test import ( TransactionTestCase, skipUnlessDBFeature, override_settings) @@ -393,6 +394,25 @@ class ReadTestCase(TransactionTestCase): self.assertListEqual(data7, data8) self.assertListEqual(data7, []) + def test_raw_subquery(self): + raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s', + (self.t1__permission.pk,)) + with self.assertNumQueries(1): + data3 = list(Test.objects.filter(permission=raw_sql)) + with self.assertNumQueries(0): + data4 = list(Test.objects.filter(permission=raw_sql)) + self.assertListEqual(data4, data3) + self.assertListEqual(data4, [self.t1]) + + with self.assertNumQueries(1): + data5 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + with self.assertNumQueries(0): + data6 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + self.assertListEqual(data6, data5) + self.assertListEqual(data6, [self.t1]) + def test_aggregate(self): Test.objects.create(name='test3', owner=self.user) with self.assertNumQueries(1): diff --git a/cachalot/tests/write.py b/cachalot/tests/write.py index 0823ff6..aa43425 100644 --- a/cachalot/tests/write.py +++ b/cachalot/tests/write.py @@ -9,6 +9,7 @@ from django.core.exceptions import MultipleObjectsReturned from django.core.management import call_command from django.db import connection, transaction from django.db.models import Count +from django.db.models.expressions import RawSQL from django.test import TransactionTestCase, skipUnlessDBFeature from .models import Test, TestParent, TestChild @@ -467,6 +468,50 @@ class WriteTestCase(TransactionTestCase): data12 = list(User.objects.exclude(user_permissions=None)) self.assertListEqual(data12, []) + def test_invalidate_raw_subquery(self): + permission = Permission.objects.first() + raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s', + (permission.pk,)) + with self.assertNumQueries(1): + data1 = list(Test.objects.filter(permission=raw_sql)) + with self.assertNumQueries(0): + data2 = list(Test.objects.filter(permission=raw_sql)) + self.assertListEqual(data2, data1) + self.assertListEqual(data2, []) + + test = Test.objects.create(name='test', permission=permission) + + with self.assertNumQueries(1): + data3 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + with self.assertNumQueries(0): + data4 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + self.assertListEqual(data4, data3) + self.assertListEqual(data4, [test]) + + Permission.objects.first().save() + + with self.assertNumQueries(1): + data5 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + with self.assertNumQueries(0): + data6 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + self.assertListEqual(data6, data5) + self.assertListEqual(data6, [test]) + + test.delete() + + with self.assertNumQueries(1): + data7 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + with self.assertNumQueries(0): + data8 = list(Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql))) + self.assertListEqual(data8, data7) + self.assertListEqual(data8, []) + def test_invalidate_nested_subqueries(self): with self.assertNumQueries(1): data1 = list(