diff --git a/combo/apps/lingo/views.py b/combo/apps/lingo/views.py
index 13442208..fba3bd5a 100644
--- a/combo/apps/lingo/views.py
+++ b/combo/apps/lingo/views.py
@@ -40,11 +40,7 @@ import eopayment
from combo.data.models import Page
from combo.utils import check_request_signature, aes_hex_decrypt, DecryptionError
-
-if 'mellon' in settings.INSTALLED_APPS:
- from mellon.models import UserSAMLIdentifier
-else:
- UserSAMLIdentifier = None
+from combo.profile.utils import get_user_from_name_id
from .models import (Regie, BasketItem, Transaction, TransactionOperation,
LingoBasketCell, SelfDeclaredInvoicePayment)
@@ -125,12 +121,9 @@ class AddBasketItemApiView(View):
try:
if request.GET.get('NameId'):
- if UserSAMLIdentifier is None:
- raise Exception('missing mellon?')
- try:
- user = UserSAMLIdentifier.objects.get(name_id=request.GET.get('NameId')).user
- except UserSAMLIdentifier.DoesNotExist:
- raise Exception('unknown name id')
+ user = get_user_from_name_id(request.GET.get('NameId'))
+ if user is None:
+ raise User.DoesNotExist()
elif request.GET.get('email'):
user = User.objects.get(email=request.GET.get('email'))
else:
@@ -198,12 +191,9 @@ class RemoveBasketItemApiView(View):
try:
if request.GET.get('NameId'):
- if UserSAMLIdentifier is None:
- raise Exception('missing mellon?')
- try:
- user = UserSAMLIdentifier.objects.get(name_id=request.GET.get('NameId')).user
- except UserSAMLIdentifier.DoesNotExist:
- raise Exception('unknown name id')
+ user = get_user_from_name_id(request.GET.get('NameId'))
+ if user is None:
+ raise User.DoesNotExist()
elif request.GET.get('email'):
user = User.objects.get(email=request.GET.get('email'))
else:
diff --git a/combo/apps/newsletters/forms.py b/combo/apps/newsletters/forms.py
index 4c1ebfb3..b5b98cac 100644
--- a/combo/apps/newsletters/forms.py
+++ b/combo/apps/newsletters/forms.py
@@ -38,8 +38,9 @@ class NewslettersManageForm(forms.Form):
logger.error('Error occured while getting newsletters: %r', e)
return
self.params = {}
- if hasattr(self.user, 'saml_identifiers') and self.user.saml_identifiers.exists():
- self.params['uuid'] = self.user.saml_identifiers.first().name_id
+ user_name_id = self.user.get_name_id()
+ if user_name_id:
+ self.params['uuid'] = user_name_id
# get mobile number from mellon session as it is not user attribute
if self.request.session.get('mellon_session'):
diff --git a/combo/apps/wcs/models.py b/combo/apps/wcs/models.py
index 9f45da62..252e0167 100644
--- a/combo/apps/wcs/models.py
+++ b/combo/apps/wcs/models.py
@@ -313,8 +313,10 @@ class WcsCurrentFormsCell(WcsUserDataBaseCell):
def get_api_url(self, context):
user = self.get_concerned_user(context)
- if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
- return '/api/users/%s/forms' % user.saml_identifiers.first().name_id
+ if user:
+ user_name_id = user.get_name_id()
+ if user_name_id:
+ return '/api/users/%s/forms' % user_name_id
return '/api/user/forms'
@property
@@ -386,8 +388,10 @@ class WcsCurrentDraftsCell(WcsUserDataBaseCell):
def get_api_url(self, context):
user = self.get_concerned_user(context)
- if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
- return '/api/users/%s/drafts' % user.saml_identifiers.first().name_id
+ if user:
+ user_name_id = user.get_name_id()
+ if user_name_id:
+ return '/api/users/%s/drafts' % user_name_id
return '/api/user/drafts'
def get_cell_extra_context(self, context):
diff --git a/combo/profile/__init__.py b/combo/profile/__init__.py
index c5fb5908..e34e80f2 100644
--- a/combo/profile/__init__.py
+++ b/combo/profile/__init__.py
@@ -21,6 +21,13 @@ import django.apps
from django.utils.translation import ugettext_lazy as _
+def user_get_name_id(user):
+ saml_identifier = user.saml_identifiers.first()
+ if saml_identifier:
+ return saml_identifier.name_id
+ return None
+
+
class AppConfig(django.apps.AppConfig):
name = 'combo.profile'
verbose_name = _('Profile')
@@ -28,6 +35,8 @@ class AppConfig(django.apps.AppConfig):
def ready(self):
from combo.apps.search import engines
engines.register(self.get_search_engines)
+ from django.contrib.auth import get_user_model
+ get_user_model().add_to_class('get_name_id', user_get_name_id)
def get_search_engines(self):
from combo.data.models import Page
diff --git a/combo/profile/utils.py b/combo/profile/utils.py
new file mode 100644
index 00000000..e8ca8fe0
--- /dev/null
+++ b/combo/profile/utils.py
@@ -0,0 +1,32 @@
+# combo - content management system
+# Copyright (C) 2014-2019 Entr'ouvert
+#
+# This program is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from django.conf import settings
+
+
+if 'mellon' in settings.INSTALLED_APPS:
+ from mellon.models import UserSAMLIdentifier
+else:
+ UserSAMLIdentifier = None
+
+
+def get_user_from_name_id(name_id):
+ if not UserSAMLIdentifier:
+ return None
+ try:
+ return UserSAMLIdentifier.objects.get(name_id=name_id).user
+ except UserSAMLIdentifier.DoesNotExist:
+ return None
diff --git a/combo/public/templatetags/combo.py b/combo/public/templatetags/combo.py
index 8791de6b..47e498f6 100644
--- a/combo/public/templatetags/combo.py
+++ b/combo/public/templatetags/combo.py
@@ -34,9 +34,6 @@ from combo.public.menu import get_menu_context
from combo.utils import NothingInCacheException, flatten_context
from combo.apps.dashboard.models import DashboardCell, Tile
-if 'mellon' in settings.INSTALLED_APPS:
- from mellon.models import UserSAMLIdentifier
-
register = template.Library()
def skeleton_text(context, placeholder_name, content=''):
@@ -248,9 +245,10 @@ def signed(obj):
@register.filter
def name_id(user):
- saml_id = UserSAMLIdentifier.objects.filter(user=user).last()
- if saml_id:
- return saml_id.name_id
+ if user:
+ user_name_id = user.get_name_id()
+ if user_name_id:
+ return user_name_id
# it is important to raise this so get_templated_url is aborted and no call
# is tried with a missing user argument.
raise VariableDoesNotExist('name_id')
diff --git a/combo/public/views.py b/combo/public/views.py
index 60fb0829..04091e46 100644
--- a/combo/public/views.py
+++ b/combo/public/views.py
@@ -45,14 +45,13 @@ from haystack.query import SearchQuerySet, SQ
if 'mellon' in settings.INSTALLED_APPS:
from mellon.utils import get_idps
- from mellon.models import UserSAMLIdentifier
else:
get_idps = lambda: []
- UserSAMLIdentifier = None
from combo.data.models import (CellBase, PostException, Page, Redirect,
ParentContentCell, TextCell, PageSnapshot)
from combo.profile.models import Profile
+from combo.profile.utils import get_user_from_name_id
from combo.apps.search.models import SearchCell
from combo import utils
@@ -81,11 +80,8 @@ def modify_global_context(request, ctx):
ctx['selected_user'] = User.objects.get(id=ctx['user_id'])
except (User.DoesNotExist, ValueError):
pass
- if 'name_id' in ctx and UserSAMLIdentifier:
- try:
- ctx['selected_user'] = UserSAMLIdentifier.objects.get(name_id=ctx['name_id']).user
- except UserSAMLIdentifier.DoesNotExist:
- pass
+ if 'name_id' in ctx:
+ ctx['selected_user'] = get_user_from_name_id(ctx['name_id'])
@csrf_exempt
def ajax_page_cell(request, page_pk, cell_reference):
diff --git a/combo/utils/requests_wrapper.py b/combo/utils/requests_wrapper.py
index 6dbde7d4..66425309 100644
--- a/combo/utils/requests_wrapper.py
+++ b/combo/utils/requests_wrapper.py
@@ -79,12 +79,13 @@ class Requests(RequestsSession):
else:
query_params = {}
if federation_key == 'nameid':
- query_params['NameID'] = user.saml_identifiers.first().name_id
+ query_params['NameID'] = user.get_name_id()
elif federation_key == 'email':
query_params['email'] = user.email
else: # 'auto'
- if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
- query_params['NameID'] = user.saml_identifiers.first().name_id
+ user_name_id = user.get_name_id()
+ if user_name_id:
+ query_params['NameID'] = user_name_id
else:
query_params['email'] = user.email
diff --git a/combo/utils/urls.py b/combo/utils/urls.py
index 24be39f6..57e67be7 100644
--- a/combo/utils/urls.py
+++ b/combo/utils/urls.py
@@ -42,8 +42,9 @@ def get_templated_url(url, context=None):
user = getattr(context.get('request'), 'user', None)
if user and user.is_authenticated():
template_vars['user_email'] = quote(user.email)
- if hasattr(user, 'saml_identifiers') and user.saml_identifiers.exists():
- template_vars['user_nameid'] = quote(user.saml_identifiers.first().name_id)
+ user_nameid = user.get_name_id()
+ if user_nameid:
+ template_vars['user_nameid'] = quote(user_nameid)
template_vars.update(settings.TEMPLATE_VARS)
if '{{' in url or '{%' in url: # Django template
try:
diff --git a/tests/test_lingo_remote_regie.py b/tests/test_lingo_remote_regie.py
index 67b8480b..6ae70878 100644
--- a/tests/test_lingo_remote_regie.py
+++ b/tests/test_lingo_remote_regie.py
@@ -75,15 +75,9 @@ class MockUser(object):
def is_authenticated(self):
return True
- def __init__(self):
- class MockSAMLUsers(object):
- def exists(self):
- return True
- def first(self):
- class MockSAMLUser(object):
- name_id = 'r2d2'
- return MockSAMLUser()
- self.saml_identifiers = MockSAMLUsers()
+ def get_name_id(self):
+ return 'r2d2'
+
@mock.patch('combo.utils.requests_wrapper.RequestsSession.request')
def test_remote_regie_active_invoices_cell(mock_request, remote_regie):
diff --git a/tests/test_newsletters_cell.py b/tests/test_newsletters_cell.py
index 311d1546..30719786 100644
--- a/tests/test_newsletters_cell.py
+++ b/tests/test_newsletters_cell.py
@@ -223,9 +223,7 @@ def test_get_subscriptions_with_name_id_and_mobile(mock_get, cell, user):
fake_saml_request = mock.Mock()
fake_saml_request.user = mock.Mock(email=USER_EMAIL)
- fake_saml_request.user.saml_identifiers = mock.Mock()
- fake_saml_request.user.saml_identifiers.exists.return_value = True
- fake_saml_request.user.saml_identifiers.first.return_value = mock.Mock(name_id='nameid')
+ fake_saml_request.user.get_name_id.return_value = 'nameid'
fake_saml_request.session = {'mellon_session': {'mobile': '0607080900'}}
form = NewslettersManageForm(instance=cell, request=fake_saml_request)
diff --git a/tests/test_profile.py b/tests/test_profile.py
index 57681e14..26f8b39f 100644
--- a/tests/test_profile.py
+++ b/tests/test_profile.py
@@ -13,9 +13,8 @@ pytestmark = pytest.mark.django_db
@override_settings(
KNOWN_SERVICES={'authentic': {'idp': {'title': 'IdP', 'url': 'http://example.org/'}}})
-@mock.patch('combo.public.templatetags.combo.UserSAMLIdentifier')
@mock.patch('combo.utils.requests.get')
-def test_profile_cell(requests_get, user_saml, app, admin_user):
+def test_profile_cell(requests_get, app, admin_user):
page = Page()
page.save()
@@ -28,24 +27,9 @@ def test_profile_cell(requests_get, user_saml, app, admin_user):
json=lambda: data,
status_code=200)
- def filter_mock(user=None):
- assert user is admin_user
- return mock.Mock(last=lambda: mock.Mock(name_id='123456'))
-
- mocked_objects = mock.Mock()
- mocked_objects.filter = mock.Mock(side_effect=filter_mock)
- user_saml.objects = mocked_objects
+ admin_user.get_name_id = lambda: '123456'
context = cell.get_cell_extra_context({'synchronous': True, 'selected_user': admin_user})
assert context['profile_fields']['first_name']['value'] == 'Foo'
assert context['profile_fields']['birthdate']['value'] == datetime.date(2018, 8, 10)
assert requests_get.call_args[0][0] == 'http://example.org/api/users/123456/'
-
- def filter_mock_missing(user=None):
- return mock.Mock(last=lambda: None)
-
- mocked_objects.filter = mock.Mock(side_effect=filter_mock_missing)
-
- context = cell.get_cell_extra_context({'synchronous': True, 'selected_user': admin_user})
- assert context['error'] == 'unknown user'
- assert requests_get.call_count == 1 # no new call was made
diff --git a/tests/test_public.py b/tests/test_public.py
index 92efafb2..b0983734 100644
--- a/tests/test_public.py
+++ b/tests/test_public.py
@@ -813,7 +813,7 @@ def test_sub_slug(app, john_doe, jane_doe):
assert 'XXYY' in resp.text
# custom behaviour for , it will add the SAML user to context
- with mock.patch('combo.public.views.UserSAMLIdentifier') as user_saml:
+ with mock.patch('combo.profile.utils.UserSAMLIdentifier') as user_saml:
class DoesNotExist(Exception):
pass
user_saml.DoesNotExist = DoesNotExist
diff --git a/tests/test_requests.py b/tests/test_requests.py
index 3d56edea..4d3b3a03 100644
--- a/tests/test_requests.py
+++ b/tests/test_requests.py
@@ -8,22 +8,18 @@ from django.utils.six.moves.urllib import parse as urlparse
from combo.utils import requests, check_query, NothingInCacheException
-class MockSAMLUser(object):
- name_id = 'r2d2'
-
class MockUser(object):
email = 'foo@example.net'
def is_authenticated(self):
return True
+ def get_name_id(self):
+ if self.samlized:
+ return 'r2d2'
+ return None
+
def __init__(self, samlized=True):
- class MockSAMLUsers(object):
- def exists(self):
- return True
- def first(self):
- return MockSAMLUser()
- if samlized:
- self.saml_identifiers = MockSAMLUsers()
+ self.samlized = samlized
def test_nosign():
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 37011684..8ccbde0c 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -9,22 +9,18 @@ from django.test.client import RequestFactory
from django.contrib.auth.models import AnonymousUser
-class MockSAMLUser(object):
- name_id = 'r2&d2'
-
class MockUser(object):
email = 'foo=3@example.net'
def is_authenticated(self):
return True
def __init__(self, samlized=True):
- class MockSAMLUsers(object):
- def exists(self):
- return True
- def first(self):
- return MockSAMLUser()
- if samlized:
- self.saml_identifiers = MockSAMLUsers()
+ self.samlized = samlized
+
+ def get_name_id(self):
+ if self.samlized:
+ return 'r2&d2'
+ return None
def test_crypto_url():
diff --git a/tests/test_wcs.py b/tests/test_wcs.py
index 449fe231..9add0ead 100644
--- a/tests/test_wcs.py
+++ b/tests/test_wcs.py
@@ -161,6 +161,14 @@ user2.store()
WCS_DIR = tempfile.mkdtemp()
WCS_PID = None
+
+class MockUser(object):
+ email = 'foo@example.net'
+ def is_authenticated(self):
+ return True
+ def get_name_id(self):
+ return None
+
def run_wcs_script(script, hostname):
script_path = os.path.join(WCS_DIR, script + '.py')
fd = open(script_path, 'w')
@@ -345,10 +353,6 @@ def test_current_forms_cell_render(context):
cell = WcsCurrentFormsCell(page=page, placeholder='content', order=0)
cell.save()
- class MockUser(object):
- email = 'foo@example.net'
- def is_authenticated(self):
- return True
context['request'].user = MockUser()
# query should fail as nothing is cached
@@ -412,10 +416,6 @@ def test_current_forms_cell_render_single_site(context):
cell.wcs_site = 'default'
cell.save()
- class MockUser(object):
- email = 'foo@example.net'
- def is_authenticated(self):
- return True
context['request'].user = MockUser()
# query should fail as nothing is cached
@@ -506,10 +506,6 @@ def test_current_drafts_cell_render_logged_in(context):
cell.save()
context['synchronous'] = True # to get fresh content
- class MockUser(object):
- email = 'foo@example.net'
- def is_authenticated(self):
- return True
context['request'].user = MockUser()
# default is to get current forms from all wcs sites
@@ -695,20 +691,14 @@ def test_backoffice_submission_cell_render(context):
result = cell.render(context)
assert '/backoffice/submission/a-private-form/' not in result
- class MockUser(object):
- email = 'foo@example.net'
- def is_authenticated(self):
- return True
context['request'].user = MockUser()
result = cell.render(context)
assert '/backoffice/submission/a-private-form/' not in result
- class MockUser(object):
+ class MockUser2(MockUser):
email = 'foo2@example.net'
- def is_authenticated(self):
- return True
- context['request'].user = MockUser()
+ context['request'].user = MockUser2()
result = cell.render(context)
assert '/backoffice/submission/a-private-form/' in result