override ContentType.__init__ to use a thread local cache (fixes #26206)
This commit is contained in:
parent
df51fe2c3e
commit
631743c45e
|
@ -2,11 +2,17 @@ import re
|
|||
import warnings
|
||||
import psycopg2
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.contenttypes.models import ContentType, ContentTypeManager
|
||||
from django.contrib.contenttypes.models import ContentTypeManager, ContentType
|
||||
from django.core.exceptions import ImproperlyConfigured, ValidationError
|
||||
import django.db.utils
|
||||
from django.db import connection
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
from django.utils.lru_cache import lru_cache
|
||||
|
||||
from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls
|
||||
from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection
|
||||
|
@ -82,14 +88,6 @@ class DatabaseWrapper(original_backend.DatabaseWrapper):
|
|||
self.include_public_schema = include_public
|
||||
self.set_settings_schema(schema_name)
|
||||
self.search_path_set = False
|
||||
# Content type can no longer be cached as public and tenant schemas
|
||||
# have different models. If someone wants to change this, the cache
|
||||
# needs to be separated between public and shared schemas. If this
|
||||
# cache isn't cleared, this can cause permission problems. For example,
|
||||
# on public, a particular model has id 14, but on the tenants it has
|
||||
# the id 15. if 14 is cached instead of 15, the permissions for the
|
||||
# wrong model will be fetched.
|
||||
ContentType.objects.clear_cache()
|
||||
|
||||
def set_schema_to_public(self):
|
||||
"""
|
||||
|
@ -178,9 +176,41 @@ class FakeTenant:
|
|||
|
||||
# Make the ContentType cache tenant and thread safe
|
||||
ContentTypeManager._thread_local_cache = threading.local()
|
||||
ContentTypeManager_old__init__ = ContentTypeManager.__init__
|
||||
|
||||
|
||||
class ContentTypeCacheDescriptor(object):
|
||||
def __get__(self, obj, owner):
|
||||
if not hasattr(owner._thread_local_cache, '_cache'):
|
||||
owner._thread_local_cache._cache = {}
|
||||
return owner._thread_local_cache._cache
|
||||
# use weak for transient Manager
|
||||
owner._thread_local_cache._cache = weakref.WeakKeyDictionary()
|
||||
global_cache = owner._thread_local_cache._cache
|
||||
get_cache = global_cache.get(owner)
|
||||
if not get_cache:
|
||||
# use an LRU cache to evict dead tenants with time and to prevent
|
||||
# bloat with lot of tenants
|
||||
@lru_cache(maxsize=200)
|
||||
def get_cache(schema_name):
|
||||
return {}
|
||||
global_cache[owner] = get_cache
|
||||
tenant = getattr(connection, 'tenant', None)
|
||||
schema_name = getattr(tenant, 'schema_name', 'public')
|
||||
return get_cache(schema_name)
|
||||
|
||||
ContentTypeManager._cache = ContentTypeCacheDescriptor()
|
||||
|
||||
|
||||
def ContentTypeManager_new__init__(self, *args, **kwargs):
|
||||
ContentTypeManager_old__init__(self, *args, **kwargs)
|
||||
if '_cache' in self.__dict__:
|
||||
del self._cache
|
||||
ContentTypeManager.__init__ = ContentTypeManager_new__init__
|
||||
|
||||
if hasattr(ContentType._meta, 'local_managers'):
|
||||
for manager in ContentType._meta.local_managers:
|
||||
if '_cache' in manager.__dict__:
|
||||
del manager._cache
|
||||
else:
|
||||
for _, manager, _ in ContentType._meta.managers:
|
||||
if '_cache' in manager.__dict__:
|
||||
del manager._cache
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from django.conf import settings
|
||||
from django.core.management import call_command
|
||||
from django.db import connection
|
||||
from django.test import TestCase
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from tenant_schemas.utils import get_public_schema_name, get_tenant_model
|
||||
|
||||
ALLOWED_TEST_DOMAIN = '.test.com'
|
||||
|
@ -46,6 +46,10 @@ class TenantTestCase(TestCase):
|
|||
verbosity=0)
|
||||
|
||||
|
||||
class TenantTransactionTestCase(TenantTestCase, TransactionTestCase):
|
||||
pass
|
||||
|
||||
|
||||
class FastTenantTestCase(TenantTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
@ -12,7 +12,7 @@ from django.core.management import call_command
|
|||
from django.db import connection
|
||||
from dts_test_app.models import DummyModel, ModelWithFkToPublicUser
|
||||
|
||||
from tenant_schemas.test.cases import TenantTestCase
|
||||
from tenant_schemas.test.cases import TenantTestCase, TenantTransactionTestCase
|
||||
from tenant_schemas.tests.models import Tenant, NonAutoSyncTenant
|
||||
from tenant_schemas.tests.testcases import BaseTestCase
|
||||
from tenant_schemas.utils import tenant_context, schema_context, schema_exists, get_tenant_model, get_public_schema_name
|
||||
|
@ -171,6 +171,71 @@ class TenantDataAndSettingsTest(BaseTestCase):
|
|||
DummyModel(name="Survived it!").save()
|
||||
|
||||
|
||||
class TenantContentTypeTest(TenantTransactionTestCase):
|
||||
def test_content_type_cache(self):
|
||||
import threading
|
||||
from django.apps import apps
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
connection.set_schema_to_public()
|
||||
tenant1 = Tenant(domain_url='something.test.com',
|
||||
schema_name='tenant1')
|
||||
tenant1.save(verbosity=BaseTestCase.get_verbosity())
|
||||
|
||||
connection.set_schema_to_public()
|
||||
tenant2 = Tenant(domain_url='example.com', schema_name='tenant2')
|
||||
tenant2.save(verbosity=BaseTestCase.get_verbosity())
|
||||
|
||||
# go to tenant1's path
|
||||
connection.set_tenant(tenant1)
|
||||
|
||||
def collects_cts(d):
|
||||
for app in apps.get_app_configs():
|
||||
for model in app.get_models():
|
||||
d[model] = ContentType.objects.get_for_model(model)
|
||||
|
||||
# check cache is thread and tenant local
|
||||
tenant1_cts = {}
|
||||
collects_cts(tenant1_cts)
|
||||
|
||||
# switch temporarily to tenant2's path
|
||||
tenant2_cts = {}
|
||||
with tenant_context(tenant2):
|
||||
collects_cts(tenant2_cts)
|
||||
tenant1_thread_cts = {}
|
||||
|
||||
class T(threading.Thread):
|
||||
def run(self):
|
||||
with tenant_context(tenant1):
|
||||
from django.db import connection
|
||||
collects_cts(tenant1_thread_cts)
|
||||
connection.close()
|
||||
|
||||
t = T()
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
for app in apps.get_app_configs():
|
||||
for model in app.get_models():
|
||||
assert tenant1_cts[model] is not tenant2_cts[model]
|
||||
assert tenant2_cts[model] is not tenant1_thread_cts[model]
|
||||
assert tenant1_cts[model] is not tenant1_thread_cts[model]
|
||||
|
||||
# check cache is effective
|
||||
tenant12_cts = {}
|
||||
with tenant_context(tenant1):
|
||||
collects_cts(tenant12_cts)
|
||||
|
||||
tenant22_cts = {}
|
||||
with tenant_context(tenant2):
|
||||
collects_cts(tenant22_cts)
|
||||
|
||||
for app in apps.get_app_configs():
|
||||
for model in app.get_models():
|
||||
assert tenant12_cts[model] is tenant1_cts[model]
|
||||
assert tenant22_cts[model] is tenant2_cts[model]
|
||||
|
||||
|
||||
class TenantSyncTest(BaseTestCase):
|
||||
"""
|
||||
Tests if the shared apps and the tenant apps get synced correctly
|
||||
|
|
Loading…
Reference in New Issue