diff --git a/src/authentic2/custom_user/models.py b/src/authentic2/custom_user/models.py
index c3f60f2ea..711de9b5c 100644
--- a/src/authentic2/custom_user/models.py
+++ b/src/authentic2/custom_user/models.py
@@ -400,6 +400,18 @@ class User(AbstractBaseUser, PermissionMixin):
deleted_user.save()
return super().delete(**kwargs)
+ def get_missing_required_on_login_attributes(self):
+ attributes = Attribute.objects.filter(required_on_login=True, disabled=False).order_by(
+ 'order', 'label'
+ )
+
+ missing = []
+ for attribute in attributes:
+ value = getattr(self.attributes, attribute.name, None)
+ if not value:
+ missing.append(attribute)
+ return missing
+
class DeletedUser(models.Model):
deleted = models.DateTimeField(verbose_name=_('Deletion date'), auto_now_add=True)
diff --git a/src/authentic2/idp/saml/saml2_endpoints.py b/src/authentic2/idp/saml/saml2_endpoints.py
index 2c4fd3931..376af09d5 100644
--- a/src/authentic2/idp/saml/saml2_endpoints.py
+++ b/src/authentic2/idp/saml/saml2_endpoints.py
@@ -112,6 +112,7 @@ from authentic2.utils import misc as utils_misc
from authentic2.utils.misc import datetime_to_xs_datetime, find_authentication_event
from authentic2.utils.misc import get_backends as get_idp_backends
from authentic2.utils.misc import login_require, make_url
+from authentic2.utils.view_decorators import enable_view_restriction
from . import app_settings
@@ -494,6 +495,7 @@ def build_assertion(request, login, provider, nid_format='transient'):
return kwargs['name_id_content']
+@enable_view_restriction
@never_cache
@csrf_exempt
@log_assert
@@ -662,6 +664,7 @@ def need_consent_for_federation(request, login, nid_format):
return HttpResponseRedirect(url)
+@enable_view_restriction
@never_cache
def continue_sso(request):
consent_answer = None
@@ -1025,6 +1028,7 @@ def check_delegated_authentication_permission(request):
return request.user.is_superuser()
+@enable_view_restriction
@never_cache
@csrf_exempt
@login_required
diff --git a/src/authentic2/middleware.py b/src/authentic2/middleware.py
index 3b0516a89..1681a3bb8 100644
--- a/src/authentic2/middleware.py
+++ b/src/authentic2/middleware.py
@@ -126,6 +126,10 @@ class ViewRestrictionMiddleware(MiddlewareMixin):
if view:
return view
+ view = self.check_required_on_login_attribute_restriction(request, user)
+ if view:
+ return view
+
for plugin in plugins.get_plugins():
if hasattr(plugin, 'check_view_restrictions'):
view = plugin.check_view_restrictions(request, user)
@@ -136,6 +140,16 @@ class ViewRestrictionMiddleware(MiddlewareMixin):
request.session['last_password_reset_check'] = now
return None
+ def check_required_on_login_attribute_restriction(self, request, user):
+ # do not bother superuser with this
+ if user.is_superuser:
+ return None
+
+ missing = user.get_missing_required_on_login_attributes()
+ if missing:
+ return 'profile_required_edit'
+ return None
+
def check_password_reset_view_restriction(self, request, user):
# If user is authenticated and a password_reset_flag is set, force
# redirect to password change and show a message.
@@ -152,6 +166,9 @@ class ViewRestrictionMiddleware(MiddlewareMixin):
def process_view(self, request, view_func, view_args, view_kwargs):
'''If current view is not the one where we should be, redirect'''
+ if not getattr(view_func, 'enable_view_restriction', False):
+ return
+
view = self.check_view_restrictions(request)
if not view:
return
@@ -160,10 +177,6 @@ class ViewRestrictionMiddleware(MiddlewareMixin):
# do not block on the restricted view
if url_name == view:
return
-
- # prevent blocking people when they logout
- if url_name == 'auth_logout':
- return
return utils_misc.redirect_and_come_back(request, view)
diff --git a/src/authentic2/templates/authentic2/accounts_edit_required.html b/src/authentic2/templates/authentic2/accounts_edit_required.html
new file mode 100644
index 000000000..597523d8a
--- /dev/null
+++ b/src/authentic2/templates/authentic2/accounts_edit_required.html
@@ -0,0 +1,10 @@
+{% extends "authentic2/accounts_edit.html" %}
+{% load i18n %}
+
+{% block content %}
+{% block required-attributes-message %}
+
{% trans "The following informations are required if you want to use this service:"%} {% for attribute in view.missing_attributes %}{{ attribute.label }}{% if not forloop.last %}, {% endif %}{% endfor %}
+
+{% endblock %}
+{{ block.super }}
+{% endblock %}
diff --git a/src/authentic2/urls.py b/src/authentic2/urls.py
index bb4764a7c..9781a7354 100644
--- a/src/authentic2/urls.py
+++ b/src/authentic2/urls.py
@@ -58,6 +58,7 @@ accounts_urlpatterns = [
),
url(r'^logged-in/$', views.logged_in, name='logged-in'),
url(r'^edit/$', views.edit_profile, name='profile_edit'),
+ url(r'^edit/required/$', views.edit_required_profile, name='profile_required_edit'),
url(r'^edit/(?P[-\w]+)/$', views.edit_profile, name='profile_edit_with_scope'),
url(r'^change-email/$', views.email_change, name='email-change'),
url(r'^change-email/verify/$', views.email_change_verify, name='email-change-verify'),
diff --git a/src/authentic2/utils/view_decorators.py b/src/authentic2/utils/view_decorators.py
new file mode 100644
index 000000000..884d3d922
--- /dev/null
+++ b/src/authentic2/utils/view_decorators.py
@@ -0,0 +1,20 @@
+# authentic2 - versatile identity manager
+# Copyright (C) 2010-2021 Entr'ouvert
+#
+# This program is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+def enable_view_restriction(view):
+ view.enable_view_restriction = True
+ return view
diff --git a/src/authentic2/views.py b/src/authentic2/views.py
index c049ec992..10d9485a5 100644
--- a/src/authentic2/views.py
+++ b/src/authentic2/views.py
@@ -63,6 +63,7 @@ from .utils import misc as utils_misc
from .utils import switch_user as utils_switch_user
from .utils.evaluate import HTTPHeaders
from .utils.service import get_service_from_request, get_service_from_token, set_service_ref
+from .utils.view_decorators import enable_view_restriction
User = get_user_model()
@@ -157,6 +158,29 @@ edit_profile = decorators.setting_enabled('A2_PROFILE_CAN_EDIT_PROFILE')(
)
+class EditRequired(EditProfile):
+ template_names = ['authentic2/accounts_edit_required.html']
+
+ def dispatch(self, request, *args, **kwargs):
+ self.missing_attributes = request.user.get_missing_required_on_login_attributes()
+ if not self.missing_attributes:
+ return utils_misc.redirect(request, self.get_success_url())
+ return super().dispatch(request, *args, **kwargs)
+
+ @classmethod
+ def get_fields(cls, scopes=None):
+ # only show the required fields
+ attribute_names = models.Attribute.objects.filter(required_on_login=True, disabled=False).values_list(
+ 'name', flat=True
+ )
+
+ fields, labels = utils_misc.get_fields_and_labels(attribute_names)
+ return fields, labels
+
+
+edit_required_profile = login_required(EditRequired.as_view())
+
+
class EmailChangeView(cbv.TemplateNamesMixin, FormView):
template_names = ['profiles/email_change.html', 'authentic2/change_email.html']
title = _('Email Change')
@@ -404,7 +428,7 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView):
return ctx
-homepage = Homepage.as_view()
+homepage = enable_view_restriction(Homepage.as_view())
class ProfileView(cbv.TemplateNamesMixin, TemplateView):
@@ -536,7 +560,7 @@ class ProfileView(cbv.TemplateNamesMixin, TemplateView):
return context
-profile = login_required(ProfileView.as_view())
+profile = enable_view_restriction(login_required(ProfileView.as_view()))
def logout_list(request):
diff --git a/src/authentic2_idp_cas/views.py b/src/authentic2_idp_cas/views.py
index 7190f708a..9d37a23f8 100644
--- a/src/authentic2_idp_cas/views.py
+++ b/src/authentic2_idp_cas/views.py
@@ -36,6 +36,7 @@ from authentic2.utils.misc import (
normalize_attribute_values,
redirect,
)
+from authentic2.utils.view_decorators import enable_view_restriction
from authentic2.views import logout as logout_view
from authentic2_idp_cas.constants import (
ATTRIBUTES_ELT,
@@ -467,9 +468,9 @@ class LogoutView(View):
return redirect(request, next_url)
-login = LoginView.as_view()
+login = enable_view_restriction(LoginView.as_view())
logout = LogoutView.as_view()
-_continue = ContinueView.as_view()
+_continue = enable_view_restriction(ContinueView.as_view())
validate = ValidateView.as_view()
service_validate = ServiceValidateView.as_view()
proxy = ProxyView.as_view()
diff --git a/src/authentic2_idp_oidc/views.py b/src/authentic2_idp_oidc/views.py
index 0b0cd1fda..cedc828a1 100644
--- a/src/authentic2_idp_oidc/views.py
+++ b/src/authentic2_idp_oidc/views.py
@@ -48,6 +48,7 @@ from authentic2 import hooks
from authentic2.decorators import setting_enabled
from authentic2.exponential_retry_timeout import ExponentialRetryTimeout
from authentic2.utils.misc import last_authentication_event, login_require, make_url, redirect
+from authentic2.utils.view_decorators import enable_view_restriction
from authentic2.views import logout as a2_logout
from django_rbac.utils import get_ou_model
@@ -228,6 +229,7 @@ def certs(request, *args, **kwargs):
return HttpResponse(utils.get_jwkset().export(private_keys=False), content_type='application/json')
+@enable_view_restriction
@setting_enabled('ENABLE', settings=app_settings)
def authorize(request, *args, **kwargs):
validated_redirect_uri = None
diff --git a/tests/middlewares/__init__.py b/tests/middlewares/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/middlewares/test_required_on_login_restriction.py b/tests/middlewares/test_required_on_login_restriction.py
new file mode 100644
index 000000000..4ce75c7e0
--- /dev/null
+++ b/tests/middlewares/test_required_on_login_restriction.py
@@ -0,0 +1,37 @@
+# authentic2 - versatile identity manager
+# Copyright (C) 2021 Entr'ouvert
+#
+# This program is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from authentic2.models import Attribute
+
+from ..utils import login
+
+
+def test_simple(app, db, simple_user):
+ Attribute.objects.create(
+ name='cgu_2021',
+ label='J\'accepte les conditions générales d\'utilisation',
+ kind='boolean',
+ required_on_login=True,
+ user_visible=True,
+ )
+ resp = login(app, simple_user, path='/accounts/')
+ assert resp.location == '/accounts/edit/required/?next=/accounts/'
+ resp = resp.follow()
+ resp.form.set('cgu_2021', True)
+ resp = resp.form.submit()
+ assert resp.location == '/accounts/'
+ resp = resp.follow()
+ assert 'les conditions générales d\'utilisation\xa0:\nTrue' in resp.pyquery.text()
diff --git a/tests/test_user_manager.py b/tests/test_user_manager.py
index 3dd9f98cf..368f82605 100644
--- a/tests/test_user_manager.py
+++ b/tests/test_user_manager.py
@@ -336,9 +336,9 @@ def test_export_csv(settings, app, superuser, django_assert_num_queries):
user_count = User.objects.count()
# queries should be batched to keep prefetching working without
# overspending memory for the queryset cache, 4 queries by batches
- num_queries = int(4 + 4 * (user_count / DEFAULT_BATCH_SIZE + bool(user_count % DEFAULT_BATCH_SIZE)))
+ num_queries = int(4 * (user_count / DEFAULT_BATCH_SIZE + bool(user_count % DEFAULT_BATCH_SIZE)))
# export task also perform one query to set trigram an another to get users count
- num_queries += 2
+ num_queries += 3
with django_assert_num_queries(num_queries):
response = response.click('CSV')