diff --git a/mellon/migrations/0004_migrate_issuer.py b/mellon/migrations/0004_migrate_issuer.py new file mode 100644 index 0000000..06be080 --- /dev/null +++ b/mellon/migrations/0004_migrate_issuer.py @@ -0,0 +1,31 @@ +# Generated by Django 2.2.19 on 2021-09-14 18:54 + +from django.db import migrations +from django.db.models import Count + + +def migrate_issuer_forward(apps, schema_editor): + UserSAMLIdentifier = apps.get_model('mellon', 'UserSAMLIdentifier') + Issuer = apps.get_model('mellon', 'Issuer') + issuers = UserSAMLIdentifier.objects.values_list('issuer').annotate(total=Count('id')) + for issuer, total in issuers: + issuer_instance = Issuer.objects.create(entity_id=issuer) + UserSAMLIdentifier.objects.filter(issuer=issuer).update(issuer_fk=issuer_instance) + + +def migrate_issuer_backward(apps, schema_editor): + UserSAMLIdentifier = apps.get_model('mellon', 'UserSAMLIdentifier') + Issuer = apps.get_model('mellon', 'Issuer') + for issuer in Issuer.objects.all(): + UserSAMLIdentifier.objects.filter(issuer_fk=issuer).update(issuer=issuer.entity_id) + + +class Migration(migrations.Migration): + dependencies = [ + ('mellon', '0003_add_issuer_model'), + ] + + operations = [] + operations = [ + migrations.RunPython(migrate_issuer_forward, migrate_issuer_backward), + ] diff --git a/tests/conftest.py b/tests/conftest.py index 2e668be..b8df92e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,9 @@ import os import django_webtest import pytest +from django.core.management import call_command +from django.db import connection +from django.db.migrations.executor import MigrationExecutor @pytest.fixture(autouse=True) @@ -82,3 +85,51 @@ def metadata_path(tmpdir, metadata): with metadata_path.open('w') as fd: fd.write(metadata) yield str(metadata_path) + + +@pytest.fixture() +def migration(request, transactional_db): + # see https://gist.github.com/asfaltboy/b3e6f9b5d95af8ba2cc46f2ba6eae5e2 + """ + This fixture returns a helper object to test Django data migrations. + The fixture returns an object with two methods; + - `before` to initialize db to the state before the migration under test + - `after` to execute the migration and bring db to the state after the migration + The methods return `old_apps` and `new_apps` respectively; these can + be used to initiate the ORM models as in the migrations themselves. + For example: + def test_foo_set_to_bar(migration): + old_apps = migration.before([('my_app', '0001_inital')]) + Foo = old_apps.get_model('my_app', 'foo') + Foo.objects.create(bar=False) + assert Foo.objects.count() == 1 + assert Foo.objects.filter(bar=False).count() == Foo.objects.count() + # executing migration + new_apps = migration.apply([('my_app', '0002_set_foo_bar')]) + Foo = new_apps.get_model('my_app', 'foo') + + assert Foo.objects.filter(bar=False).count() == 0 + assert Foo.objects.filter(bar=True).count() == Foo.objects.count() + Based on: https://gist.github.com/blueyed/4fb0a807104551f103e6 + """ + + class Migrator: + def before(self, targets, at_end=True): + """Specify app and starting migration names as in: + before([('app', '0001_before')]) => app/migrations/0001_before.py + """ + executor = MigrationExecutor(connection) + executor.migrate(targets) + executor.loader.build_graph() + return executor._create_project_state(with_applied_migrations=True).apps + + def apply(self, targets): + """Migrate forwards to the "targets" migration""" + executor = MigrationExecutor(connection) + executor.migrate(targets) + executor.loader.build_graph() + return executor._create_project_state(with_applied_migrations=True).apps + + yield Migrator() + + call_command('migrate', verbosity=0) diff --git a/tests/test_migrations.py b/tests/test_migrations.py new file mode 100644 index 0000000..86108ff --- /dev/null +++ b/tests/test_migrations.py @@ -0,0 +1,73 @@ +# django-mellon - SAML2 authentication for Django +# Copyright (C) 2014-2021 Entr'ouvert +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import pytest +from django.contrib.auth.models import User + +from mellon.models import Issuer, UserSAMLIdentifier + + +@pytest.fixture +def user_and_issuers(db): + user1 = User.objects.create(username='user1') + user2 = User.objects.create(username='user2') + issuer1 = Issuer.objects.create(entity_id='https://idp1') + issuer2 = Issuer.objects.create(entity_id='https://idp2') + UserSAMLIdentifier.objects.create(user=user1, issuer=issuer1, name_id='xxx') + UserSAMLIdentifier.objects.create(user=user2, issuer=issuer2, name_id='yyy') + + +def test_migration_0004_migrate_issuer_back_and_forward(transactional_db, user_and_issuers, migration): + migration.before([('mellon', '0002_sessionindex')]) + new_apps = migration.apply([('mellon', '0004_migrate_issuer')]) + + UserSAMLIdentifier = new_apps.get_model('mellon', 'UserSAMLIdentifier') + Issuer = new_apps.get_model('mellon', 'Issuer') + User = new_apps.get_model('auth', 'User') + + user1 = User.objects.get(username='user1') + user2 = User.objects.get(username='user2') + + assert UserSAMLIdentifier.objects.count() == 2 + assert Issuer.objects.count() == 2 + assert UserSAMLIdentifier.objects.get(user=user1, issuer_fk__entity_id='https://idp1', name_id='xxx') + assert UserSAMLIdentifier.objects.get(user=user2, issuer_fk__entity_id='https://idp2', name_id='yyy') + + +def test_migration_0004_migrate_issuer(transactional_db, migration): + old_apps = migration.before([('mellon', '0003_add_issuer_model')]) + + UserSAMLIdentifier = old_apps.get_model('mellon', 'UserSAMLIdentifier') + User = old_apps.get_model('auth', 'User') + + user1 = User.objects.create(username='user1') + user2 = User.objects.create(username='user2') + + UserSAMLIdentifier.objects.create(user=user1, issuer='https://idp1', name_id='xxx') + UserSAMLIdentifier.objects.create(user=user2, issuer='https://idp2', name_id='yyy') + + new_apps = migration.apply([('mellon', '0004_migrate_issuer')]) + + UserSAMLIdentifier = new_apps.get_model('mellon', 'UserSAMLIdentifier') + Issuer = new_apps.get_model('mellon', 'Issuer') + User = new_apps.get_model('auth', 'User') + + user1 = User.objects.get(username='user1') + user2 = User.objects.get(username='user2') + + assert UserSAMLIdentifier.objects.count() == 2 + assert Issuer.objects.count() == 2 + assert UserSAMLIdentifier.objects.get(user=user1, issuer_fk__entity_id='https://idp1', name_id='xxx') + assert UserSAMLIdentifier.objects.get(user=user2, issuer_fk__entity_id='https://idp2', name_id='yyy')