debian-django-tenant-schemas/tenant_schemas/postgresql_backend/base.py

222 lines
8.2 KiB
Python
Raw Normal View History

2013-09-17 07:15:28 +02:00
import re
import warnings
import psycopg2
import threading
import weakref
2011-02-06 18:42:00 +01:00
from django.conf import settings
from django.contrib.contenttypes.models import ContentTypeManager, ContentType
from django.core.exceptions import ImproperlyConfigured, ValidationError
import django.db.utils
from django.db import connection
try:
from functools import lru_cache
except ImportError:
from django.utils.lru_cache import lru_cache
from tenant_schemas.utils import get_public_schema_name, get_limit_set_calls
from tenant_schemas.postgresql_backend.introspection import DatabaseSchemaIntrospection
2011-02-06 18:42:00 +01:00
ORIGINAL_BACKEND = getattr(settings, 'ORIGINAL_BACKEND', 'django.db.backends.postgresql_psycopg2')
# Django 1.9+ takes care to rename the default backend to 'django.db.backends.postgresql'
original_backend = django.db.utils.load_backend(ORIGINAL_BACKEND)
2011-02-06 18:42:00 +01:00
EXTRA_SEARCH_PATHS = getattr(settings, 'PG_EXTRA_SEARCH_PATHS', [])
2011-02-06 18:42:00 +01:00
# from the postgresql doc
SQL_IDENTIFIER_RE = re.compile(r'^[_a-zA-Z][_a-zA-Z0-9]{,62}$')
SQL_SCHEMA_NAME_RESERVED_RE = re.compile(r'^pg_', re.IGNORECASE)
def _is_valid_identifier(identifier):
return bool(SQL_IDENTIFIER_RE.match(identifier))
2011-02-06 18:42:00 +01:00
2013-06-03 22:00:44 +02:00
2011-02-06 18:42:00 +01:00
def _check_identifier(identifier):
if not _is_valid_identifier(identifier):
raise ValidationError("Invalid string used for the identifier.")
def _is_valid_schema_name(name):
return _is_valid_identifier(name) and not SQL_SCHEMA_NAME_RESERVED_RE.match(name)
def _check_schema_name(name):
if not _is_valid_schema_name(name):
raise ValidationError("Invalid string used for the schema name.")
2011-02-06 18:42:00 +01:00
2013-06-03 22:00:44 +02:00
class DatabaseWrapper(original_backend.DatabaseWrapper):
"""
Adds the capability to manipulate the search_path using set_tenant and set_schema_name
"""
include_public_schema = True
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
# Use a patched version of the DatabaseIntrospection that only returns the table list for the
# currently selected schema.
self.introspection = DatabaseSchemaIntrospection(self)
self.set_schema_to_public()
2011-02-06 18:42:00 +01:00
def close(self):
self.search_path_set = False
super(DatabaseWrapper, self).close()
def rollback(self):
super(DatabaseWrapper, self).rollback()
# Django's rollback clears the search path so we have to set it again the next time.
self.search_path_set = False
def set_tenant(self, tenant, include_public=True):
"""
Main API method to current database schema,
but it does not actually modify the db connection.
"""
self.set_schema(tenant.schema_name, include_public)
self.tenant = tenant
def set_schema(self, schema_name, include_public=True):
2011-02-06 18:42:00 +01:00
"""
Main API method to current database schema,
but it does not actually modify the db connection.
"""
self.tenant = FakeTenant(schema_name=schema_name)
self.schema_name = schema_name
self.include_public_schema = include_public
self.set_settings_schema(schema_name)
self.search_path_set = False
2012-07-02 19:48:01 +02:00
def set_schema_to_public(self):
2011-02-06 18:42:00 +01:00
"""
2012-07-02 19:48:01 +02:00
Instructs to stay in the common 'public' schema.
2011-02-06 18:42:00 +01:00
"""
self.set_schema(get_public_schema_name())
def set_settings_schema(self, schema_name):
self.settings_dict['SCHEMA'] = schema_name
2011-02-06 18:42:00 +01:00
def get_schema(self):
warnings.warn("connection.get_schema() is deprecated, use connection.schema_name instead.",
category=DeprecationWarning)
return self.schema_name
def get_tenant(self):
warnings.warn("connection.get_tenant() is deprecated, use connection.tenant instead.",
category=DeprecationWarning)
return self.tenant
def _cursor(self, name=None):
"""
Here it happens. We hope every Django db operation using PostgreSQL
must go through this to get the cursor handle. We change the path.
"""
if name:
# Only supported and required by Django 1.11 (server-side cursor)
cursor = super(DatabaseWrapper, self)._cursor(name=name)
else:
cursor = super(DatabaseWrapper, self)._cursor()
2014-07-25 06:30:03 +02:00
# optionally limit the number of executions - under load, the execution
# of `set search_path` can be quite time consuming
if (not get_limit_set_calls()) or not self.search_path_set:
# Actual search_path modification for the cursor. Database will
# search schemata from left to right when looking for the object
# (table, index, sequence, etc.).
if not self.schema_name:
raise ImproperlyConfigured("Database schema not set. Did you forget "
"to call set_schema() or set_tenant()?")
_check_schema_name(self.schema_name)
public_schema_name = get_public_schema_name()
search_paths = []
if self.schema_name == public_schema_name:
search_paths = [public_schema_name]
elif self.include_public_schema:
search_paths = [self.schema_name, public_schema_name]
else:
search_paths = [self.schema_name]
search_paths.extend(EXTRA_SEARCH_PATHS)
if name:
# Named cursor can only be used once
cursor_for_search_path = self.connection.cursor()
else:
# Reuse
cursor_for_search_path = cursor
# In the event that an error already happened in this transaction and we are going
# to rollback we should just ignore database error when setting the search_path
# if the next instruction is not a rollback it will just fail also, so
# we do not have to worry that it's not the good one
try:
cursor_for_search_path.execute('SET search_path = {0}'.format(','.join(search_paths)))
cursor_for_search_path.execute('SET application_name = {0}'.format(self.schema_name))
except (django.db.utils.DatabaseError, psycopg2.InternalError):
self.search_path_set = False
else:
self.search_path_set = True
if name:
cursor_for_search_path.close()
return cursor
2011-02-06 18:42:00 +01:00
class FakeTenant:
"""
We can't import any db model in a backend (apparently?), so this class is used
for wrapping schema names in a tenant-like structure.
"""
def __init__(self, schema_name):
self.schema_name = schema_name
# Make the ContentType cache tenant and thread safe
ContentTypeManager._thread_local_cache = threading.local()
ContentTypeManager_old__init__ = ContentTypeManager.__init__
class ContentTypeCacheDescriptor(object):
def __get__(self, obj, owner):
if not hasattr(owner._thread_local_cache, '_cache'):
# use weak for transient Manager
owner._thread_local_cache._cache = weakref.WeakKeyDictionary()
global_cache = owner._thread_local_cache._cache
model_cache = global_cache.setdefault(obj, weakref.WeakKeyDictionary())
get_cache = model_cache.get(obj.model)
if not get_cache:
# use an LRU cache to evict dead tenants with time and to prevent
# bloat with lot of tenants
@lru_cache(maxsize=200)
def get_cache(schema_name):
return {}
model_cache[obj.model] = get_cache
tenant = getattr(connection, 'tenant', None)
schema_name = getattr(tenant, 'schema_name', 'public')
return get_cache(schema_name)
ContentTypeManager._cache = ContentTypeCacheDescriptor()
def ContentTypeManager_new__init__(self, *args, **kwargs):
ContentTypeManager_old__init__(self, *args, **kwargs)
if '_cache' in self.__dict__:
del self._cache
ContentTypeManager.__init__ = ContentTypeManager_new__init__
if '_cache' in ContentType.objects.__dict__:
del ContentType.objects._cache
if hasattr(ContentType._meta, 'local_managers'):
for manager in ContentType._meta.local_managers:
if '_cache' in manager.__dict__:
del manager._cache
else:
for _, manager, _ in ContentType._meta.managers:
if '_cache' in manager.__dict__:
del manager._cache