diff --git a/tenant_schemas/postgresql_backend/base.py b/tenant_schemas/postgresql_backend/base.py index 2bc836d..d6ae824 100644 --- a/tenant_schemas/postgresql_backend/base.py +++ b/tenant_schemas/postgresql_backend/base.py @@ -2,11 +2,17 @@ import re import warnings import psycopg2 import threading +import weakref from django.conf import settings -from django.contrib.contenttypes.models import ContentType, ContentTypeManager +from django.contrib.contenttypes.models import ContentTypeManager, ContentType from django.core.exceptions import ImproperlyConfigured, ValidationError import django.db.utils +from django.db import connection +try: + from functools import lru_cache +except ImportError: + from django.utils.lru_cache import lru_cache from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection @@ -82,14 +88,6 @@ class DatabaseWrapper(original_backend.DatabaseWrapper): self.include_public_schema = include_public self.set_settings_schema(schema_name) self.search_path_set = False - # Content type can no longer be cached as public and tenant schemas - # have different models. If someone wants to change this, the cache - # needs to be separated between public and shared schemas. If this - # cache isn't cleared, this can cause permission problems. For example, - # on public, a particular model has id 14, but on the tenants it has - # the id 15. if 14 is cached instead of 15, the permissions for the - # wrong model will be fetched. - ContentType.objects.clear_cache() def set_schema_to_public(self): """ @@ -178,9 +176,41 @@ class FakeTenant: # Make the ContentType cache tenant and thread safe ContentTypeManager._thread_local_cache = threading.local() +ContentTypeManager_old__init__ = ContentTypeManager.__init__ + + class ContentTypeCacheDescriptor(object): def __get__(self, obj, owner): if not hasattr(owner._thread_local_cache, '_cache'): - owner._thread_local_cache._cache = {} - return owner._thread_local_cache._cache + # use weak for transient Manager + owner._thread_local_cache._cache = weakref.WeakKeyDictionary() + global_cache = owner._thread_local_cache._cache + get_cache = global_cache.get(owner) + if not get_cache: + # use an LRU cache to evict dead tenants with time and to prevent + # bloat with lot of tenants + @lru_cache(maxsize=200) + def get_cache(schema_name): + return {} + global_cache[owner] = get_cache + tenant = getattr(connection, 'tenant', None) + schema_name = getattr(tenant, 'schema_name', 'public') + return get_cache(schema_name) + ContentTypeManager._cache = ContentTypeCacheDescriptor() + + +def ContentTypeManager_new__init__(self, *args, **kwargs): + ContentTypeManager_old__init__(self, *args, **kwargs) + if '_cache' in self.__dict__: + del self._cache +ContentTypeManager.__init__ = ContentTypeManager_new__init__ + +if hasattr(ContentType._meta, 'local_managers'): + for manager in ContentType._meta.local_managers: + if '_cache' in manager.__dict__: + del manager._cache +else: + for _, manager, _ in ContentType._meta.managers: + if '_cache' in manager.__dict__: + del manager._cache diff --git a/tenant_schemas/test/cases.py b/tenant_schemas/test/cases.py index beef8dd..68c6ee5 100644 --- a/tenant_schemas/test/cases.py +++ b/tenant_schemas/test/cases.py @@ -1,7 +1,7 @@ from django.conf import settings from django.core.management import call_command from django.db import connection -from django.test import TestCase +from django.test import TestCase, TransactionTestCase from tenant_schemas.utils import get_public_schema_name, get_tenant_model ALLOWED_TEST_DOMAIN = '.test.com' @@ -46,6 +46,10 @@ class TenantTestCase(TestCase): verbosity=0) +class TenantTransactionTestCase(TenantTestCase, TransactionTestCase): + pass + + class FastTenantTestCase(TenantTestCase): @classmethod def setUpClass(cls): diff --git a/tenant_schemas/tests/test_tenants.py b/tenant_schemas/tests/test_tenants.py index 2bf26d8..9ef4966 100644 --- a/tenant_schemas/tests/test_tenants.py +++ b/tenant_schemas/tests/test_tenants.py @@ -12,7 +12,7 @@ from django.core.management import call_command from django.db import connection from dts_test_app.models import DummyModel, ModelWithFkToPublicUser -from tenant_schemas.test.cases import TenantTestCase +from tenant_schemas.test.cases import TenantTestCase, TenantTransactionTestCase from tenant_schemas.tests.models import Tenant, NonAutoSyncTenant from tenant_schemas.tests.testcases import BaseTestCase from tenant_schemas.utils import tenant_context, schema_context, schema_exists, get_tenant_model, get_public_schema_name @@ -171,6 +171,71 @@ class TenantDataAndSettingsTest(BaseTestCase): DummyModel(name="Survived it!").save() +class TenantContentTypeTest(TenantTransactionTestCase): + def test_content_type_cache(self): + import threading + from django.apps import apps + from django.contrib.contenttypes.models import ContentType + + connection.set_schema_to_public() + tenant1 = Tenant(domain_url='something.test.com', + schema_name='tenant1') + tenant1.save(verbosity=BaseTestCase.get_verbosity()) + + connection.set_schema_to_public() + tenant2 = Tenant(domain_url='example.com', schema_name='tenant2') + tenant2.save(verbosity=BaseTestCase.get_verbosity()) + + # go to tenant1's path + connection.set_tenant(tenant1) + + def collects_cts(d): + for app in apps.get_app_configs(): + for model in app.get_models(): + d[model] = ContentType.objects.get_for_model(model) + + # check cache is thread and tenant local + tenant1_cts = {} + collects_cts(tenant1_cts) + + # switch temporarily to tenant2's path + tenant2_cts = {} + with tenant_context(tenant2): + collects_cts(tenant2_cts) + tenant1_thread_cts = {} + + class T(threading.Thread): + def run(self): + with tenant_context(tenant1): + from django.db import connection + collects_cts(tenant1_thread_cts) + connection.close() + + t = T() + t.start() + t.join() + + for app in apps.get_app_configs(): + for model in app.get_models(): + assert tenant1_cts[model] is not tenant2_cts[model] + assert tenant2_cts[model] is not tenant1_thread_cts[model] + assert tenant1_cts[model] is not tenant1_thread_cts[model] + + # check cache is effective + tenant12_cts = {} + with tenant_context(tenant1): + collects_cts(tenant12_cts) + + tenant22_cts = {} + with tenant_context(tenant2): + collects_cts(tenant22_cts) + + for app in apps.get_app_configs(): + for model in app.get_models(): + assert tenant12_cts[model] is tenant1_cts[model] + assert tenant22_cts[model] is tenant2_cts[model] + + class TenantSyncTest(BaseTestCase): """ Tests if the shared apps and the tenant apps get synced correctly