summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBertrand Bordage <bordage.bertrand@gmail.com>2017-06-03 13:18:20 (GMT)
committerBertrand Bordage <bordage.bertrand@gmail.com>2017-06-03 13:18:20 (GMT)
commit3b5dde8d3063b5ae5450599d7064ac7c3f10f0fb (patch)
tree0b0248306936b6410aa5a921ff57cb3a7cd069ba
parentcb877355ac4f2cb60de85a907372f672e31fbb18 (diff)
downloaddjango-cachalot-3b5dde8d3063b5ae5450599d7064ac7c3f10f0fb.zip
django-cachalot-3b5dde8d3063b5ae5450599d7064ac7c3f10f0fb.tar.gz
django-cachalot-3b5dde8d3063b5ae5450599d7064ac7c3f10f0fb.tar.bz2
Adds tests for RawSQL.
-rw-r--r--cachalot/tests/read.py20
-rw-r--r--cachalot/tests/write.py45
2 files changed, 65 insertions, 0 deletions
diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py
index 2d7c782..42f73a8 100644
--- a/cachalot/tests/read.py
+++ b/cachalot/tests/read.py
@@ -12,6 +12,7 @@ from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.db import connection, transaction
from django.db.models import Count
+from django.db.models.expressions import RawSQL
from django.db.transaction import TransactionManagementError
from django.test import (
TransactionTestCase, skipUnlessDBFeature, override_settings)
@@ -393,6 +394,25 @@ class ReadTestCase(TransactionTestCase):
self.assertListEqual(data7, data8)
self.assertListEqual(data7, [])
+ def test_raw_subquery(self):
+ raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s',
+ (self.t1__permission.pk,))
+ with self.assertNumQueries(1):
+ data3 = list(Test.objects.filter(permission=raw_sql))
+ with self.assertNumQueries(0):
+ data4 = list(Test.objects.filter(permission=raw_sql))
+ self.assertListEqual(data4, data3)
+ self.assertListEqual(data4, [self.t1])
+
+ with self.assertNumQueries(1):
+ data5 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ with self.assertNumQueries(0):
+ data6 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ self.assertListEqual(data6, data5)
+ self.assertListEqual(data6, [self.t1])
+
def test_aggregate(self):
Test.objects.create(name='test3', owner=self.user)
with self.assertNumQueries(1):
diff --git a/cachalot/tests/write.py b/cachalot/tests/write.py
index 0823ff6..aa43425 100644
--- a/cachalot/tests/write.py
+++ b/cachalot/tests/write.py
@@ -9,6 +9,7 @@ from django.core.exceptions import MultipleObjectsReturned
from django.core.management import call_command
from django.db import connection, transaction
from django.db.models import Count
+from django.db.models.expressions import RawSQL
from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import Test, TestParent, TestChild
@@ -467,6 +468,50 @@ class WriteTestCase(TransactionTestCase):
data12 = list(User.objects.exclude(user_permissions=None))
self.assertListEqual(data12, [])
+ def test_invalidate_raw_subquery(self):
+ permission = Permission.objects.first()
+ raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s',
+ (permission.pk,))
+ with self.assertNumQueries(1):
+ data1 = list(Test.objects.filter(permission=raw_sql))
+ with self.assertNumQueries(0):
+ data2 = list(Test.objects.filter(permission=raw_sql))
+ self.assertListEqual(data2, data1)
+ self.assertListEqual(data2, [])
+
+ test = Test.objects.create(name='test', permission=permission)
+
+ with self.assertNumQueries(1):
+ data3 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ with self.assertNumQueries(0):
+ data4 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ self.assertListEqual(data4, data3)
+ self.assertListEqual(data4, [test])
+
+ Permission.objects.first().save()
+
+ with self.assertNumQueries(1):
+ data5 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ with self.assertNumQueries(0):
+ data6 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ self.assertListEqual(data6, data5)
+ self.assertListEqual(data6, [test])
+
+ test.delete()
+
+ with self.assertNumQueries(1):
+ data7 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ with self.assertNumQueries(0):
+ data8 = list(Test.objects.filter(
+ pk__in=Test.objects.filter(permission=raw_sql)))
+ self.assertListEqual(data8, data7)
+ self.assertListEqual(data8, [])
+
def test_invalidate_nested_subqueries(self):
with self.assertNumQueries(1):
data1 = list(