import re import warnings import psycopg2 import threading import weakref from django.conf import settings from django.contrib.contenttypes.models import ContentTypeManager, ContentType from django.core.exceptions import ImproperlyConfigured, ValidationError import django.db.utils from django.db import connection try: from functools import lru_cache except ImportError: from django.utils.lru_cache import lru_cache from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2') # Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql' original_backend = django.db.utils.load_backend(ORIGINAL_BACKEND) EXTRA_SEARCH_PATHS = getattr(settings, 'PG_EXTRA_SEARCH_PATHS', []) # from the postgresql doc SQL_IDENTIFIER_RE = re.compile(r'^[_a-zA-Z][_a-zA-Z0-9]{,62}$') SQL_SCHEMA_NAME_RESERVED_RE = re.compile(r'^pg_', re.IGNORECASE) def _is_valid_identifier(identifier): return bool(SQL_IDENTIFIER_RE.match(identifier)) def _check_identifier(identifier): if not _is_valid_identifier(identifier): raise ValidationError("Invalid string used for the identifier.") def _is_valid_schema_name(name): return _is_valid_identifier(name) and not SQL_SCHEMA_NAME_RESERVED_RE.match(name) def _check_schema_name(name): if not _is_valid_schema_name(name): raise ValidationError("Invalid string used for the schema name.") class DatabaseWrapper(original_backend.DatabaseWrapper): """ Adds the capability to manipulate the search_path using set_tenant and set_schema_name """ include_public_schema = True def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) # Use a patched version of the DatabaseIntrospection that only returns the table list for the # currently selected schema. self.introspection = DatabaseSchemaIntrospection(self) self.set_schema_to_public() def close(self): self.search_path_set = False super(DatabaseWrapper, self).close() def rollback(self): super(DatabaseWrapper, self).rollback() # Django's rollback clears the search path so we have to set it again the next time. self.search_path_set = False def set_tenant(self, tenant, include_public=True): """ Main API method to current database schema, but it does not actually modify the db connection. """ self.set_schema(tenant.schema_name, include_public) self.tenant = tenant def set_schema(self, schema_name, include_public=True): """ Main API method to current database schema, but it does not actually modify the db connection. """ self.tenant = FakeTenant(schema_name=schema_name) self.schema_name = schema_name self.include_public_schema = include_public self.set_settings_schema(schema_name) self.search_path_set = False def set_schema_to_public(self): """ Instructs to stay in the common 'public' schema. """ self.set_schema(get_public_schema_name()) def set_settings_schema(self, schema_name): self.settings_dict['SCHEMA'] = schema_name def get_schema(self): warnings.warn("connection.get_schema() is deprecated, use connection.schema_name instead.", category=DeprecationWarning) return self.schema_name def get_tenant(self): warnings.warn("connection.get_tenant() is deprecated, use connection.tenant instead.", category=DeprecationWarning) return self.tenant def _cursor(self, name=None): """ Here it happens. We hope every Django db operation using PostgreSQL must go through this to get the cursor handle. We change the path. """ if name: # Only supported and required by Django 1.11 (server-side cursor) cursor = super(DatabaseWrapper, self)._cursor(name=name) else: cursor = super(DatabaseWrapper, self)._cursor() # optionally limit the number of executions - under load, the execution # of `set search_path` can be quite time consuming if (not get_limit_set_calls()) or not self.search_path_set: # Actual search_path modification for the cursor. Database will # search schemata from left to right when looking for the object # (table, index, sequence, etc.). if not self.schema_name: raise ImproperlyConfigured("Database schema not set. Did you forget " "to call set_schema() or set_tenant()?") _check_schema_name(self.schema_name) public_schema_name = get_public_schema_name() search_paths = [] if self.schema_name == public_schema_name: search_paths = [public_schema_name] elif self.include_public_schema: search_paths = [self.schema_name, public_schema_name] else: search_paths = [self.schema_name] search_paths.extend(EXTRA_SEARCH_PATHS) if name: # Named cursor can only be used once cursor_for_search_path = self.connection.cursor() else: # Reuse cursor_for_search_path = cursor # In the event that an error already happened in this transaction and we are going # to rollback we should just ignore database error when setting the search_path # if the next instruction is not a rollback it will just fail also, so # we do not have to worry that it's not the good one try: cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths))) cursor_for_search_path.execute('SET application_name = {0}'.format(self.schema_name)) except (django.db.utils.DatabaseError, psycopg2.InternalError): self.search_path_set = False else: self.search_path_set = True if name: cursor_for_search_path.close() return cursor class FakeTenant: """ We can't import any db model in a backend (apparently?), so this class is used for wrapping schema names in a tenant-like structure. """ def __init__(self, schema_name): self.schema_name = schema_name # Make the ContentType cache tenant and thread safe ContentTypeManager._thread_local_cache = threading.local() ContentTypeManager_old__init__ = ContentTypeManager.__init__ class ContentTypeCacheDescriptor(object): def __get__(self, obj, owner): if not hasattr(owner._thread_local_cache, '_cache'): # use weak for transient Manager owner._thread_local_cache._cache = weakref.WeakKeyDictionary() global_cache = owner._thread_local_cache._cache model_cache = global_cache.setdefault(obj, weakref.WeakKeyDictionary()) get_cache = model_cache.get(obj.model) if not get_cache: # use an LRU cache to evict dead tenants with time and to prevent # bloat with lot of tenants @lru_cache(maxsize=200) def get_cache(schema_name): return {} model_cache[obj.model] = get_cache tenant = getattr(connection, 'tenant', None) schema_name = getattr(tenant, 'schema_name', 'public') return get_cache(schema_name) ContentTypeManager._cache = ContentTypeCacheDescriptor() def ContentTypeManager_new__init__(self, *args, **kwargs): ContentTypeManager_old__init__(self, *args, **kwargs) if '_cache' in self.__dict__: del self._cache ContentTypeManager.__init__ = ContentTypeManager_new__init__ if '_cache' in ContentType.objects.__dict__: del ContentType.objects._cache if hasattr(ContentType._meta, 'local_managers'): for manager in ContentType._meta.local_managers: if '_cache' in manager.__dict__: del manager._cache else: for _, manager, _ in ContentType._meta.managers: if '_cache' in manager.__dict__: del manager._cache