From 542a5b60978b4400f5a63a6f6ff8eb739be959af Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Sat, 3 Jun 2017 17:45:26 +0200 Subject: [PATCH] Simplifies tests while increasing their robustness by explicitely checking detected table names. --- cachalot/tests/api.py | 7 +- cachalot/tests/multi_db.py | 9 +- cachalot/tests/read.py | 451 ++++++++++++-------------------- cachalot/tests/settings.py | 160 ++++------- cachalot/tests/test_utils.py | 40 +++ cachalot/tests/thread_safety.py | 9 +- cachalot/tests/transaction.py | 57 ++-- cachalot/tests/write.py | 82 ++---- 8 files changed, 296 insertions(+), 519 deletions(-) create mode 100644 cachalot/tests/test_utils.py diff --git a/cachalot/tests/api.py b/cachalot/tests/api.py index dbd3316..6e83947 100644 --- a/cachalot/tests/api.py +++ b/cachalot/tests/api.py @@ -9,18 +9,19 @@ from django.contrib.auth.models import User from django.core.cache import DEFAULT_CACHE_ALIAS, caches from django.core.management import call_command from django.db import connection, transaction, DEFAULT_DB_ALIAS -from django.template import engines, Context +from django.template import engines from django.test import TransactionTestCase from jinja2.exceptions import TemplateSyntaxError from ..api import * from .models import Test +from .test_utils import TestUtilsMixin -class APITestCase(TransactionTestCase): +class APITestCase(TestUtilsMixin, TransactionTestCase): def setUp(self): + super(APITestCase, self).setUp() self.t1 = Test.objects.create(name='test1') - self.is_sqlite = connection.vendor == 'sqlite' self.cache_alias2 = next(alias for alias in settings.CACHES if alias != DEFAULT_CACHE_ALIAS) diff --git a/cachalot/tests/multi_db.py b/cachalot/tests/multi_db.py index c189456..a9833b8 100644 --- a/cachalot/tests/multi_db.py +++ b/cachalot/tests/multi_db.py @@ -8,14 +8,17 @@ from django.db import DEFAULT_DB_ALIAS, connections, transaction from django.test import TransactionTestCase from .models import Test +from .test_utils import TestUtilsMixin @skipIf(len(settings.DATABASES) == 1, 'We can’t change the DB used since there’s only one configured') -class MultiDatabaseTestCase(TransactionTestCase): +class MultiDatabaseTestCase(TestUtilsMixin, TransactionTestCase): multi_db = True def setUp(self): + super(MultiDatabaseTestCase, self).setUp() + self.t1 = Test.objects.create(name='test1') self.t2 = Test.objects.create(name='test2') self.db_alias2 = next(alias for alias in settings.DATABASES @@ -23,10 +26,6 @@ class MultiDatabaseTestCase(TransactionTestCase): connection2 = connections[self.db_alias2] self.is_sqlite2 = connection2.vendor == 'sqlite' self.is_mysql2 = connection2.vendor == 'mysql' - if connection2.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection2.cursor() def test_read(self): with self.assertNumQueries(1): diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index 42f73a8..905d4ae 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -18,8 +18,9 @@ from django.test import ( TransactionTestCase, skipUnlessDBFeature, override_settings) from pytz import UTC -from ..utils import _get_table_cache_key +from ..utils import _get_table_cache_key, UncachableQuery from .models import Test, TestChild +from .test_utils import TestUtilsMixin DJANGO_GTE_1_9 = django_version[:2] >= (1, 9) @@ -28,7 +29,7 @@ if DJANGO_GTE_1_9: from django.db.models.functions import Now -class ReadTestCase(TransactionTestCase): +class ReadTestCase(TestUtilsMixin, TransactionTestCase): """ Tests if every SQL request that only reads data is cached. @@ -38,6 +39,8 @@ class ReadTestCase(TransactionTestCase): """ def setUp(self): + super(ReadTestCase, self).setUp() + self.group = Group.objects.create(name='test_group') self.group__permissions = list(Permission.objects.all()[:3]) self.group.permissions.add(*self.group__permissions) @@ -57,8 +60,6 @@ class ReadTestCase(TransactionTestCase): name='test2', owner=self.admin, public=True, date='1944-06-06', datetime='1944-06-06T06:35:00') - self.is_sqlite = connection.vendor == 'sqlite' - def test_empty(self): with self.assertNumQueries(0): data1 = list(Test.objects.none()) @@ -121,131 +122,87 @@ class ReadTestCase(TransactionTestCase): self.assertListEqual(data2, [self.t1, self.t2]) def test_filter(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.filter(public=True)) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter(public=True)) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t2]) + qs = Test.objects.filter(public=True) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2]) - with self.assertNumQueries(1): - data1 = list(Test.objects.filter(name__in=['test2', 'test72'])) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter(name__in=['test2', 'test72'])) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t2]) + qs = Test.objects.filter(name__in=['test2', 'test72']) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2]) - with self.assertNumQueries(1): - data1 = list(Test.objects.filter( - date__gt=datetime.date(1900, 1, 1))) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter( - date__gt=datetime.date(1900, 1, 1))) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t2]) + qs = Test.objects.filter(date__gt=datetime.date(1900, 1, 1)) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2]) - with self.assertNumQueries(1): - data1 = list(Test.objects.filter( - datetime__lt=datetime.datetime(1900, 1, 1))) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter( - datetime__lt=datetime.datetime(1900, 1, 1))) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1]) + qs = Test.objects.filter(datetime__lt=datetime.datetime(1900, 1, 1)) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t1]) def test_filter_empty(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.filter(public=True, - name='user')) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter(public=True, - name='user')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, []) + qs = Test.objects.filter(public=True, name='user') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, []) def test_exclude(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.exclude(public=True)) - with self.assertNumQueries(0): - data2 = list(Test.objects.exclude(public=True)) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1]) + qs = Test.objects.exclude(public=True) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t1]) - with self.assertNumQueries(1): - data1 = list(Test.objects.exclude(name__in=['test2', 'test72'])) - with self.assertNumQueries(0): - data2 = list(Test.objects.exclude(name__in=['test2', 'test72'])) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1]) + qs = Test.objects.exclude(name__in=['test2', 'test72']) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t1]) def test_slicing(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.all()[:1]) - with self.assertNumQueries(0): - data2 = list(Test.objects.all()[:1]) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1]) + qs = Test.objects.all()[:1] + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t1]) def test_order_by(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.order_by('pk')) - with self.assertNumQueries(0): - data2 = list(Test.objects.order_by('pk')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1, self.t2]) + qs = Test.objects.order_by('pk') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t1, self.t2]) - with self.assertNumQueries(1): - data1 = list(Test.objects.order_by('-name')) - with self.assertNumQueries(0): - data2 = list(Test.objects.order_by('-name')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t2, self.t1]) + qs = Test.objects.order_by('-name') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2, self.t1]) def test_random_order_by(self): - with self.assertNumQueries(1): - list(Test.objects.order_by('?')) - with self.assertNumQueries(1): - list(Test.objects.order_by('?')) + qs = Test.objects.order_by('?') + with self.assertRaises(UncachableQuery): + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, after=1, compare_results=False) @skipIf(connection.vendor == 'mysql', 'MySQL does not support limit/offset on a subquery. ' 'Since Django only applies ordering in subqueries when they are ' 'offset/limited, we can’t test it on MySQL.') def test_random_order_by_subquery(self): - with self.assertNumQueries(1): - list(Test.objects.filter( - pk__in=Test.objects.order_by('?')[:10])) - with self.assertNumQueries(1): - list(Test.objects.filter( - pk__in=Test.objects.order_by('?')[:10])) + qs = Test.objects.filter( + pk__in=Test.objects.order_by('?')[:10]) + with self.assertRaises(UncachableQuery): + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, after=1, compare_results=False) def test_reverse(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.reverse()) - with self.assertNumQueries(0): - data2 = list(Test.objects.reverse()) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t2, self.t1]) + qs = Test.objects.reverse() + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2, self.t1]) def test_distinct(self): # We ensure that the query without distinct should return duplicate # objects, in order to have a real-world example. - data1 = list(Test.objects.filter( - owner__user_permissions__content_type__app_label='auth')) - self.assertEqual(len(data1), 3) - self.assertListEqual(data1, [self.t1] * 3) + qs = Test.objects.filter( + owner__user_permissions__content_type__app_label='auth') + self.assert_tables(qs, 'cachalot_test', 'auth_user', + 'auth_user_user_permissions', 'auth_permission', + 'django_content_type') + self.assert_query_cached(qs, [self.t1, self.t1, self.t1]) - with self.assertNumQueries(1): - data2 = list(Test.objects.filter( - owner__user_permissions__content_type__app_label='auth' - ).distinct()) - with self.assertNumQueries(0): - data3 = list(Test.objects.filter( - owner__user_permissions__content_type__app_label='auth' - ).distinct()) - self.assertListEqual(data3, data2) - self.assertEqual(len(data3), 1) - self.assertListEqual(data3, [self.t1]) + qs = qs.distinct() + self.assert_tables(qs, 'cachalot_test', 'auth_user', + 'auth_user_user_permissions', 'auth_permission', + 'django_content_type') + self.assert_query_cached(qs, [self.t1]) def test_iterator(self): with self.assertNumQueries(1): @@ -276,12 +233,9 @@ class ReadTestCase(TransactionTestCase): self.assertDictEqual(data2[1], {'name': 'test2', 'public': True}) def test_values_list(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list('name', flat=True)) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list('name', flat=True)) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, ['test1', 'test2']) + qs = Test.objects.values_list('name', flat=True) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, ['test1', 'test2']) def test_earliest(self): with self.assertNumQueries(1): @@ -300,33 +254,24 @@ class ReadTestCase(TransactionTestCase): self.assertEqual(data2, self.t2) def test_dates(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.dates('date', 'year')) - with self.assertNumQueries(0): - data2 = list(Test.objects.dates('date', 'year')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [datetime.date(1789, 1, 1), - datetime.date(1944, 1, 1)]) + qs = Test.objects.dates('date', 'year') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [datetime.date(1789, 1, 1), + datetime.date(1944, 1, 1)]) def test_datetimes(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.datetimes('datetime', 'hour')) - with self.assertNumQueries(0): - data2 = list(Test.objects.datetimes('datetime', 'hour')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [datetime.datetime(1789, 7, 14, 16), - datetime.datetime(1944, 6, 6, 6)]) + qs = Test.objects.datetimes('datetime', 'hour') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [datetime.datetime(1789, 7, 14, 16), + datetime.datetime(1944, 6, 6, 6)]) @skipIf(connection.vendor == 'mysql', 'Time zones are not supported by MySQL.') @override_settings(USE_TZ=True) def test_datetimes_with_time_zones(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.datetimes('datetime', 'hour')) - with self.assertNumQueries(0): - data2 = list(Test.objects.datetimes('datetime', 'hour')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [ + qs = Test.objects.datetimes('datetime', 'hour') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [ datetime.datetime(1789, 7, 14, 16, tzinfo=UTC), datetime.datetime(1944, 6, 6, 6, tzinfo=UTC)]) @@ -356,62 +301,41 @@ class ReadTestCase(TransactionTestCase): self.assertListEqual(data2, ['cuddle', 'discuss', 'touch']) def test_subquery(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.filter(owner__in=User.objects.all())) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter(owner__in=User.objects.all())) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [self.t1, self.t2]) + qs = Test.objects.filter(owner__in=User.objects.all()) + self.assert_tables(qs, 'cachalot_test', 'auth_user') + self.assert_query_cached(qs, [self.t1, self.t2]) - with self.assertNumQueries(1): - data3 = list(Test.objects.filter( - owner__groups__permissions__in=Permission.objects.all())) - with self.assertNumQueries(0): - data4 = list(Test.objects.filter( - owner__groups__permissions__in=Permission.objects.all())) - self.assertListEqual(data4, data3) - self.assertListEqual(data4, [self.t1, self.t1, self.t1]) + qs = Test.objects.filter( + owner__groups__permissions__in=Permission.objects.all()) + self.assert_tables(qs, 'cachalot_test', 'auth_user', + 'auth_user_groups', 'auth_group', + 'auth_group_permissions', 'auth_permission') + self.assert_query_cached(qs, [self.t1, self.t1, self.t1]) - with self.assertNumQueries(1): - data5 = list( - Test.objects.filter( - owner__groups__permissions__in=Permission.objects.all() - ).distinct()) - with self.assertNumQueries(0): - data6 = list( - Test.objects.filter( - owner__groups__permissions__in=Permission.objects.all() - ).distinct()) - self.assertListEqual(data6, data5) - self.assertListEqual(data6, [self.t1]) + qs = Test.objects.filter( + owner__groups__permissions__in=Permission.objects.all() + ).distinct() + self.assert_tables(qs, 'cachalot_test', 'auth_user', + 'auth_user_groups', 'auth_group', + 'auth_group_permissions', 'auth_permission') + self.assert_query_cached(qs, [self.t1]) - with self.assertNumQueries(1): - data7 = list( - TestChild.objects.exclude(permissions__isnull=True)) - with self.assertNumQueries(0): - data8 = list( - TestChild.objects.exclude(permissions__isnull=True)) - self.assertListEqual(data7, data8) - self.assertListEqual(data7, []) + qs = TestChild.objects.exclude(permissions__isnull=True) + self.assert_tables(qs, 'cachalot_testparent', 'cachalot_testchild', + 'cachalot_testchild_permissions', 'auth_permission') + self.assert_query_cached(qs, []) def test_raw_subquery(self): raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s', (self.t1__permission.pk,)) - with self.assertNumQueries(1): - data3 = list(Test.objects.filter(permission=raw_sql)) - with self.assertNumQueries(0): - data4 = list(Test.objects.filter(permission=raw_sql)) - self.assertListEqual(data4, data3) - self.assertListEqual(data4, [self.t1]) + qs = Test.objects.filter(permission=raw_sql) + self.assert_tables(qs, 'cachalot_test', 'auth_permission') + self.assert_query_cached(qs, [self.t1]) - with self.assertNumQueries(1): - data5 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - with self.assertNumQueries(0): - data6 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - self.assertListEqual(data6, data5) - self.assertListEqual(data6, [self.t1]) + qs = Test.objects.filter( + pk__in=Test.objects.filter(permission=raw_sql)) + self.assert_tables(qs, 'cachalot_test', 'auth_permission') + self.assert_query_cached(qs, [self.t1]) def test_aggregate(self): Test.objects.create(name='test3', owner=self.user) @@ -424,14 +348,10 @@ class ReadTestCase(TransactionTestCase): def test_annotate(self): Test.objects.create(name='test3', owner=self.user) - with self.assertNumQueries(1): - data1 = list(User.objects.annotate(n=Count('test')).order_by('pk') - .values_list('n', flat=True)) - with self.assertNumQueries(0): - data2 = list(User.objects.annotate(n=Count('test')).order_by('pk') - .values_list('n', flat=True)) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [2, 1]) + qs = (User.objects.annotate(n=Count('test')).order_by('pk') + .values_list('n', flat=True)) + self.assert_tables(qs, 'auth_user', 'cachalot_test') + self.assert_query_cached(qs, [2, 1]) def test_only(self): with self.assertNumQueries(1): @@ -584,15 +504,11 @@ class ReadTestCase(TransactionTestCase): ['test1', 'test2']) def test_having(self): - with self.assertNumQueries(1): - data1 = list(User.objects.annotate(n=Count('user_permissions')) - .filter(n__gte=1)) - self.assertListEqual(data1, [self.user]) - - with self.assertNumQueries(0): - data2 = list(User.objects.annotate(n=Count('user_permissions')) - .filter(n__gte=1)) - self.assertListEqual(data2, [self.user]) + qs = (User.objects.annotate(n=Count('user_permissions')) + .filter(n__gte=1)) + self.assert_tables(qs, 'auth_user', 'auth_user_user_permissions', + 'auth_permission') + self.assert_query_cached(qs, [self.user]) with self.assertNumQueries(1): self.assertEqual(User.objects.annotate(n=Count('user_permissions')) @@ -624,26 +540,20 @@ class ReadTestCase(TransactionTestCase): def test_extra_where(self): sql_condition = ("owner_id IN " "(SELECT id FROM auth_user WHERE username = 'admin')") - with self.assertNumQueries(1): - data1 = list(Test.objects.extra(where=[sql_condition])) - self.assertListEqual(data1, [self.t2]) - with self.assertNumQueries(0): - data2 = list(Test.objects.extra(where=[sql_condition])) - self.assertListEqual(data2, [self.t2]) + qs = Test.objects.extra(where=[sql_condition]) + self.assert_tables(qs, 'cachalot_test', 'auth_user') + self.assert_query_cached(qs, [self.t2]) def test_extra_tables(self): - with self.assertNumQueries(1): - list(Test.objects.extra(tables=['auth_user'])) - with self.assertNumQueries(0): - list(Test.objects.extra(tables=['auth_user'])) + qs = Test.objects.extra(tables=['auth_user'], + select={'extra_id': 'auth_user.id'}) + self.assert_tables(qs, 'cachalot_test', 'auth_user') + self.assert_query_cached(qs) def test_extra_order_by(self): - with self.assertNumQueries(1): - data1 = list(Test.objects.extra(order_by=['-cachalot_test.name'])) - self.assertListEqual(data1, [self.t2, self.t1]) - with self.assertNumQueries(0): - data2 = list(Test.objects.extra(order_by=['-cachalot_test.name'])) - self.assertListEqual(data2, [self.t2, self.t1]) + qs = Test.objects.extra(order_by=['-cachalot_test.name']) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [self.t2, self.t1]) def test_table_inheritance(self): with self.assertNumQueries(3 if self.is_sqlite else 2): @@ -719,17 +629,15 @@ class ReadTestCase(TransactionTestCase): [('é',) + l for l in Test.objects.values_list(*attnames)]) def test_missing_table_cache_key(self): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + qs = Test.objects.all() + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs) table_cache_key = _get_table_cache_key(connection.alias, Test._meta.db_table) cache.delete(table_cache_key) - with self.assertNumQueries(1): - list(Test.objects.all()) + self.assert_query_cached(qs) def test_unicode_get(self): with self.assertNumQueries(1): @@ -748,23 +656,18 @@ class ReadTestCase(TransactionTestCase): table_name = '"%s"' % table_name with connection.cursor() as cursor: cursor.execute('CREATE TABLE %s (taste VARCHAR(20));' % table_name) - with self.assertNumQueries(1): - list(Test.objects.extra(tables=['Clémentine'])) - with self.assertNumQueries(0): - list(Test.objects.extra(tables=['Clémentine'])) + qs = Test.objects.extra(tables=['Clémentine'], + select={'taste': '%s.taste' % table_name}) + # Here the table `Clémentine` is not detected because it is not + # registered by Django, and we only check for registered tables + # to avoid creating an extra SQL query fetching table names. + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs) with connection.cursor() as cursor: cursor.execute('DROP TABLE %s;' % table_name) -class ParameterTypeTestCase(TransactionTestCase): - def setUp(self): - self.is_sqlite = connection.vendor == 'sqlite' - self.is_mysql = connection.vendor == 'mysql' - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() - +class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase): def test_binary(self): """ Binary data should be cached on PostgreSQL & MySQL, but not on SQLite, @@ -773,20 +676,17 @@ class ParameterTypeTestCase(TransactionTestCase): So this also tests how django-cachalot handles unknown params, in this case the `memory` object passed to SQLite. """ - with self.assertNumQueries(1): - list(Test.objects.filter(bin=None)) - with self.assertNumQueries(0): - list(Test.objects.filter(bin=None)) + qs = Test.objects.filter(bin=None) + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs) - with self.assertNumQueries(1): - list(Test.objects.filter(bin=b'abc')) - with self.assertNumQueries(1 if self.is_sqlite else 0): - list(Test.objects.filter(bin=b'abc')) + qs = Test.objects.filter(bin=b'abc') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, after=1 if self.is_sqlite else 0) - with self.assertNumQueries(1): - list(Test.objects.filter(bin=b'def')) - with self.assertNumQueries(1 if self.is_sqlite else 0): - list(Test.objects.filter(bin=b'def')) + qs = Test.objects.filter(bin=b'def') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, after=1 if self.is_sqlite else 0) def test_float(self): with self.assertNumQueries(2 if self.is_sqlite else 1): @@ -814,14 +714,11 @@ class ParameterTypeTestCase(TransactionTestCase): Test.objects.create(name='test1', a_decimal=Decimal('123.45')) with self.assertNumQueries(2 if self.is_sqlite else 1): Test.objects.create(name='test1', a_decimal=Decimal('12.3')) - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list('a_decimal', flat=True).filter( - a_decimal__isnull=False).order_by('a_decimal')) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list('a_decimal', flat=True).filter( - a_decimal__isnull=False).order_by('a_decimal')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [Decimal('12.3'), Decimal('123.45')]) + + qs = Test.objects.values_list('a_decimal', flat=True).filter( + a_decimal__isnull=False).order_by('a_decimal') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [Decimal('12.3'), Decimal('123.45')]) with self.assertNumQueries(1): Test.objects.get(a_decimal=Decimal('123.45')) @@ -833,14 +730,11 @@ class ParameterTypeTestCase(TransactionTestCase): Test.objects.create(name='test1', ip='127.0.0.1') with self.assertNumQueries(2 if self.is_sqlite else 1): Test.objects.create(name='test2', ip='192.168.0.1') - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list('ip', flat=True).filter( - ip__isnull=False).order_by('ip')) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list('ip', flat=True).filter( - ip__isnull=False).order_by('ip')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, ['127.0.0.1', '192.168.0.1']) + + qs = Test.objects.values_list('ip', flat=True).filter( + ip__isnull=False).order_by('ip') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, ['127.0.0.1', '192.168.0.1']) with self.assertNumQueries(1): Test.objects.get(ip='127.0.0.1') @@ -852,15 +746,12 @@ class ParameterTypeTestCase(TransactionTestCase): Test.objects.create(name='test1', ip='2001:db8:a0b:12f0::1/64') with self.assertNumQueries(2 if self.is_sqlite else 1): Test.objects.create(name='test2', ip='2001:db8:0:85a3::ac1f:8001') - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list('ip', flat=True).filter( - ip__isnull=False).order_by('ip')) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list('ip', flat=True).filter( - ip__isnull=False).order_by('ip')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [ - '2001:db8:0:85a3::ac1f:8001', '2001:db8:a0b:12f0::1/64']) + + qs = Test.objects.values_list('ip', flat=True).filter( + ip__isnull=False).order_by('ip') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, ['2001:db8:0:85a3::ac1f:8001', + '2001:db8:a0b:12f0::1/64']) with self.assertNumQueries(1): Test.objects.get(ip='2001:db8:0:85a3::ac1f:8001') @@ -872,17 +763,12 @@ class ParameterTypeTestCase(TransactionTestCase): Test.objects.create(name='test1', duration=datetime.timedelta(30)) with self.assertNumQueries(2 if self.is_sqlite else 1): Test.objects.create(name='test2', duration=datetime.timedelta(60)) - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list( - 'duration', flat=True).filter( - duration__isnull=False).order_by('duration')) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list( - 'duration', flat=True).filter( - duration__isnull=False).order_by('duration')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [ - datetime.timedelta(30), datetime.timedelta(60)]) + + qs = Test.objects.values_list('duration', flat=True).filter( + duration__isnull=False).order_by('duration') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [datetime.timedelta(30), + datetime.timedelta(60)]) with self.assertNumQueries(1): Test.objects.get(duration=datetime.timedelta(30)) @@ -896,16 +782,11 @@ class ParameterTypeTestCase(TransactionTestCase): with self.assertNumQueries(2 if self.is_sqlite else 1): Test.objects.create(name='test2', uuid='ebb3b6e1-1737-4321-93e3-4c35d61ff491') - with self.assertNumQueries(1): - data1 = list(Test.objects.values_list( - 'uuid', flat=True).filter( - uuid__isnull=False).order_by('uuid')) - with self.assertNumQueries(0): - data2 = list(Test.objects.values_list( - 'uuid', flat=True).filter( - uuid__isnull=False).order_by('uuid')) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, [ + + qs = Test.objects.values_list('uuid', flat=True).filter( + uuid__isnull=False).order_by('uuid') + self.assert_tables(qs, 'cachalot_test') + self.assert_query_cached(qs, [ UUID('1cc401b7-09f4-4520-b8d0-c267576d196b'), UUID('ebb3b6e1-1737-4321-93e3-4c35d61ff491')]) diff --git a/cachalot/tests/settings.py b/cachalot/tests/settings.py index 1c172f3..100ecc6 100644 --- a/cachalot/tests/settings.py +++ b/cachalot/tests/settings.py @@ -14,54 +14,35 @@ from django.test.utils import override_settings from ..api import invalidate from .models import Test, TestParent, TestChild +from .test_utils import TestUtilsMixin -class SettingsTestCase(TransactionTestCase): - def setUp(self): - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() - +class SettingsTestCase(TestUtilsMixin, TransactionTestCase): @override_settings(CACHALOT_ENABLED=False) def test_decorator(self): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) + self.assert_query_cached(Test.objects.all(), after=1) def test_django_override(self): with self.settings(CACHALOT_ENABLED=False): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) + qs = Test.objects.all() + self.assert_query_cached(qs, after=1) with self.settings(CACHALOT_ENABLED=True): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs) def test_enabled(self): + qs = Test.objects.all() + with self.settings(CACHALOT_ENABLED=True): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs) with self.settings(CACHALOT_ENABLED=False): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) + self.assert_query_cached(qs, after=1) with self.assertNumQueries(0): list(Test.objects.all()) - is_sqlite = connection.vendor == 'sqlite' - with self.settings(CACHALOT_ENABLED=False): - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): t = Test.objects.create(name='test') with self.assertNumQueries(1): data = list(Test.objects.all()) @@ -74,66 +55,53 @@ class SettingsTestCase(TransactionTestCase): if alias != DEFAULT_CACHE_ALIAS) invalidate(Test, cache_alias=other_cache_alias) + qs = Test.objects.all() + with self.settings(CACHALOT_CACHE=DEFAULT_CACHE_ALIAS): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs) with self.settings(CACHALOT_CACHE=other_cache_alias): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs) Test.objects.create(name='test') # Only `CACHALOT_CACHE` is invalidated, so changing the database should # not invalidate all caches. with self.settings(CACHALOT_CACHE=other_cache_alias): - with self.assertNumQueries(0): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs, before=0) def test_cache_timeout(self): + qs = Test.objects.all() + with self.assertNumQueries(1): - list(Test.objects.all()) + list(qs.all()) sleep(1) with self.assertNumQueries(0): - list(Test.objects.all()) + list(qs.all()) invalidate(Test) with self.settings(CACHALOT_TIMEOUT=0): with self.assertNumQueries(1): - list(Test.objects.all()) + list(qs.all()) sleep(0.05) with self.assertNumQueries(1): - list(Test.objects.all()) + list(qs.all()) # We have to test with a full second and not a shorter time because # memcached only takes the integer part of the timeout into account. with self.settings(CACHALOT_TIMEOUT=1): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(qs) sleep(1) with self.assertNumQueries(1): list(Test.objects.all()) def test_cache_random(self): - with self.assertNumQueries(1): - list(Test.objects.order_by('?')) - with self.assertNumQueries(1): - list(Test.objects.order_by('?')) + qs = Test.objects.order_by('?') + self.assert_query_cached(qs, after=1, compare_results=False) with self.settings(CACHALOT_CACHE_RANDOM=True): - with self.assertNumQueries(1): - list(Test.objects.order_by('?')) - with self.assertNumQueries(0): - list(Test.objects.order_by('?')) + self.assert_query_cached(qs) def test_invalidate_raw(self): with self.assertNumQueries(1): @@ -148,85 +116,45 @@ class SettingsTestCase(TransactionTestCase): def test_only_cachable_tables(self): with self.settings(CACHALOT_ONLY_CACHABLE_TABLES=('cachalot_test',)): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + self.assert_query_cached(Test.objects.all()) + self.assert_query_cached(TestParent.objects.all(), after=1) + self.assert_query_cached(Test.objects.select_related('owner'), + after=1) - with self.assertNumQueries(1): - list(TestParent.objects.all()) - with self.assertNumQueries(1): - list(TestParent.objects.all()) - - with self.assertNumQueries(1): - list(Test.objects.select_related('owner')) - with self.assertNumQueries(1): - list(Test.objects.select_related('owner')) - - with self.assertNumQueries(1): - list(TestParent.objects.all()) - with self.assertNumQueries(0): - list(TestParent.objects.all()) + self.assert_query_cached(TestParent.objects.all()) with self.settings(CACHALOT_ONLY_CACHABLE_TABLES=( 'cachalot_test', 'cachalot_testchild', 'auth_user')): - with self.assertNumQueries(1): - list(Test.objects.select_related('owner')) - with self.assertNumQueries(0): - list(Test.objects.select_related('owner')) + self.assert_query_cached(Test.objects.select_related('owner')) # TestChild uses multi-table inheritance, and since its parent, # 'cachalot_testparent', is not cachable, a basic # TestChild query can’t be cached - with self.assertNumQueries(1): - list(TestChild.objects.all()) - with self.assertNumQueries(1): - list(TestChild.objects.all()) + self.assert_query_cached(TestChild.objects.all(), after=1) # However, if we only fetch data from the 'cachalot_testchild' # table, it’s cachable. - with self.assertNumQueries(1): - list(TestChild.objects.values('public')) - with self.assertNumQueries(0): - list(TestChild.objects.values('public')) + self.assert_query_cached(TestChild.objects.values('public')) def test_uncachable_tables(self): - with self.settings(CACHALOT_UNCACHABLE_TABLES=('cachalot_test',)): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) - - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(0): - list(Test.objects.all()) + qs = Test.objects.all() with self.settings(CACHALOT_UNCACHABLE_TABLES=('cachalot_test',)): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) + self.assert_query_cached(qs, after=1) + + self.assert_query_cached(qs) + + with self.settings(CACHALOT_UNCACHABLE_TABLES=('cachalot_test',)): + self.assert_query_cached(qs, after=1) def test_only_cachable_and_uncachable_table(self): with self.settings( CACHALOT_ONLY_CACHABLE_TABLES=('cachalot_test', 'cachalot_testparent'), CACHALOT_UNCACHABLE_TABLES=('cachalot_test',)): - with self.assertNumQueries(1): - list(Test.objects.all()) - with self.assertNumQueries(1): - list(Test.objects.all()) - - with self.assertNumQueries(1): - list(TestParent.objects.all()) - with self.assertNumQueries(0): - list(TestParent.objects.all()) - - with self.assertNumQueries(1): - list(User.objects.all()) - with self.assertNumQueries(1): - list(User.objects.all()) + self.assert_query_cached(Test.objects.all(), after=1) + self.assert_query_cached(TestParent.objects.all()) + self.assert_query_cached(User.objects.all(), after=1) def test_compatibility(self): """ diff --git a/cachalot/tests/test_utils.py b/cachalot/tests/test_utils.py new file mode 100644 index 0000000..3c3d402 --- /dev/null +++ b/cachalot/tests/test_utils.py @@ -0,0 +1,40 @@ +from django.db import connection + +from ..utils import _get_tables + + +class TestUtilsMixin: + def setUp(self): + self.is_sqlite = connection.vendor == 'sqlite' + self.is_mysql = connection.vendor == 'mysql' + self.force_repoen_connection() + + def force_repoen_connection(self): + if connection.vendor in ('mysql', 'postgresql'): + # We need to reopen the connection or Django + # will execute an extra SQL request below. + connection.cursor() + + def assert_tables(self, queryset, *tables): + self.assertSetEqual(_get_tables(queryset.db, queryset.query), + set(tables)) + + def assert_query_cached(self, queryset, result=None, result_type=None, + compare_results=True, before=1, after=0): + with self.assertNumQueries(before): + data1 = list(queryset.all()) + with self.assertNumQueries(after): + data2 = list(queryset.all()) + if not compare_results: + return + if result_type is None: + result_type = list if result is None else type(result) + assert_functions = { + list: self.assertListEqual, + set: self.assertSetEqual, + dict: self.assertDictEqual, + } + assert_function = assert_functions.get(result_type, self.assertEqual) + assert_function(data2, data1) + if result is not None: + assert_function(data2, result) diff --git a/cachalot/tests/thread_safety.py b/cachalot/tests/thread_safety.py index fd10c94..30dd2a9 100644 --- a/cachalot/tests/thread_safety.py +++ b/cachalot/tests/thread_safety.py @@ -7,6 +7,7 @@ from django.db import connection, transaction from django.test import TransactionTestCase, skipUnlessDBFeature from .models import Test +from .test_utils import TestUtilsMixin class TestThread(Thread): @@ -21,13 +22,7 @@ class TestThread(Thread): @skipUnlessDBFeature('test_db_allows_multiple_connections') -class ThreadSafetyTestCase(TransactionTestCase): - def setUp(self): - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() - +class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): def test_concurrent_caching(self): t1 = TestThread().start_and_join() t = Test.objects.create(name='test') diff --git a/cachalot/tests/transaction.py b/cachalot/tests/transaction.py index 7e0013e..60b378f 100644 --- a/cachalot/tests/transaction.py +++ b/cachalot/tests/transaction.py @@ -3,23 +3,16 @@ from __future__ import unicode_literals from django.contrib.auth.models import User -from django.db import connection, transaction +from django.db import transaction from django.test import TransactionTestCase from .models import Test +from .test_utils import TestUtilsMixin -class AtomicTestCase(TransactionTestCase): - def setUp(self): - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() - +class AtomicTestCase(TestUtilsMixin, TransactionTestCase): def test_successful_read_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): with transaction.atomic(): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) @@ -29,9 +22,7 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, []) def test_unsuccessful_read_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): try: with transaction.atomic(): data1 = list(Test.objects.all()) @@ -49,23 +40,21 @@ class AtomicTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): with transaction.atomic(): t1 = Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(Test.objects.all()) self.assertListEqual(data2, [t1]) - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): with transaction.atomic(): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data3 = list(Test.objects.all()) self.assertListEqual(data3, [t1, t2]) - with self.assertNumQueries(4 if is_sqlite else 3): + with self.assertNumQueries(4 if self.is_sqlite else 3): with transaction.atomic(): data4 = list(Test.objects.all()) t3 = Test.objects.create(name='test3') @@ -80,9 +69,7 @@ class AtomicTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): try: with transaction.atomic(): Test.objects.create(name='test') @@ -97,9 +84,7 @@ class AtomicTestCase(TransactionTestCase): Test.objects.get(name='test') def test_cache_inside_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(2 if is_sqlite else 1): + with self.assertNumQueries(2 if self.is_sqlite else 1): with transaction.atomic(): data1 = list(Test.objects.all()) data2 = list(Test.objects.all()) @@ -107,9 +92,7 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, []) def test_invalidation_inside_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(4 if is_sqlite else 3): + with self.assertNumQueries(4 if self.is_sqlite else 3): with transaction.atomic(): data1 = list(Test.objects.all()) t = Test.objects.create(name='test') @@ -118,9 +101,7 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, [t]) def test_successful_nested_read_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(7 if is_sqlite else 6): + with self.assertNumQueries(7 if self.is_sqlite else 6): with transaction.atomic(): list(Test.objects.all()) with transaction.atomic(): @@ -135,10 +116,7 @@ class AtomicTestCase(TransactionTestCase): list(User.objects.all()) def test_unsuccessful_nested_read_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - num_queries = 6 if is_sqlite else 5 - - with self.assertNumQueries(num_queries): + with self.assertNumQueries(6 if self.is_sqlite else 5): with transaction.atomic(): try: with transaction.atomic(): @@ -151,9 +129,7 @@ class AtomicTestCase(TransactionTestCase): list(Test.objects.all()) def test_successful_nested_write_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - - with self.assertNumQueries(13 if is_sqlite else 12): + with self.assertNumQueries(13 if self.is_sqlite else 12): with transaction.atomic(): t1 = Test.objects.create(name='test1') with transaction.atomic(): @@ -170,10 +146,7 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data3, [t1, t2, t3, t4]) def test_unsuccessful_nested_write_atomic(self): - is_sqlite = connection.vendor == 'sqlite' - num_queries = 16 if is_sqlite else 15 - - with self.assertNumQueries(num_queries): + with self.assertNumQueries(16 if self.is_sqlite else 15): with transaction.atomic(): t1 = Test.objects.create(name='test1') try: diff --git a/cachalot/tests/write.py b/cachalot/tests/write.py index 90e1c30..0d451f0 100644 --- a/cachalot/tests/write.py +++ b/cachalot/tests/write.py @@ -13,24 +13,18 @@ from django.db.models.expressions import RawSQL from django.test import TransactionTestCase, skipUnlessDBFeature from .models import Test, TestParent, TestChild +from .test_utils import TestUtilsMixin DJANGO_GTE_1_11 = django_version[:2] >= (1, 11) -class WriteTestCase(TransactionTestCase): +class WriteTestCase(TestUtilsMixin, TransactionTestCase): """ Tests if every SQL request writing data is not cached and invalidates the implied data. """ - def setUp(self): - self.is_sqlite = connection.vendor == 'sqlite' - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() - def test_create(self): with self.assertNumQueries(1): data1 = list(Test.objects.all()) @@ -536,37 +530,25 @@ class WriteTestCase(TransactionTestCase): (permission.pk,)) with self.assertNumQueries(1): data1 = list(Test.objects.filter(permission=raw_sql)) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter(permission=raw_sql)) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, []) + self.assertListEqual(data1, []) test = Test.objects.create(name='test', permission=permission) with self.assertNumQueries(1): - data3 = list(Test.objects.filter(permission=raw_sql)) - with self.assertNumQueries(0): - data4 = list(Test.objects.filter(permission=raw_sql)) - self.assertListEqual(data4, data3) - self.assertListEqual(data4, [test]) + data2 = list(Test.objects.filter(permission=raw_sql)) + self.assertListEqual(data2, [test]) - Permission.objects.first().save() + permission.save() with self.assertNumQueries(1): - data5 = list(Test.objects.filter(permission=raw_sql)) - with self.assertNumQueries(0): - data6 = list(Test.objects.filter(permission=raw_sql)) - self.assertListEqual(data6, data5) - self.assertListEqual(data6, [test]) + data3 = list(Test.objects.filter(permission=raw_sql)) + self.assertListEqual(data3, [test]) test.delete() with self.assertNumQueries(1): - data7 = list(Test.objects.filter(permission=raw_sql)) - with self.assertNumQueries(0): - data8 = list(Test.objects.filter(permission=raw_sql)) - self.assertListEqual(data8, data7) - self.assertListEqual(data8, []) + data4 = list(Test.objects.filter(permission=raw_sql)) + self.assertListEqual(data4, []) def test_invalidate_nested_raw_subquery(self): permission = Permission.objects.first() @@ -575,44 +557,28 @@ class WriteTestCase(TransactionTestCase): with self.assertNumQueries(1): data1 = list(Test.objects.filter( pk__in=Test.objects.filter(permission=raw_sql))) - with self.assertNumQueries(0): - data2 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - self.assertListEqual(data2, data1) - self.assertListEqual(data2, []) + self.assertListEqual(data1, []) test = Test.objects.create(name='test', permission=permission) with self.assertNumQueries(1): - data3 = list(Test.objects.filter( + data2 = list(Test.objects.filter( pk__in=Test.objects.filter(permission=raw_sql))) - with self.assertNumQueries(0): - data4 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - self.assertListEqual(data4, data3) - self.assertListEqual(data4, [test]) + self.assertListEqual(data2, [test]) - Permission.objects.first().save() + permission.save() with self.assertNumQueries(1): - data5 = list(Test.objects.filter( + data3 = list(Test.objects.filter( pk__in=Test.objects.filter(permission=raw_sql))) - with self.assertNumQueries(0): - data6 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - self.assertListEqual(data6, data5) - self.assertListEqual(data6, [test]) + self.assertListEqual(data3, [test]) test.delete() with self.assertNumQueries(1): - data7 = list(Test.objects.filter( + data4 = list(Test.objects.filter( pk__in=Test.objects.filter(permission=raw_sql))) - with self.assertNumQueries(0): - data8 = list(Test.objects.filter( - pk__in=Test.objects.filter(permission=raw_sql))) - self.assertListEqual(data8, data7) - self.assertListEqual(data8, []) + self.assertListEqual(data4, []) def test_invalidate_select_related(self): with self.assertNumQueries(1): @@ -993,7 +959,7 @@ class WriteTestCase(TransactionTestCase): []) -class DatabaseCommandTestCase(TransactionTestCase): +class DatabaseCommandTestCase(TestUtilsMixin, TransactionTestCase): def setUp(self): self.t = Test.objects.create(name='test1') @@ -1003,10 +969,7 @@ class DatabaseCommandTestCase(TransactionTestCase): call_command('flush', verbosity=0, interactive=False) - if connection.vendor == 'mysql': - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() + self.force_repoen_connection() with self.assertNumQueries(1): self.assertListEqual(list(Test.objects.all()), []) @@ -1018,10 +981,7 @@ class DatabaseCommandTestCase(TransactionTestCase): call_command('loaddata', 'cachalot/tests/loaddata_fixture.json', verbosity=0, interactive=False) - if connection.vendor in ('mysql', 'postgresql'): - # We need to reopen the connection or Django - # will execute an extra SQL request below. - connection.cursor() + self.force_repoen_connection() with self.assertNumQueries(1): self.assertListEqual([t.name for t in Test.objects.all()],