migrations: client data migration (#73170)

This commit is contained in:
Paul Marillonnet 2023-01-10 11:27:09 +01:00
parent 3b405b5eae
commit cada3f164d
3 changed files with 111 additions and 0 deletions

View File

@ -0,0 +1,30 @@
from django.conf import settings
from django.db import migrations
def add_cut_partners(apps, schema_editor):
cut_partners = getattr(settings, 'A2_CUT_PARTNERS', {})
OIDCClient = apps.get_model('authentic2_idp_oidc', 'OIDCClient')
CUTPartner = apps.get_model('authentic2_cut', 'CUTPartner')
cut_partner_names = [partner['name'] for partner in cut_partners]
for oidc_client in OIDCClient.objects.filter(name__in=cut_partner_names):
cut_partner_data = cut_partners[cut_partner_names.index(oidc_client.name)]
cut_partner = CUTPartner.objects.create(
oidc_client=oidc_client,
domains=cut_partner_data['domains'],
url=cut_partner_data['url'],
stat_emails=cut_partner_data['stat_emails'],
)
class Migration(migrations.Migration):
dependencies = [
('authentic2_cut', '0007_cutpartner'),
]
operations = [
migrations.RunPython(add_cut_partners, reverse_code=migrations.RunPython.noop),
]

View File

@ -13,6 +13,8 @@ except ImportError:
from authentic2.a2_rbac.models import OrganizationalUnit as OU
from django.contrib.auth import get_user_model
from django.core.management import call_command
from django.db import connection
from django.db.migrations.executor import MigrationExecutor
User = get_user_model()
TEST_DIR = pathlib.Path(__file__).parent
@ -157,3 +159,29 @@ def clean_caches():
from authentic2.apps.journal.models import event_type_cache
event_type_cache.cache.clear()
@pytest.fixture()
def migration(request, transactional_db):
# see https://gist.github.com/asfaltboy/b3e6f9b5d95af8ba2cc46f2ba6eae5e2
# copied from authentic mainline's fixtures
class Migrator:
def before(self, targets, at_end=True):
"""Specify app and starting migration names as in:
before([('app', '0001_before')]) => app/migrations/0001_before.py
"""
executor = MigrationExecutor(connection)
executor.migrate(targets)
executor.loader.build_graph()
return executor._create_project_state(with_applied_migrations=True).apps
def apply(self, targets):
"""Migrate forwards to the "targets" migration"""
executor = MigrationExecutor(connection)
executor.migrate(targets)
executor.loader.build_graph()
return executor._create_project_state(with_applied_migrations=True).apps
yield Migrator()
call_command('migrate', verbosity=0)

53
tests/test_migrations.py Normal file
View File

@ -0,0 +1,53 @@
def test_migration_0008_migrate_oidcclients_to_cutpartners(transactional_db, migration, settings):
settings.A2_CUT_PARTNERS = [
{
'domains': ['.lyon.fr'],
'url': 'www.lyon.fr',
'name': 'Ville de Lyon',
'stat_emails': ['lyon@example.com'],
},
{
'domains': ['.entrouvert.org'],
'url': 'www.entrouvert.org',
'name': 'Entrouvert',
'stat_emails': ['entrouvert@example.com'],
},
# partner that does not match an existing oidc client
{
'domains': ['.example.com'],
'url': 'www.example.com',
'name': 'Example',
'stat_emails': ['void@example.com'],
},
]
old_apps = migration.before([('authentic2_cut', '0007_cutpartner')])
OIDCClient = old_apps.get_model('authentic2_idp_oidc', 'OIDCClient')
OIDCClient.objects.create(
name='Ville de Lyon', # matches A2_CUT_PARTNERS first entry
slug='ville-de-lyon',
client_id='abc',
client_secret='def',
)
OIDCClient.objects.create(
name='Entrouvert', # matches A2_CUT_PARTNERS second entry
slug='entrouvert',
client_id='ghi',
client_secret='jkl',
)
CUTPartner = old_apps.get_model('authentic2_cut', 'CUTPartner')
assert not CUTPartner.objects.count()
new_apps = migration.apply([('authentic2_cut', '0008_migrate_oidcclients_to_cutpartners')])
CUTPartner = new_apps.get_model('authentic2_cut', 'CUTPartner')
assert CUTPartner.objects.count() == 2
eo_partner = CUTPartner.objects.get(url='www.entrouvert.org')
assert eo_partner.domains == ['.entrouvert.org']
assert eo_partner.stat_emails == ['entrouvert@example.com']
lyon_partner = CUTPartner.objects.get(url='www.lyon.fr')
assert lyon_partner.domains == ['.lyon.fr']
assert lyon_partner.stat_emails == ['lyon@example.com']