diff --git a/combo/apps/lingo/views.py b/combo/apps/lingo/views.py index 13442208..fba3bd5a 100644 --- a/combo/apps/lingo/views.py +++ b/combo/apps/lingo/views.py @@ -40,11 +40,7 @@ import eopayment from combo.data.models import Page from combo.utils import check_request_signature, aes_hex_decrypt, DecryptionError - -if 'mellon' in settings.INSTALLED_APPS: - from mellon.models import UserSAMLIdentifier -else: - UserSAMLIdentifier = None +from combo.profile.utils import get_user_from_name_id from .models import (Regie, BasketItem, Transaction, TransactionOperation, LingoBasketCell, SelfDeclaredInvoicePayment) @@ -125,12 +121,9 @@ class AddBasketItemApiView(View): try: if request.GET.get('NameId'): - if UserSAMLIdentifier is None: - raise Exception('missing mellon?') - try: - user = UserSAMLIdentifier.objects.get(name_id=request.GET.get('NameId')).user - except UserSAMLIdentifier.DoesNotExist: - raise Exception('unknown name id') + user = get_user_from_name_id(request.GET.get('NameId')) + if user is None: + raise User.DoesNotExist() elif request.GET.get('email'): user = User.objects.get(email=request.GET.get('email')) else: @@ -198,12 +191,9 @@ class RemoveBasketItemApiView(View): try: if request.GET.get('NameId'): - if UserSAMLIdentifier is None: - raise Exception('missing mellon?') - try: - user = UserSAMLIdentifier.objects.get(name_id=request.GET.get('NameId')).user - except UserSAMLIdentifier.DoesNotExist: - raise Exception('unknown name id') + user = get_user_from_name_id(request.GET.get('NameId')) + if user is None: + raise User.DoesNotExist() elif request.GET.get('email'): user = User.objects.get(email=request.GET.get('email')) else: diff --git a/combo/apps/newsletters/forms.py b/combo/apps/newsletters/forms.py index 4c1ebfb3..b5b98cac 100644 --- a/combo/apps/newsletters/forms.py +++ b/combo/apps/newsletters/forms.py @@ -38,8 +38,9 @@ class NewslettersManageForm(forms.Form): logger.error('Error occured while getting newsletters: %r', e) return self.params = {} - if hasattr(self.user, 'saml_identifiers') and self.user.saml_identifiers.exists(): - self.params['uuid'] = self.user.saml_identifiers.first().name_id + user_name_id = self.user.get_name_id() + if user_name_id: + self.params['uuid'] = user_name_id # get mobile number from mellon session as it is not user attribute if self.request.session.get('mellon_session'): diff --git a/combo/apps/wcs/models.py b/combo/apps/wcs/models.py index 9f45da62..252e0167 100644 --- a/combo/apps/wcs/models.py +++ b/combo/apps/wcs/models.py @@ -313,8 +313,10 @@ class WcsCurrentFormsCell(WcsUserDataBaseCell): def get_api_url(self, context): user = self.get_concerned_user(context) - if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists(): - return '/api/users/%s/forms' % user.saml_identifiers.first().name_id + if user: + user_name_id = user.get_name_id() + if user_name_id: + return '/api/users/%s/forms' % user_name_id return '/api/user/forms' @property @@ -386,8 +388,10 @@ class WcsCurrentDraftsCell(WcsUserDataBaseCell): def get_api_url(self, context): user = self.get_concerned_user(context) - if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists(): - return '/api/users/%s/drafts' % user.saml_identifiers.first().name_id + if user: + user_name_id = user.get_name_id() + if user_name_id: + return '/api/users/%s/drafts' % user_name_id return '/api/user/drafts' def get_cell_extra_context(self, context): diff --git a/combo/profile/__init__.py b/combo/profile/__init__.py index c5fb5908..e34e80f2 100644 --- a/combo/profile/__init__.py +++ b/combo/profile/__init__.py @@ -21,6 +21,13 @@ import django.apps from django.utils.translation import ugettext_lazy as _ +def user_get_name_id(user): + saml_identifier = user.saml_identifiers.first() + if saml_identifier: + return saml_identifier.name_id + return None + + class AppConfig(django.apps.AppConfig): name = 'combo.profile' verbose_name = _('Profile') @@ -28,6 +35,8 @@ class AppConfig(django.apps.AppConfig): def ready(self): from combo.apps.search import engines engines.register(self.get_search_engines) + from django.contrib.auth import get_user_model + get_user_model().add_to_class('get_name_id', user_get_name_id) def get_search_engines(self): from combo.data.models import Page diff --git a/combo/profile/utils.py b/combo/profile/utils.py new file mode 100644 index 00000000..e8ca8fe0 --- /dev/null +++ b/combo/profile/utils.py @@ -0,0 +1,32 @@ +# combo - content management system +# Copyright (C) 2014-2019 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from django.conf import settings + + +if 'mellon' in settings.INSTALLED_APPS: + from mellon.models import UserSAMLIdentifier +else: + UserSAMLIdentifier = None + + +def get_user_from_name_id(name_id): + if not UserSAMLIdentifier: + return None + try: + return UserSAMLIdentifier.objects.get(name_id=name_id).user + except UserSAMLIdentifier.DoesNotExist: + return None diff --git a/combo/public/templatetags/combo.py b/combo/public/templatetags/combo.py index 8791de6b..47e498f6 100644 --- a/combo/public/templatetags/combo.py +++ b/combo/public/templatetags/combo.py @@ -34,9 +34,6 @@ from combo.public.menu import get_menu_context from combo.utils import NothingInCacheException, flatten_context from combo.apps.dashboard.models import DashboardCell, Tile -if 'mellon' in settings.INSTALLED_APPS: - from mellon.models import UserSAMLIdentifier - register = template.Library() def skeleton_text(context, placeholder_name, content=''): @@ -248,9 +245,10 @@ def signed(obj): @register.filter def name_id(user): - saml_id = UserSAMLIdentifier.objects.filter(user=user).last() - if saml_id: - return saml_id.name_id + if user: + user_name_id = user.get_name_id() + if user_name_id: + return user_name_id # it is important to raise this so get_templated_url is aborted and no call # is tried with a missing user argument. raise VariableDoesNotExist('name_id') diff --git a/combo/public/views.py b/combo/public/views.py index 60fb0829..04091e46 100644 --- a/combo/public/views.py +++ b/combo/public/views.py @@ -45,14 +45,13 @@ from haystack.query import SearchQuerySet, SQ if 'mellon' in settings.INSTALLED_APPS: from mellon.utils import get_idps - from mellon.models import UserSAMLIdentifier else: get_idps = lambda: [] - UserSAMLIdentifier = None from combo.data.models import (CellBase, PostException, Page, Redirect, ParentContentCell, TextCell, PageSnapshot) from combo.profile.models import Profile +from combo.profile.utils import get_user_from_name_id from combo.apps.search.models import SearchCell from combo import utils @@ -81,11 +80,8 @@ def modify_global_context(request, ctx): ctx['selected_user'] = User.objects.get(id=ctx['user_id']) except (User.DoesNotExist, ValueError): pass - if 'name_id' in ctx and UserSAMLIdentifier: - try: - ctx['selected_user'] = UserSAMLIdentifier.objects.get(name_id=ctx['name_id']).user - except UserSAMLIdentifier.DoesNotExist: - pass + if 'name_id' in ctx: + ctx['selected_user'] = get_user_from_name_id(ctx['name_id']) @csrf_exempt def ajax_page_cell(request, page_pk, cell_reference): diff --git a/combo/utils/requests_wrapper.py b/combo/utils/requests_wrapper.py index 6dbde7d4..66425309 100644 --- a/combo/utils/requests_wrapper.py +++ b/combo/utils/requests_wrapper.py @@ -79,12 +79,13 @@ class Requests(RequestsSession): else: query_params = {} if federation_key == 'nameid': - query_params['NameID'] = user.saml_identifiers.first().name_id + query_params['NameID'] = user.get_name_id() elif federation_key == 'email': query_params['email'] = user.email else: # 'auto' - if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists(): - query_params['NameID'] = user.saml_identifiers.first().name_id + user_name_id = user.get_name_id() + if user_name_id: + query_params['NameID'] = user_name_id else: query_params['email'] = user.email diff --git a/combo/utils/urls.py b/combo/utils/urls.py index 24be39f6..57e67be7 100644 --- a/combo/utils/urls.py +++ b/combo/utils/urls.py @@ -42,8 +42,9 @@ def get_templated_url(url, context=None): user = getattr(context.get('request'), 'user', None) if user and user.is_authenticated(): template_vars['user_email'] = quote(user.email) - if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists(): - template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id) + user_nameid = user.get_name_id() + if user_nameid: + template_vars['user_nameid'] = quote(user_nameid) template_vars.update(settings.TEMPLATE_VARS) if '{{' in url or '{%' in url: # Django template try: diff --git a/tests/test_lingo_remote_regie.py b/tests/test_lingo_remote_regie.py index 67b8480b..6ae70878 100644 --- a/tests/test_lingo_remote_regie.py +++ b/tests/test_lingo_remote_regie.py @@ -75,15 +75,9 @@ class MockUser(object): def is_authenticated(self): return True - def __init__(self): - class MockSAMLUsers(object): - def exists(self): - return True - def first(self): - class MockSAMLUser(object): - name_id = 'r2d2' - return MockSAMLUser() - self.saml_identifiers = MockSAMLUsers() + def get_name_id(self): + return 'r2d2' + @mock.patch('combo.utils.requests_wrapper.RequestsSession.request') def test_remote_regie_active_invoices_cell(mock_request, remote_regie): diff --git a/tests/test_newsletters_cell.py b/tests/test_newsletters_cell.py index 311d1546..30719786 100644 --- a/tests/test_newsletters_cell.py +++ b/tests/test_newsletters_cell.py @@ -223,9 +223,7 @@ def test_get_subscriptions_with_name_id_and_mobile(mock_get, cell, user): fake_saml_request = mock.Mock() fake_saml_request.user = mock.Mock(email=USER_EMAIL) - fake_saml_request.user.saml_identifiers = mock.Mock() - fake_saml_request.user.saml_identifiers.exists.return_value = True - fake_saml_request.user.saml_identifiers.first.return_value = mock.Mock(name_id='nameid') + fake_saml_request.user.get_name_id.return_value = 'nameid' fake_saml_request.session = {'mellon_session': {'mobile': '0607080900'}} form = NewslettersManageForm(instance=cell, request=fake_saml_request) diff --git a/tests/test_profile.py b/tests/test_profile.py index 57681e14..26f8b39f 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -13,9 +13,8 @@ pytestmark = pytest.mark.django_db @override_settings( KNOWN_SERVICES={'authentic': {'idp': {'title': 'IdP', 'url': 'http://example.org/'}}}) -@mock.patch('combo.public.templatetags.combo.UserSAMLIdentifier') @mock.patch('combo.utils.requests.get') -def test_profile_cell(requests_get, user_saml, app, admin_user): +def test_profile_cell(requests_get, app, admin_user): page = Page() page.save() @@ -28,24 +27,9 @@ def test_profile_cell(requests_get, user_saml, app, admin_user): json=lambda: data, status_code=200) - def filter_mock(user=None): - assert user is admin_user - return mock.Mock(last=lambda: mock.Mock(name_id='123456')) - - mocked_objects = mock.Mock() - mocked_objects.filter = mock.Mock(side_effect=filter_mock) - user_saml.objects = mocked_objects + admin_user.get_name_id = lambda: '123456' context = cell.get_cell_extra_context({'synchronous': True, 'selected_user': admin_user}) assert context['profile_fields']['first_name']['value'] == 'Foo' assert context['profile_fields']['birthdate']['value'] == datetime.date(2018, 8, 10) assert requests_get.call_args[0][0] == 'http://example.org/api/users/123456/' - - def filter_mock_missing(user=None): - return mock.Mock(last=lambda: None) - - mocked_objects.filter = mock.Mock(side_effect=filter_mock_missing) - - context = cell.get_cell_extra_context({'synchronous': True, 'selected_user': admin_user}) - assert context['error'] == 'unknown user' - assert requests_get.call_count == 1 # no new call was made diff --git a/tests/test_public.py b/tests/test_public.py index 92efafb2..b0983734 100644 --- a/tests/test_public.py +++ b/tests/test_public.py @@ -813,7 +813,7 @@ def test_sub_slug(app, john_doe, jane_doe): assert 'XXYY' in resp.text # custom behaviour for , it will add the SAML user to context - with mock.patch('combo.public.views.UserSAMLIdentifier') as user_saml: + with mock.patch('combo.profile.utils.UserSAMLIdentifier') as user_saml: class DoesNotExist(Exception): pass user_saml.DoesNotExist = DoesNotExist diff --git a/tests/test_requests.py b/tests/test_requests.py index 3d56edea..4d3b3a03 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -8,22 +8,18 @@ from django.utils.six.moves.urllib import parse as urlparse from combo.utils import requests, check_query, NothingInCacheException -class MockSAMLUser(object): - name_id = 'r2d2' - class MockUser(object): email = 'foo@example.net' def is_authenticated(self): return True + def get_name_id(self): + if self.samlized: + return 'r2d2' + return None + def __init__(self, samlized=True): - class MockSAMLUsers(object): - def exists(self): - return True - def first(self): - return MockSAMLUser() - if samlized: - self.saml_identifiers = MockSAMLUsers() + self.samlized = samlized def test_nosign(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 37011684..8ccbde0c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,22 +9,18 @@ from django.test.client import RequestFactory from django.contrib.auth.models import AnonymousUser -class MockSAMLUser(object): - name_id = 'r2&d2' - class MockUser(object): email = 'foo=3@example.net' def is_authenticated(self): return True def __init__(self, samlized=True): - class MockSAMLUsers(object): - def exists(self): - return True - def first(self): - return MockSAMLUser() - if samlized: - self.saml_identifiers = MockSAMLUsers() + self.samlized = samlized + + def get_name_id(self): + if self.samlized: + return 'r2&d2' + return None def test_crypto_url(): diff --git a/tests/test_wcs.py b/tests/test_wcs.py index 449fe231..9add0ead 100644 --- a/tests/test_wcs.py +++ b/tests/test_wcs.py @@ -161,6 +161,14 @@ user2.store() WCS_DIR = tempfile.mkdtemp() WCS_PID = None + +class MockUser(object): + email = 'foo@example.net' + def is_authenticated(self): + return True + def get_name_id(self): + return None + def run_wcs_script(script, hostname): script_path = os.path.join(WCS_DIR, script + '.py') fd = open(script_path, 'w') @@ -345,10 +353,6 @@ def test_current_forms_cell_render(context): cell = WcsCurrentFormsCell(page=page, placeholder='content', order=0) cell.save() - class MockUser(object): - email = 'foo@example.net' - def is_authenticated(self): - return True context['request'].user = MockUser() # query should fail as nothing is cached @@ -412,10 +416,6 @@ def test_current_forms_cell_render_single_site(context): cell.wcs_site = 'default' cell.save() - class MockUser(object): - email = 'foo@example.net' - def is_authenticated(self): - return True context['request'].user = MockUser() # query should fail as nothing is cached @@ -506,10 +506,6 @@ def test_current_drafts_cell_render_logged_in(context): cell.save() context['synchronous'] = True # to get fresh content - class MockUser(object): - email = 'foo@example.net' - def is_authenticated(self): - return True context['request'].user = MockUser() # default is to get current forms from all wcs sites @@ -695,20 +691,14 @@ def test_backoffice_submission_cell_render(context): result = cell.render(context) assert '/backoffice/submission/a-private-form/' not in result - class MockUser(object): - email = 'foo@example.net' - def is_authenticated(self): - return True context['request'].user = MockUser() result = cell.render(context) assert '/backoffice/submission/a-private-form/' not in result - class MockUser(object): + class MockUser2(MockUser): email = 'foo2@example.net' - def is_authenticated(self): - return True - context['request'].user = MockUser() + context['request'].user = MockUser2() result = cell.render(context) assert '/backoffice/submission/a-private-form/' in result