diff --git a/tenant_schemas/tests/test_tenants.py b/tenant_schemas/tests/test_tenants.py index 9ef4966..9d0ce09 100644 --- a/tenant_schemas/tests/test_tenants.py +++ b/tenant_schemas/tests/test_tenants.py @@ -172,10 +172,24 @@ class TenantDataAndSettingsTest(BaseTestCase): class TenantContentTypeTest(TenantTransactionTestCase): + @classmethod + def setUpClass(cls): + super(TenantContentTypeTest, cls).setUpClass() + settings.SHARED_APPS = ('tenant_schemas', ) + settings.TENANT_APPS = ('dts_test_app', + 'django.contrib.contenttypes', + 'django.contrib.auth', ) + settings.INSTALLED_APPS = settings.SHARED_APPS + settings.TENANT_APPS + cls.sync_shared() + def test_content_type_cache(self): import threading from django.apps import apps from django.contrib.contenttypes.models import ContentType + try: + from django.contrib.contenttypes.management import create_contenttypes + except ImportError: + from django.contrib.contenttypes.management import update_contenttypes as create_contenttypes connection.set_schema_to_public() tenant1 = Tenant(domain_url='something.test.com', @@ -186,8 +200,13 @@ class TenantContentTypeTest(TenantTransactionTestCase): tenant2 = Tenant(domain_url='example.com', schema_name='tenant2') tenant2.save(verbosity=BaseTestCase.get_verbosity()) - # go to tenant1's path - connection.set_tenant(tenant1) + with tenant_context(tenant2): + # recreate contenttypes to change primary key + ContentType.objects.all().delete() + assert ContentType.objects.count() == 0 + for app_config in apps.get_app_configs(): + create_contenttypes(app_config) + assert ContentType.objects.count() != 0 def collects_cts(d): for app in apps.get_app_configs(): @@ -196,13 +215,14 @@ class TenantContentTypeTest(TenantTransactionTestCase): # check cache is thread and tenant local tenant1_cts = {} - collects_cts(tenant1_cts) + # go to tenant1's path + with tenant_context(tenant1): + collects_cts(tenant1_cts) # switch temporarily to tenant2's path tenant2_cts = {} with tenant_context(tenant2): collects_cts(tenant2_cts) - tenant1_thread_cts = {} class T(threading.Thread): def run(self): @@ -210,6 +230,7 @@ class TenantContentTypeTest(TenantTransactionTestCase): from django.db import connection collects_cts(tenant1_thread_cts) connection.close() + tenant1_thread_cts = {} t = T() t.start() @@ -220,18 +241,27 @@ class TenantContentTypeTest(TenantTransactionTestCase): assert tenant1_cts[model] is not tenant2_cts[model] assert tenant2_cts[model] is not tenant1_thread_cts[model] assert tenant1_cts[model] is not tenant1_thread_cts[model] + # check pk differs between tenants, as we recreated them + assert tenant1_cts[model].pk is not tenant2_cts[model].pk + assert tenant2_cts[model].pk is not tenant1_thread_cts[model].pk # check cache is effective tenant12_cts = {} with tenant_context(tenant1): + for key in tenant1_cts: + assert ContentType.objects.get(pk=tenant1_cts[key].pk).model == tenant1_cts[key].model collects_cts(tenant12_cts) tenant22_cts = {} with tenant_context(tenant2): + for key in tenant2_cts: + assert ContentType.objects.get(pk=tenant2_cts[key].pk).model == tenant2_cts[key].model collects_cts(tenant22_cts) for app in apps.get_app_configs(): for model in app.get_models(): + assert tenant22_cts[model].pk is tenant2_cts[model].pk + assert tenant12_cts[model].pk is tenant1_cts[model].pk assert tenant12_cts[model] is tenant1_cts[model] assert tenant22_cts[model] is tenant2_cts[model]