diff --git a/tests/test_ctl.py b/tests/test_ctl.py index 9754cfb36..7af2b8943 100644 --- a/tests/test_ctl.py +++ b/tests/test_ctl.py @@ -2,6 +2,7 @@ import collections import os import sys +import mock import psycopg2 import pytest from django.core.management import CommandError, call_command @@ -26,7 +27,14 @@ from .utilities import clean_temporary_pub, create_temporary_pub @pytest.fixture def pub(): - return create_temporary_pub() + yield create_temporary_pub() + clean_temporary_pub() + + +@pytest.fixture +def sql_pub(): + yield create_temporary_pub(sql_mode=True) + clean_temporary_pub() def pytest_generate_tests(metafunc): @@ -477,3 +485,15 @@ def test_ctl_no_command(capsys): assert 'error: You must use a command' in captured.err finally: sys.argv = old_argv + + +def test_dbshell(sql_pub): + + with pytest.raises(CommandError): + call_command('dbshell') # missing tenant name + + with mock.patch('subprocess.call') as call: + call.side_effect = lambda *args: 0 + call_command('dbshell', '--domain', 'example.net') + assert call.call_args[0][-1][0] == 'psql' + assert call.call_args[0][-1][-1] == sql_pub.cfg['postgresql']['database'] diff --git a/wcs/ctl/management/commands/dbshell.py b/wcs/ctl/management/commands/dbshell.py new file mode 100644 index 000000000..20fa2eab0 --- /dev/null +++ b/wcs/ctl/management/commands/dbshell.py @@ -0,0 +1,28 @@ +# w.c.s. - web application for online forms +# Copyright (C) 2005-2021 Entr'ouvert +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . + +from django.db.backends.postgresql.client import DatabaseClient + +from . import TenantCommand + + +class Command(TenantCommand): + def add_arguments(self, parser): + parser.add_argument('-d', '--domain', '--vhost', metavar='DOMAIN') + + def handle(self, *args, **options): + pub = self.init_tenant_publisher(options['domain'], register_tld_names=False) + DatabaseClient.runshell_db(conn_params=pub.cfg['postgresql'])