debian-django-tenant-schemas/tenant_schemas/postgresql_backend/base.py

222 lines
8.2 KiB
Python

import re
import warnings
import psycopg2
import threading
import weakref
from django.conf import settings
from django.contrib.contenttypes.models import ContentTypeManager, ContentType
from django.core.exceptions import ImproperlyConfigured, ValidationError
import django.db.utils
from django.db import connection
try:
from functools import lru_cache
except ImportError:
from django.utils.lru_cache import lru_cache
from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls
from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection
ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2')
# Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql'
original_backend = django.db.utils.load_backend(ORIGINAL_BACKEND)
EXTRA_SEARCH_PATHS = getattr(settings, 'PG_EXTRA_SEARCH_PATHS', [])
# from the postgresql doc
SQL_IDENTIFIER_RE = re.compile(r'^[_a-zA-Z][_a-zA-Z0-9]{,62}$')
SQL_SCHEMA_NAME_RESERVED_RE = re.compile(r'^pg_', re.IGNORECASE)
def _is_valid_identifier(identifier):
return bool(SQL_IDENTIFIER_RE.match(identifier))
def _check_identifier(identifier):
if not _is_valid_identifier(identifier):
raise ValidationError("Invalid string used for the identifier.")
def _is_valid_schema_name(name):
return _is_valid_identifier(name) and not SQL_SCHEMA_NAME_RESERVED_RE.match(name)
def _check_schema_name(name):
if not _is_valid_schema_name(name):
raise ValidationError("Invalid string used for the schema name.")
class DatabaseWrapper(original_backend.DatabaseWrapper):
"""
Adds the capability to manipulate the search_path using set_tenant and set_schema_name
"""
include_public_schema = True
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
# Use a patched version of the DatabaseIntrospection that only returns the table list for the
# currently selected schema.
self.introspection = DatabaseSchemaIntrospection(self)
self.set_schema_to_public()
def close(self):
self.search_path_set = False
super(DatabaseWrapper, self).close()
def rollback(self):
super(DatabaseWrapper, self).rollback()
# Django's rollback clears the search path so we have to set it again the next time.
self.search_path_set = False
def set_tenant(self, tenant, include_public=True):
"""
Main API method to current database schema,
but it does not actually modify the db connection.
"""
self.set_schema(tenant.schema_name, include_public)
self.tenant = tenant
def set_schema(self, schema_name, include_public=True):
"""
Main API method to current database schema,
but it does not actually modify the db connection.
"""
self.tenant = FakeTenant(schema_name=schema_name)
self.schema_name = schema_name
self.include_public_schema = include_public
self.set_settings_schema(schema_name)
self.search_path_set = False
def set_schema_to_public(self):
"""
Instructs to stay in the common 'public' schema.
"""
self.set_schema(get_public_schema_name())
def set_settings_schema(self, schema_name):
self.settings_dict['SCHEMA'] = schema_name
def get_schema(self):
warnings.warn("connection.get_schema() is deprecated, use connection.schema_name instead.",
category=DeprecationWarning)
return self.schema_name
def get_tenant(self):
warnings.warn("connection.get_tenant() is deprecated, use connection.tenant instead.",
category=DeprecationWarning)
return self.tenant
def _cursor(self, name=None):
"""
Here it happens. We hope every Django db operation using PostgreSQL
must go through this to get the cursor handle. We change the path.
"""
if name:
# Only supported and required by Django 1.11 (server-side cursor)
cursor = super(DatabaseWrapper, self)._cursor(name=name)
else:
cursor = super(DatabaseWrapper, self)._cursor()
# optionally limit the number of executions - under load, the execution
# of `set search_path` can be quite time consuming
if (not get_limit_set_calls()) or not self.search_path_set:
# Actual search_path modification for the cursor. Database will
# search schemata from left to right when looking for the object
# (table, index, sequence, etc.).
if not self.schema_name:
raise ImproperlyConfigured("Database schema not set. Did you forget "
"to call set_schema() or set_tenant()?")
_check_schema_name(self.schema_name)
public_schema_name = get_public_schema_name()
search_paths = []
if self.schema_name == public_schema_name:
search_paths = [public_schema_name]
elif self.include_public_schema:
search_paths = [self.schema_name, public_schema_name]
else:
search_paths = [self.schema_name]
search_paths.extend(EXTRA_SEARCH_PATHS)
if name:
# Named cursor can only be used once
cursor_for_search_path = self.connection.cursor()
else:
# Reuse
cursor_for_search_path = cursor
# In the event that an error already happened in this transaction and we are going
# to rollback we should just ignore database error when setting the search_path
# if the next instruction is not a rollback it will just fail also, so
# we do not have to worry that it's not the good one
try:
cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths)))
cursor_for_search_path.execute('SET application_name = {0}'.format(self.schema_name))
except (django.db.utils.DatabaseError, psycopg2.InternalError):
self.search_path_set = False
else:
self.search_path_set = True
if name:
cursor_for_search_path.close()
return cursor
class FakeTenant:
"""
We can't import any db model in a backend (apparently?), so this class is used
for wrapping schema names in a tenant-like structure.
"""
def __init__(self, schema_name):
self.schema_name = schema_name
# Make the ContentType cache tenant and thread safe
ContentTypeManager._thread_local_cache = threading.local()
ContentTypeManager_old__init__ = ContentTypeManager.__init__
class ContentTypeCacheDescriptor(object):
def __get__(self, obj, owner):
if not hasattr(owner._thread_local_cache, '_cache'):
# use weak for transient Manager
owner._thread_local_cache._cache = weakref.WeakKeyDictionary()
global_cache = owner._thread_local_cache._cache
model_cache = global_cache.setdefault(obj, weakref.WeakKeyDictionary())
get_cache = model_cache.get(obj.model)
if not get_cache:
# use an LRU cache to evict dead tenants with time and to prevent
# bloat with lot of tenants
@lru_cache(maxsize=200)
def get_cache(schema_name):
return {}
model_cache[obj.model] = get_cache
tenant = getattr(connection, 'tenant', None)
schema_name = getattr(tenant, 'schema_name', 'public')
return get_cache(schema_name)
ContentTypeManager._cache = ContentTypeCacheDescriptor()
def ContentTypeManager_new__init__(self, *args, **kwargs):
ContentTypeManager_old__init__(self, *args, **kwargs)
if '_cache' in self.__dict__:
del self._cache
ContentTypeManager.__init__ = ContentTypeManager_new__init__
if '_cache' in ContentType.objects.__dict__:
del ContentType.objects._cache
if hasattr(ContentType._meta, 'local_managers'):
for manager in ContentType._meta.local_managers:
if '_cache' in manager.__dict__:
del manager._cache
else:
for _, manager, _ in ContentType._meta.managers:
if '_cache' in manager.__dict__:
del manager._cache