multitenant: use real tenant in migration commands (#48071)

This commit is contained in:
Benjamin Dauvergne 2020-10-27 17:19:51 +01:00
parent 1e492e7aba
commit d41a4e3409
3 changed files with 40 additions and 27 deletions

View File

@ -24,7 +24,8 @@ from django.db import connection
from django.conf import settings
from tenant_schemas.utils import get_public_schema_name, schema_exists
from hobo.multitenant.middleware import TenantMiddleware
from tenant_schemas.postgresql_backend.base import FakeTenant
from hobo.multitenant.middleware import TenantMiddleware, TenantNotFound
from hobo.multitenant.management.commands import SyncCommon
@ -39,12 +40,15 @@ class MigrateSchemasCommand(SyncCommon):
def handle(self, *args, **options):
super(MigrateSchemasCommand, self).handle(*args, **options)
if self.schema_name:
if not schema_exists(self.schema_name):
raise RuntimeError('Schema "{}" does not exist'.format(
self.schema_name))
if self.domain:
try:
tenant = TenantMiddleware.get_tenant_by_hostname(self.domain)
except TenantNotFound:
raise RuntimeError('Tenant "{}" does not exist'.format(self.domain))
else:
self.run_migrations(self.schema_name, settings.TENANT_APPS)
self.run_migrations(tenant, settings.TENANT_APPS)
elif self.schema_name:
self.run_migrations_on_schema(self.schema_name, settings.TENANT_APPS)
else:
app_labels = []
for app in apps.get_app_configs():
@ -54,7 +58,7 @@ class MigrateSchemasCommand(SyncCommon):
loader.load_disk()
all_migrations = set([(app, migration) for app, migration in loader.disk_migrations if app in app_labels])
for tenant in TenantMiddleware.get_tenants():
connection.set_schema(tenant.schema_name, include_public=False)
connection.set_tenant(tenant, include_public=False)
applied_migrations = self.get_applied_migrations(app_labels)
if options.get('fake') or options.get('migration_name') or options.get('app_label'):
# never skip migrations if explicit migration actions
@ -62,9 +66,9 @@ class MigrateSchemasCommand(SyncCommon):
applied_migrations = []
if all([x in applied_migrations for x in all_migrations]):
if int(self.options.get('verbosity', 1)) >= 1:
self._notice("=== Skipping migrations of schema %s" % tenant.schema_name)
self._notice("=== Skipping migrations of tenant %s" % tenant.domain_url)
continue
self.run_migrations(tenant.schema_name, settings.TENANT_APPS)
self.run_migrations(tenant, settings.TENANT_APPS)
def get_applied_migrations(self, app_labels):
applied_migrations = []
@ -75,10 +79,20 @@ class MigrateSchemasCommand(SyncCommon):
applied_migrations = [x for x in applied_migrations if x[0] in app_labels]
return applied_migrations
def run_migrations(self, schema_name, included_apps):
def run_migrations(self, tenant, included_apps):
if int(self.options.get('verbosity', 1)) >= 1:
self._notice("=== Running migrate for schema %s" % schema_name)
connection.set_schema(schema_name, include_public=False)
self._notice("=== Running migrate for tenant %s" % tenant.domain_url)
connection.set_tenant(tenant, include_public=False)
command = MigrateCommand()
command.requires_system_checks = False
command.requires_migrations_checks = False
command.execute(*self.args, **self.options)
connection.set_schema_to_public()
def run_migrations_on_schema(self, schema, included_apps):
if int(self.options.get('verbosity', 1)) >= 1:
self._notice("=== Running migrate for schema %s" % schema)
connection.set_schema(schema, include_public=False)
command = MigrateCommand()
command.requires_system_checks = False
command.requires_migrations_checks = False

View File

@ -18,8 +18,7 @@ from django.core.management.commands.showmigrations import Command as ShowMigrat
from django.db import connection
from django.conf import settings
from tenant_schemas.utils import schema_exists
from hobo.multitenant.middleware import TenantMiddleware
from hobo.multitenant.middleware import TenantMiddleware, TenantNotFound
from hobo.multitenant.management.commands import SyncCommon
@ -33,22 +32,22 @@ class ShowMigrationsSchemasCommand(SyncCommon):
def handle(self, *args, **options):
super(ShowMigrationsSchemasCommand, self).handle(*args, **options)
if self.schema_name:
if not schema_exists(self.schema_name):
if self.domain:
try:
tenant = TenantMiddleware.get_tenant_by_hostname(self.domain)
except TenantNotFound:
raise RuntimeError('Schema "{}" does not exist'.format(
self.schema_name))
else:
self.run_showmigrations(self.schema_name, settings.TENANT_APPS)
self.run_showmigrations(tenant, settings.TENANT_APPS)
else:
for tenant in TenantMiddleware.get_tenants():
self.run_showmigrations(tenant.schema_name, settings.TENANT_APPS)
self.run_showmigrations(tenant, settings.TENANT_APPS)
def run_showmigrations(self, schema_name, included_apps):
self._notice("=== Show migrations for schema %s" % schema_name)
connection.set_schema(schema_name, include_public=False)
def run_showmigrations(self, tenant, included_apps):
self._notice("=== Show migrations for schema %s" % tenant.domain_url)
connection.set_tenant(tenant, include_public=False)
command = ShowMigrationsCommand()
command.execute(*self.args, **self.options)
connection.set_schema_to_public()
Command = ShowMigrationsSchemasCommand

View File

@ -72,15 +72,15 @@ def test_migrate_schemas_skip_applied(db, capsys):
assert 'Running migrate for schema www_example_com' in captured.out
call_command('migrate_schemas', verbosity=1)
captured = capsys.readouterr()
assert 'Skipping migrations of schema www_example_com' in captured.out
assert 'Skipping migrations of tenant www.example.com' in captured.out
call_command('migrate_schemas', 'common', '0001_initial', verbosity=1)
captured = capsys.readouterr()
assert 'Running migrate for schema www_example_com' in captured.out
assert 'Running migrate for tenant www.example.com' in captured.out
assert 'Unapplying common.0002' in captured.out
call_command('migrate_schemas', verbosity=1)
captured = capsys.readouterr()
assert 'Running migrate for schema www_example_com' in captured.out
assert 'Running migrate for tenant www.example.com' in captured.out
assert 'Applying common.0002' in captured.out
call_command('migrate_schemas', verbosity=1)
captured = capsys.readouterr()
assert 'Skipping migrations of schema www_example_com' in captured.out
assert 'Skipping migrations of tenant www.example.com' in captured.out