diff --git a/src/authentic2/custom_user/models.py b/src/authentic2/custom_user/models.py index c3f60f2ea..711de9b5c 100644 --- a/src/authentic2/custom_user/models.py +++ b/src/authentic2/custom_user/models.py @@ -400,6 +400,18 @@ class User(AbstractBaseUser, PermissionMixin): deleted_user.save() return super().delete(**kwargs) + def get_missing_required_on_login_attributes(self): + attributes = Attribute.objects.filter(required_on_login=True, disabled=False).order_by( + 'order', 'label' + ) + + missing = [] + for attribute in attributes: + value = getattr(self.attributes, attribute.name, None) + if not value: + missing.append(attribute) + return missing + class DeletedUser(models.Model): deleted = models.DateTimeField(verbose_name=_('Deletion date'), auto_now_add=True) diff --git a/src/authentic2/idp/saml/saml2_endpoints.py b/src/authentic2/idp/saml/saml2_endpoints.py index 2c4fd3931..376af09d5 100644 --- a/src/authentic2/idp/saml/saml2_endpoints.py +++ b/src/authentic2/idp/saml/saml2_endpoints.py @@ -112,6 +112,7 @@ from authentic2.utils import misc as utils_misc from authentic2.utils.misc import datetime_to_xs_datetime, find_authentication_event from authentic2.utils.misc import get_backends as get_idp_backends from authentic2.utils.misc import login_require, make_url +from authentic2.utils.view_decorators import enable_view_restriction from . import app_settings @@ -494,6 +495,7 @@ def build_assertion(request, login, provider, nid_format='transient'): return kwargs['name_id_content'] +@enable_view_restriction @never_cache @csrf_exempt @log_assert @@ -662,6 +664,7 @@ def need_consent_for_federation(request, login, nid_format): return HttpResponseRedirect(url) +@enable_view_restriction @never_cache def continue_sso(request): consent_answer = None @@ -1025,6 +1028,7 @@ def check_delegated_authentication_permission(request): return request.user.is_superuser() +@enable_view_restriction @never_cache @csrf_exempt @login_required diff --git a/src/authentic2/middleware.py b/src/authentic2/middleware.py index 3b0516a89..1681a3bb8 100644 --- a/src/authentic2/middleware.py +++ b/src/authentic2/middleware.py @@ -126,6 +126,10 @@ class ViewRestrictionMiddleware(MiddlewareMixin): if view: return view + view = self.check_required_on_login_attribute_restriction(request, user) + if view: + return view + for plugin in plugins.get_plugins(): if hasattr(plugin, 'check_view_restrictions'): view = plugin.check_view_restrictions(request, user) @@ -136,6 +140,16 @@ class ViewRestrictionMiddleware(MiddlewareMixin): request.session['last_password_reset_check'] = now return None + def check_required_on_login_attribute_restriction(self, request, user): + # do not bother superuser with this + if user.is_superuser: + return None + + missing = user.get_missing_required_on_login_attributes() + if missing: + return 'profile_required_edit' + return None + def check_password_reset_view_restriction(self, request, user): # If user is authenticated and a password_reset_flag is set, force # redirect to password change and show a message. @@ -152,6 +166,9 @@ class ViewRestrictionMiddleware(MiddlewareMixin): def process_view(self, request, view_func, view_args, view_kwargs): '''If current view is not the one where we should be, redirect''' + if not getattr(view_func, 'enable_view_restriction', False): + return + view = self.check_view_restrictions(request) if not view: return @@ -160,10 +177,6 @@ class ViewRestrictionMiddleware(MiddlewareMixin): # do not block on the restricted view if url_name == view: return - - # prevent blocking people when they logout - if url_name == 'auth_logout': - return return utils_misc.redirect_and_come_back(request, view) diff --git a/src/authentic2/templates/authentic2/accounts_edit_required.html b/src/authentic2/templates/authentic2/accounts_edit_required.html new file mode 100644 index 000000000..597523d8a --- /dev/null +++ b/src/authentic2/templates/authentic2/accounts_edit_required.html @@ -0,0 +1,10 @@ +{% extends "authentic2/accounts_edit.html" %} +{% load i18n %} + +{% block content %} +{% block required-attributes-message %} +
{% trans "The following informations are required if you want to use this service:"%} {% for attribute in view.missing_attributes %}{{ attribute.label }}{% if not forloop.last %}, {% endif %}{% endfor %} +
+{% endblock %} +{{ block.super }} +{% endblock %} diff --git a/src/authentic2/urls.py b/src/authentic2/urls.py index bb4764a7c..9781a7354 100644 --- a/src/authentic2/urls.py +++ b/src/authentic2/urls.py @@ -58,6 +58,7 @@ accounts_urlpatterns = [ ), url(r'^logged-in/$', views.logged_in, name='logged-in'), url(r'^edit/$', views.edit_profile, name='profile_edit'), + url(r'^edit/required/$', views.edit_required_profile, name='profile_required_edit'), url(r'^edit/(?P[-\w]+)/$', views.edit_profile, name='profile_edit_with_scope'), url(r'^change-email/$', views.email_change, name='email-change'), url(r'^change-email/verify/$', views.email_change_verify, name='email-change-verify'), diff --git a/src/authentic2/utils/view_decorators.py b/src/authentic2/utils/view_decorators.py new file mode 100644 index 000000000..884d3d922 --- /dev/null +++ b/src/authentic2/utils/view_decorators.py @@ -0,0 +1,20 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2021 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +def enable_view_restriction(view): + view.enable_view_restriction = True + return view diff --git a/src/authentic2/views.py b/src/authentic2/views.py index c049ec992..10d9485a5 100644 --- a/src/authentic2/views.py +++ b/src/authentic2/views.py @@ -63,6 +63,7 @@ from .utils import misc as utils_misc from .utils import switch_user as utils_switch_user from .utils.evaluate import HTTPHeaders from .utils.service import get_service_from_request, get_service_from_token, set_service_ref +from .utils.view_decorators import enable_view_restriction User = get_user_model() @@ -157,6 +158,29 @@ edit_profile = decorators.setting_enabled('A2_PROFILE_CAN_EDIT_PROFILE')( ) +class EditRequired(EditProfile): + template_names = ['authentic2/accounts_edit_required.html'] + + def dispatch(self, request, *args, **kwargs): + self.missing_attributes = request.user.get_missing_required_on_login_attributes() + if not self.missing_attributes: + return utils_misc.redirect(request, self.get_success_url()) + return super().dispatch(request, *args, **kwargs) + + @classmethod + def get_fields(cls, scopes=None): + # only show the required fields + attribute_names = models.Attribute.objects.filter(required_on_login=True, disabled=False).values_list( + 'name', flat=True + ) + + fields, labels = utils_misc.get_fields_and_labels(attribute_names) + return fields, labels + + +edit_required_profile = login_required(EditRequired.as_view()) + + class EmailChangeView(cbv.TemplateNamesMixin, FormView): template_names = ['profiles/email_change.html', 'authentic2/change_email.html'] title = _('Email Change') @@ -404,7 +428,7 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView): return ctx -homepage = Homepage.as_view() +homepage = enable_view_restriction(Homepage.as_view()) class ProfileView(cbv.TemplateNamesMixin, TemplateView): @@ -536,7 +560,7 @@ class ProfileView(cbv.TemplateNamesMixin, TemplateView): return context -profile = login_required(ProfileView.as_view()) +profile = enable_view_restriction(login_required(ProfileView.as_view())) def logout_list(request): diff --git a/src/authentic2_idp_cas/views.py b/src/authentic2_idp_cas/views.py index 7190f708a..9d37a23f8 100644 --- a/src/authentic2_idp_cas/views.py +++ b/src/authentic2_idp_cas/views.py @@ -36,6 +36,7 @@ from authentic2.utils.misc import ( normalize_attribute_values, redirect, ) +from authentic2.utils.view_decorators import enable_view_restriction from authentic2.views import logout as logout_view from authentic2_idp_cas.constants import ( ATTRIBUTES_ELT, @@ -467,9 +468,9 @@ class LogoutView(View): return redirect(request, next_url) -login = LoginView.as_view() +login = enable_view_restriction(LoginView.as_view()) logout = LogoutView.as_view() -_continue = ContinueView.as_view() +_continue = enable_view_restriction(ContinueView.as_view()) validate = ValidateView.as_view() service_validate = ServiceValidateView.as_view() proxy = ProxyView.as_view() diff --git a/src/authentic2_idp_oidc/views.py b/src/authentic2_idp_oidc/views.py index 0b0cd1fda..cedc828a1 100644 --- a/src/authentic2_idp_oidc/views.py +++ b/src/authentic2_idp_oidc/views.py @@ -48,6 +48,7 @@ from authentic2 import hooks from authentic2.decorators import setting_enabled from authentic2.exponential_retry_timeout import ExponentialRetryTimeout from authentic2.utils.misc import last_authentication_event, login_require, make_url, redirect +from authentic2.utils.view_decorators import enable_view_restriction from authentic2.views import logout as a2_logout from django_rbac.utils import get_ou_model @@ -228,6 +229,7 @@ def certs(request, *args, **kwargs): return HttpResponse(utils.get_jwkset().export(private_keys=False), content_type='application/json') +@enable_view_restriction @setting_enabled('ENABLE', settings=app_settings) def authorize(request, *args, **kwargs): validated_redirect_uri = None diff --git a/tests/middlewares/__init__.py b/tests/middlewares/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/middlewares/test_required_on_login_restriction.py b/tests/middlewares/test_required_on_login_restriction.py new file mode 100644 index 000000000..4ce75c7e0 --- /dev/null +++ b/tests/middlewares/test_required_on_login_restriction.py @@ -0,0 +1,37 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2021 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from authentic2.models import Attribute + +from ..utils import login + + +def test_simple(app, db, simple_user): + Attribute.objects.create( + name='cgu_2021', + label='J\'accepte les conditions générales d\'utilisation', + kind='boolean', + required_on_login=True, + user_visible=True, + ) + resp = login(app, simple_user, path='/accounts/') + assert resp.location == '/accounts/edit/required/?next=/accounts/' + resp = resp.follow() + resp.form.set('cgu_2021', True) + resp = resp.form.submit() + assert resp.location == '/accounts/' + resp = resp.follow() + assert 'les conditions générales d\'utilisation\xa0:\nTrue' in resp.pyquery.text() diff --git a/tests/test_user_manager.py b/tests/test_user_manager.py index 3dd9f98cf..368f82605 100644 --- a/tests/test_user_manager.py +++ b/tests/test_user_manager.py @@ -336,9 +336,9 @@ def test_export_csv(settings, app, superuser, django_assert_num_queries): user_count = User.objects.count() # queries should be batched to keep prefetching working without # overspending memory for the queryset cache, 4 queries by batches - num_queries = int(4 + 4 * (user_count / DEFAULT_BATCH_SIZE + bool(user_count % DEFAULT_BATCH_SIZE))) + num_queries = int(4 * (user_count / DEFAULT_BATCH_SIZE + bool(user_count % DEFAULT_BATCH_SIZE))) # export task also perform one query to set trigram an another to get users count - num_queries += 2 + num_queries += 3 with django_assert_num_queries(num_queries): response = response.click('CSV')