misc: maintain home url, service and ou (#61199)

Home service is defined on SSO requests.

Home URL is the OU home URL or the default homepage url, but on
account's pages if you pass a ?next= to an /accounts/ or /register/ view
and this URL can be linked ton an existing service, the home service is
set and the URL is take as the home URL.
This commit is contained in:
Benjamin Dauvergne 2022-01-28 00:31:04 +01:00
parent f72d1d3b2a
commit 051d27b068
23 changed files with 273 additions and 205 deletions

View File

@ -29,7 +29,7 @@ from . import app_settings, views
from .forms import authentication as authentication_forms
from .utils import misc as utils_misc
from .utils.evaluate import evaluate_condition
from .utils.service import get_service_from_request
from .utils.service import get_service
from .utils.views import csrf_token_check
logger = logging.getLogger(__name__)
@ -88,7 +88,8 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
return []
return OU.objects.filter(pk__in=service_ou_ids)
def get_preferred_ous(self, request, service):
def get_preferred_ous(self, request):
service = get_service(request)
preferred_ous_cookie = utils_misc.get_remember_cookie(request, 'preferred-ous')
preferred_ous = []
if preferred_ous_cookie:
@ -102,7 +103,6 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
return preferred_ous
def login(self, request, *args, **kwargs):
service = get_service_from_request(request)
context = kwargs.get('context', {})
is_post = request.method == 'POST' and self.submit_name in request.POST
data = request.POST if is_post else None
@ -112,7 +112,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
# Special handling when the form contains an OU selector
if app_settings.A2_LOGIN_FORM_OU_SELECTOR:
preferred_ous = self.get_preferred_ous(request, service)
preferred_ous = self.get_preferred_ous(request)
if preferred_ous:
initial['ou'] = preferred_ous[0]
@ -135,7 +135,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
if form.cleaned_data.get('remember_me'):
request.session['remember_me'] = True
request.session.set_expiry(app_settings.A2_USER_REMEMBER_ME)
response = utils_misc.login(request, form.get_user(), how, service=service)
response = utils_misc.login(request, form.get_user(), how)
if 'ou' in form.fields:
utils_misc.prepend_remember_cookie(
request, response, 'preferred-ous', form.cleaned_data['ou'].pk

View File

@ -20,5 +20,4 @@ CANCEL_FIELD_NAME = 'cancel'
AUTHENTICATION_EVENTS_SESSION_KEY = 'authentication-events'
SWITCH_USER_SESSION_KEY = '_switch_user'
LAST_LOGIN_SESSION_KEY = '_last_login'
SERVICE_FIELD_NAME = 'service'
NEXT_URL_SIGNATURE = 'next-signature'

View File

@ -20,6 +20,7 @@ from pkg_resources import get_distribution
from . import app_settings, constants
from .models import Service
from .utils import misc as utils_misc
from .utils.service import get_service
class UserFederations:
@ -69,3 +70,19 @@ def a2_processor(request):
except Service.DoesNotExist:
pass
return variables
def home(request):
ctx = {}
service = get_service(request)
if service:
ctx['home_service'] = service
if service.ou:
ctx['home_ou'] = service.ou
if request.session.get('home_url'):
ctx['home_url'] = request.session['home_url']
elif service and service.ou and service.ou.home_url:
ctx['home_url'] = service.ou.home_url
else:
ctx['home_url'] = app_settings.A2_HOMEPAGE_URL or settings.LOGIN_REDIRECT_URL
return ctx

View File

@ -113,6 +113,7 @@ from authentic2.utils import misc as utils_misc
from authentic2.utils.misc import datetime_to_xs_datetime, find_authentication_event
from authentic2.utils.misc import get_backends as get_idp_backends
from authentic2.utils.misc import login_require, make_url
from authentic2.utils.service import set_service
from authentic2.utils.view_decorators import check_view_restriction, enable_view_restriction
from . import app_settings
@ -582,6 +583,7 @@ def sso(request):
},
)
else:
set_service(request, provider_loaded)
policy = get_sp_options_policy(provider_loaded)
if not policy:
return error_page(request, _('sso: No SP policy defined'), logger=logger, warning=True)
@ -628,7 +630,7 @@ def sso(request):
return sso_after_process_request(request, login, nid_format=nid_format)
def need_login(request, login, nid_format, service):
def need_login(request, login, nid_format):
"""Redirect to the login page with a nonce parameter to verify later that
the login form was submitted
"""
@ -640,7 +642,6 @@ def need_login(request, login, nid_format, service):
request,
next_url=next_url,
params={NONCE_FIELD_NAME: nonce},
service=service,
login_hint=get_login_hints_extension(login),
)
@ -789,7 +790,7 @@ def sso_after_process_request(
if not passive and (user.is_anonymous or (force_authn and not did_auth)):
logger.debug('login required')
return need_login(request, login, nid_format, service)
return need_login(request, login, nid_format)
# No user is authenticated and passive is True, deny request
if passive and user.is_anonymous:
@ -1296,6 +1297,7 @@ def slo_soap(request):
except ObjectDoesNotExist:
logger.warning('provider %r unknown', logout.remoteProviderId)
return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED)
set_service(request, provider)
policy = get_sp_options_policy(provider)
if not policy:
logger.warning('No policy found for %s', logout.remoteProviderId)
@ -1385,6 +1387,7 @@ def slo(request):
except ObjectDoesNotExist:
logger.debug('provider %r unknown', logout.remoteProviderId)
return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED)
set_service(request, provider)
policy = get_sp_options_policy(provider)
if not policy:
logger.debug('No policy found for %s', logout.remoteProviderId)

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from authentic2.apps.journal.journal import Journal as BaseJournal
from authentic2.utils.service import get_service_from_request
from authentic2.utils.service import get_service
class Journal(BaseJournal):
@ -25,7 +25,7 @@ class Journal(BaseJournal):
@property
def service(self):
return self._service or (get_service_from_request(self.request) if self.request else None)
return self._service or get_service(self.request) if self.request else None
def massage_kwargs(self, record_parameters, kwargs):
if 'service' not in kwargs and 'service' in record_parameters:

View File

@ -28,12 +28,10 @@ from django.conf import settings
from django.contrib import messages
from django.db.models import Model
from django.utils.deprecation import MiddlewareMixin
from django.utils.functional import SimpleLazyObject
from django.utils.translation import ugettext as _
from . import app_settings, plugins
from .utils import misc as utils_misc
from .utils.service import get_service_from_request, get_service_from_session
class CollectIPMiddleware(MiddlewareMixin):
@ -263,18 +261,6 @@ class CookieTestMiddleware(MiddlewareMixin):
return response
class SaveServiceInSessionMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
service = get_service_from_request(request)
if service:
request.session['service_pk'] = service.pk
request.service = SimpleLazyObject(lambda: get_service_from_session(request))
return self.get_response(request)
def journal_middleware(get_response):
from . import journal

View File

@ -451,6 +451,9 @@ class Service(models.Model):
def get_absolute_url(self):
return reverse('a2-manager-service', kwargs={'service_pk': self.pk})
def get_base_urls(self):
return []
Service._meta.natural_key = [['slug', 'ou']]

View File

@ -81,6 +81,7 @@ TEMPLATES = [
'django.contrib.messages.context_processors.messages',
'django.template.context_processors.static',
'authentic2.context_processors.a2_processor',
'authentic2.context_processors.home',
],
},
},
@ -96,7 +97,6 @@ MIDDLEWARE = (
'django.middleware.common.CommonMiddleware',
'django.middleware.http.ConditionalGetMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'authentic2.middleware.SaveServiceInSessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.middleware.locale.LocaleMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',

View File

@ -12,7 +12,9 @@
{% endblock %}
{% block bodyargs %}
data-service-slug="{{ service.slug }}" data-service-name="{{ service.name }}"
{% if home_url %}data-home-url="{{ home_url }}"{% endif %}
{% if home_service %}data-home-service-slug="{{ home_service.slug }}" data-home-service-name="{{ home_service.name }}"{% endif %}
{% if home_ou %}data-home-ou-slug="{{ home_ou.slug }}" data-home-ou-name="{{ home_ou.name }}"{% endif %}
{% endblock %}
{% block extrascripts %}

View File

@ -47,10 +47,8 @@ from django.utils.formats import localize
from django.utils.translation import ungettext
from authentic2.saml.saml2utils import filter_attribute_private_key, filter_element_private_key
from authentic2.utils import crypto
from .. import app_settings, constants, plugins
from .service import set_service_ref
from .. import app_settings, constants, crypto, plugins
class CleanLogMessage(logging.Filter):
@ -455,15 +453,13 @@ def last_authentication_event(request=None, session=None):
return None
def login(request, user, how, service=None, service_slug=None, nonce=None, record=True, **kwargs):
def login(request, user, how, nonce=None, record=True, **kwargs):
"""Login a user model, record the authentication event and redirect to next
URL or settings.LOGIN_REDIRECT_URL."""
from .. import hooks
from .service import get_service
from .views import check_cookie_works
if service:
assert service_slug is None
service_slug = service.slug
check_cookie_works(request)
last_login = user.last_login
auth_login(request, user)
@ -472,23 +468,21 @@ def login(request, user, how, service=None, service_slug=None, nonce=None, recor
if constants.LAST_LOGIN_SESSION_KEY not in request.session:
request.session[constants.LAST_LOGIN_SESSION_KEY] = localize(to_current_timezone(last_login), True)
record_authentication_event(request, how, nonce=nonce)
hooks.call_hooks('event', name='login', user=user, how=how, service=service_slug)
hooks.call_hooks('event', name='login', user=user, how=how, service=get_service(request))
# prevent logint-hint to influence next use of the login page
if 'login-hint' in request.session:
del request.session['login-hint']
if record:
request.journal.record('user.login', how=how, service=service)
request.journal.record('user.login', how=how)
return continue_to_next_url(request, **kwargs)
def login_require(request, next_url=None, login_url='auth_login', service=None, login_hint=(), **kwargs):
def login_require(request, next_url=None, login_url='auth_login', login_hint=(), **kwargs):
'''Require a login and come back to current URL'''
next_url = next_url or request.get_full_path()
params = kwargs.setdefault('params', {})
params[REDIRECT_FIELD_NAME] = next_url
if service:
set_service_ref(params, service)
if login_hint:
request.session['login-hint'] = list(login_hint)
elif 'login-hint' in request.session:
@ -735,14 +729,12 @@ def get_fk_model(model, fieldname):
return field.related_model
def get_registration_url(request, service=None):
def get_registration_url(request):
next_url = select_next_url(request, settings.LOGIN_REDIRECT_URL)
next_url = make_url(
next_url, request=request, keep_params=True, include=(constants.NONCE_FIELD_NAME,), resolve=False
)
params = {REDIRECT_FIELD_NAME: next_url}
if service:
set_service_ref(params, service)
return make_url('registration_register', params=params)
@ -1041,9 +1033,17 @@ def get_next_url(params, field_name=None):
return next_url
def select_next_url(request, default, field_name=None, include_post=False, replace=None):
EMPTY = object()
def select_next_url(request, default=EMPTY, field_name=None, include_post=False, replace=None):
'''Select the first valid next URL'''
# pylint: disable=consider-using-ternary
if default is EMPTY:
if request.user.is_authenticated and request.user.ou and request.user.ou.home_url:
default = request.user.ou.home_url
else:
default = settings.LOGIN_REDIRECT_URL
next_url = (include_post and get_next_url(request.POST, field_name=field_name)) or get_next_url(
request.GET, field_name=field_name
)
@ -1144,7 +1144,7 @@ def same_origin(url1, url2):
return True
def simulate_authentication(request, user, method, backend=None, service=None, record=False, **kwargs):
def simulate_authentication(request, user, method, backend=None, record=False, **kwargs):
"""Simulate a normal login by eventually forcing a backend attribute on the
user instance"""
if not getattr(user, 'backend', None) and not backend:
@ -1152,7 +1152,7 @@ def simulate_authentication(request, user, method, backend=None, service=None, r
if backend:
user = copy.deepcopy(user)
user.backend = backend
return login(request, user, method, service=service, record=record, **kwargs)
return login(request, user, method, record=record, **kwargs)
def get_manager_login_url():

View File

@ -14,64 +14,71 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from authentic2.constants import SERVICE_FIELD_NAME
import urllib.parse
from django.apps import apps
from authentic2.decorators import GlobalCache
from authentic2.utils.misc import same_origin
def service_ref(service):
if service.ou:
return '%s %s' % (service.ou.slug, service.slug)
else:
return service.slug
def get_service_from_ref(ref):
@GlobalCache(timeout=60)
def _base_urls_map():
from authentic2.models import Service
splitted = ref.split(' ')
base_urls_map = {}
for service in Service.objects.select_related().select_subclasses():
for url in service.get_base_urls():
base_urls_map[url] = (type(service), service.pk)
return base_urls_map
try:
ou_slug, service_slug = splitted
except ValueError:
pass
else:
return Service.objects.filter(ou__slug=ou_slug, slug=service_slug).first()
try:
(service_slug,) = splitted
except ValueError:
return None
def clean_service(request):
request.session.pop('sevice_type', None)
request.session.pop('sevice_pk', None)
service = Service.objects.filter(ou__isnull=True, slug=service_slug).first()
def _set_session_service(session, service):
if 'home_url' in session:
del session['home_url']
if service:
return service
try:
return Service.objects.get(slug=service_slug)
except (Service.DoesNotExist, Service.MultipleObjectsReturned):
return None
session['service_type'] = [type(service)._meta.app_label, type(service)._meta.model_name]
session['service_pk'] = service.pk
def get_service_from_request(request):
service_ref = request.GET.get(SERVICE_FIELD_NAME)
if service_ref and '\x00' not in service_ref:
return get_service_from_ref(service_ref)
return None
def set_service(request, service):
# do not set service on non document fetch (<script> tags, etc..)
headers = request.headers
if 'sec-fetch-dest' in headers and headers['sec-fetch-dest'] != 'document':
return
request._service = service
_set_session_service(request.session, service)
def get_service_from_session(request):
session = getattr(request, 'session', None)
if session and 'service_pk' in session:
from authentic2.models import Service
def set_home_url(request, url=None):
if not url:
from .misc import select_next_url
return Service.objects.get(pk=session['service_pk'])
return None
url = select_next_url(request, default=None)
if not url or not urllib.parse.urlparse(url).netloc:
return
urls_map = _base_urls_map()
for base_url, (Model, pk) in urls_map.items():
if same_origin(base_url, url):
set_service(request, Model.object.get(pk=pk))
break
else:
clean_service(request)
request.session['home_url'] = url
def get_service_from_token(params):
ref = params.get(SERVICE_FIELD_NAME)
if not ref:
return None
return get_service_from_ref(ref)
def set_service_ref(params, service):
params[SERVICE_FIELD_NAME] = service_ref(service)
def get_service(request):
if not hasattr(request, '_service'):
if 'service_type' in request.session and 'service_pk' in request.session:
ServiceKlass = apps.get_app_config(request.session['service_type'][0]).get_model(
request.session['service_type'][1]
)
request._service = ServiceKlass.objects.get(pk=request.session['service_pk'])
else:
request._service = None
return request._service

View File

@ -63,7 +63,7 @@ from .utils import crypto
from .utils import misc as utils_misc
from .utils import switch_user as utils_switch_user
from .utils.evaluate import make_condition_context
from .utils.service import get_service_from_request, get_service_from_token, set_service_ref
from .utils.service import get_service, set_home_url
from .utils.view_decorators import enable_view_restriction
User = get_user_model()
@ -71,7 +71,13 @@ User = get_user_model()
logger = logging.getLogger(__name__)
class EditProfile(cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView):
class HomeURLMixin:
def dispatch(self, request, *args, **kwargs):
set_home_url(request)
return super().dispatch(request, *args, **kwargs)
class EditProfile(HomeURLMixin, cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView):
model = User
template_names = ['profiles/edit_profile.html', 'authentic2/accounts_edit.html']
title = _('Edit account data')
@ -182,7 +188,7 @@ class EditRequired(EditProfile):
edit_required_profile = login_required(EditRequired.as_view())
class EmailChangeView(cbv.TemplateNamesMixin, FormView):
class EmailChangeView(HomeURLMixin, cbv.TemplateNamesMixin, FormView):
template_names = ['profiles/email_change.html', 'authentic2/change_email.html']
title = _('Email Change')
success_url = '..'
@ -295,8 +301,6 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
redirect_to = request.GET.get(redirect_field_name)
service = get_service_from_request(request)
if not redirect_to or ' ' in redirect_to:
redirect_to = settings.LOGIN_REDIRECT_URL
# Heavier security check -- redirects to http://example.com should
@ -311,7 +315,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
blocks = []
registration_url = utils_misc.get_registration_url(request, service=service)
registration_url = utils_misc.get_registration_url(request)
context = {
'cancel': app_settings.A2_LOGIN_DISPLAY_A_CANCEL_BUTTON and nonce is not None,
@ -346,6 +350,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
parameters = {'request': request, 'context': context}
login_hint = set(request.session.get('login-hint', []))
show_ctx = make_condition_context(request=request, login_hint=login_hint)
service = get_service(request)
if service:
show_ctx['service_ou_slug'] = service.ou and service.ou.slug
show_ctx['service_slug'] = service.slug
@ -421,7 +426,12 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView):
template_names = ['idp/homepage.html', 'authentic2/homepage.html']
def dispatch(self, request, *args, **kwargs):
if app_settings.A2_HOMEPAGE_URL:
home_url = request.session.get('home_url')
home_url = home_url or (
request.user.is_authenticated and request.user and request.user.ou and request.user.ou.home_url
)
home_url = app_settings.A2_HOMEPAGE_URL
if home_url:
return utils_misc.redirect(request, app_settings.A2_HOMEPAGE_URL)
return login_required(super().dispatch)(request, *args, **kwargs)
@ -435,7 +445,7 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView):
homepage = enable_view_restriction(Homepage.as_view())
class ProfileView(cbv.TemplateNamesMixin, TemplateView):
class ProfileView(HomeURLMixin, cbv.TemplateNamesMixin, TemplateView):
template_names = ['idp/account_management.html', 'authentic2/accounts.html']
title = _('Your account')
@ -889,7 +899,7 @@ class PasswordResetConfirmView(cbv.RedirectToNextURLViewMixin, FormView):
password_reset_confirm = PasswordResetConfirmView.as_view()
class BaseRegistrationView(FormView):
class BaseRegistrationView(HomeURLMixin, FormView):
form_class = registration_forms.RegistrationForm
template_name = 'registration/registration_form.html'
title = _('Registration')
@ -912,6 +922,7 @@ class BaseRegistrationView(FormView):
if 'ou' in self.token:
self.ou = OU.objects.get(pk=self.token['ou'])
self.next_url = self.token.pop(REDIRECT_FIELD_NAME, utils_misc.select_next_url(request, None))
set_home_url(request, self.next_url)
return super().dispatch(request, *args, **kwargs)
def form_valid(self, form):
@ -978,11 +989,6 @@ class BaseRegistrationView(FormView):
for field in form.cleaned_data:
self.token[field] = form.cleaned_data[field]
# propagate service to the registration completion view
service = get_service_from_request(self.request)
if service:
set_service_ref(self.token, service)
self.token.pop(REDIRECT_FIELD_NAME, None)
self.token.pop('email', None)
@ -1066,8 +1072,7 @@ class RegistrationCompletionView(CreateView):
if self.ou:
self.email_is_unique |= self.ou.email_is_unique
self.init_fields_labels_and_help_texts()
# if registration is done during an SSO add the service to the registration event
self.service = get_service_from_token(self.token)
set_home_url(request, self.get_success_url())
return super().dispatch(request, *args, **kwargs)
def init_fields_labels_and_help_texts(self):
@ -1180,9 +1185,7 @@ class RegistrationCompletionView(CreateView):
def get(self, request, *args, **kwargs):
if len(self.users) == 1 and self.email_is_unique:
# Found one user, EMAIL is unique, log her in
utils_misc.simulate_authentication(
request, self.users[0], method=self.authentication_method, service=self.service
)
utils_misc.simulate_authentication(request, self.users[0], method=self.authentication_method)
return utils_misc.redirect(request, self.get_success_url())
confirm_data = self.token.get('confirm_data', False)
@ -1220,9 +1223,7 @@ class RegistrationCompletionView(CreateView):
uid = request.POST['uid']
for user in self.users:
if str(user.id) == uid:
utils_misc.simulate_authentication(
request, user, method=self.authentication_method, service=self.service
)
utils_misc.simulate_authentication(request, user, method=self.authentication_method)
return utils_misc.redirect(request, self.get_success_url())
return super().post(request, *args, **kwargs)
@ -1284,14 +1285,12 @@ class RegistrationCompletionView(CreateView):
view=self,
authentication_method=self.authentication_method,
token=self.token,
service=self.service and self.service.slug,
service=get_service(request),
)
self.send_registration_success_email(user)
def registration_success(self, request, user):
utils_misc.simulate_authentication(
request, user, method=self.authentication_method, service=self.service
)
utils_misc.simulate_authentication(request, user, method=self.authentication_method)
message_template = loader.get_template('authentic2/registration_success_message.html')
messages.info(self.request, message_template.render(request=request))
return utils_misc.redirect(request, self.get_success_url())
@ -1319,7 +1318,7 @@ class RegistrationCompletionView(CreateView):
registration_completion = RegistrationCompletionView.as_view()
class AccountDeleteView(TemplateView):
class AccountDeleteView(HomeURLMixin, TemplateView):
template_name = 'authentic2/accounts_delete_request.html'
title = _('Request account deletion')
@ -1407,7 +1406,7 @@ class RegistrationCompleteView(TemplateView):
registration_complete = RegistrationCompleteView.as_view()
class PasswordChangeView(DjPasswordChangeView):
class PasswordChangeView(HomeURLMixin, DjPasswordChangeView):
title = _('Password Change')
do_not_call_in_templates = True
@ -1471,7 +1470,7 @@ class SuView(View):
su = SuView.as_view()
class Consents(ListView):
class Consents(HomeURLMixin, ListView):
template_name = 'authentic2/consents.html'
title = _('Consent Management')
model = OIDCAuthorization

View File

@ -42,7 +42,6 @@ from authentic2.utils import misc as utils_misc
from authentic2.utils import views as utils_views
from authentic2.utils.crypto import check_hmac_url, hash_chain, hmac_url
from authentic2.utils.models import safe_get_or_create
from authentic2.utils.service import get_service_from_ref, get_service_from_request, service_ref
from . import app_settings, models
from .utils import (
@ -69,7 +68,6 @@ class LoginOrLinkView(View):
"""
_next_url = None
service = None
@property
def next_url(self):
@ -114,7 +112,7 @@ class LoginOrLinkView(View):
def handle_authorization_response(self, request, code, state):
# check state signature and parse it
try:
state, self._next_url, self.service = self.decode_state(state)
state, self._next_url = self.decode_state(state)
except ValueError:
return utils_misc.redirect(request, settings.LOGIN_REDIRECT_URL)
@ -186,10 +184,8 @@ class LoginOrLinkView(View):
else:
return self.login(request)
def encode_state(self, state, next_url, service):
encoded_state = state + ' ' + self.next_url + ' '
if service:
encoded_state += service_ref(service)
def encode_state(self, state, next_url):
encoded_state = state + ' ' + self.next_url
encoded_state += ' ' + hmac_url(settings.SECRET_KEY, encoded_state)
return encoded_state
@ -197,32 +193,23 @@ class LoginOrLinkView(View):
payload, signature = state.rsplit(' ', 1)
if not check_hmac_url(settings.SECRET_KEY, payload, signature):
raise ValueError
# service_ref can be made of one or two parts
try:
state, next_url, service_ref = payload.split(' ')
except ValueError:
state, next_url, ou_slug, service_slug = payload.split(' ')
service_ref = ou_slug + ' ' + service_slug
service = get_service_from_ref(service_ref)
return state, next_url, service
state, next_url, *dummy = payload.split(' ')
return state, next_url
def make_authorization_request(self, request):
scope = ' '.join(set(['openid'] + app_settings.scopes))
service = self.service or get_service_from_request(request)
nonce_seed, nonce, state = hash_chain(3)
# encode the target service and next_url in the state
full_state = state + ' ' + self.next_url + ' '
if service:
full_state += service_ref(service)
full_state += ' ' + hmac_url(settings.SECRET_KEY, full_state)
params = {
'client_id': app_settings.client_id,
'scope': scope,
'redirect_uri': self.redirect_uri,
'response_type': 'code',
'state': self.encode_state(state, self.next_url, service),
'state': self.encode_state(state, self.next_url),
'nonce': nonce,
'acr_values': 'eidas1',
}
@ -340,7 +327,7 @@ class LoginOrLinkView(View):
def finish_login(self, request, user, user_info, created):
self.update_user_info(user, user_info)
utils_views.check_cookie_works(request)
utils_misc.login(request, user, 'france-connect', service=self.service)
utils_misc.login(request, user, 'france-connect')
# keep id_token around for logout
request.session['fc_id_token'] = self.id_token

View File

@ -36,6 +36,7 @@ from authentic2.utils.misc import (
normalize_attribute_values,
redirect,
)
from authentic2.utils.service import set_service
from authentic2.utils.view_decorators import enable_view_restriction
from authentic2.views import logout as logout_view
from authentic2_idp_cas.constants import (
@ -151,6 +152,7 @@ class LoginView(CasMixin, View):
model = Service.objects.for_service(service)
if not model:
return self.failure(request, service, 'service unknown')
set_service(request, model)
if renew and gateway:
return self.failure(request, service, 'renew and gateway cannot be requested at the same time')
@ -464,6 +466,7 @@ class LogoutView(View):
if referrer:
model = Service.objects.for_service(referrer)
if model:
set_service(request, model)
return logout_view(request, next_url=next_url, check_referer=False, do_local=False)
return redirect(request, next_url)

View File

@ -41,6 +41,7 @@ from authentic2.a2_rbac.models import OrganizationalUnit
from authentic2.decorators import setting_enabled
from authentic2.exponential_retry_timeout import ExponentialRetryTimeout
from authentic2.utils.misc import last_authentication_event, login_require, make_url, redirect
from authentic2.utils.service import set_service
from authentic2.utils.view_decorators import check_view_restriction
from authentic2.views import logout as a2_logout
@ -246,6 +247,8 @@ def authorize(request, *args, **kwargs):
client = get_client(client_id=client_id)
if not client:
raise InvalidRequest(_('Unknown client identifier: "%s"') % client_id)
# define the current service
set_service(request, client)
try:
client.validate_redirect_uri(redirect_uri)
except ValueError:
@ -333,7 +336,7 @@ def authorize_for_client(request, client, redirect_uri):
params = {}
if nonce is not None:
params['nonce'] = nonce
return login_require(request, params=params, service=client, login_hint=login_hint)
return login_require(request, params=params, login_hint=login_hint)
# view restriction and passive SSO
if hasattr(request, 'view_restriction_response'):
@ -352,7 +355,7 @@ def authorize_for_client(request, client, redirect_uri):
params = {}
if nonce is not None:
params['nonce'] = nonce
return login_require(request, params=params, service=client, login_hint=login_hint)
return login_require(request, params=params, login_hint=login_hint)
iat = now() # iat = issued at
@ -812,6 +815,7 @@ def logout(request):
)
for provider in providers:
if post_logout_redirect_uri in provider.post_logout_redirect_uris.split():
set_service(request, provider)
break
else:
messages.warning(request, _('Invalid post logout URI'))

View File

@ -160,12 +160,16 @@ class FranceConnectMock:
@pytest.fixture
def franceconnect(settings, db):
def service(db):
return Service.objects.create(name='portail', slug='portail', ou=get_default_ou())
@pytest.fixture
def franceconnect(settings, service, db):
settings.A2_FC_ENABLE = True
settings.A2_FC_CLIENT_ID = CLIENT_ID
settings.A2_FC_CLIENT_SECRET = CLIENT_SECRET
Service.objects.create(name='portail', slug='portail', ou=get_default_ou())
mock_object = FranceConnectMock()
with mock_object():
yield mock_object

View File

@ -32,12 +32,12 @@ from authentic2.a2_rbac.models import OrganizationalUnit as OU
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.apps.journal.models import Event
from authentic2.custom_user.models import DeletedUser
from authentic2.models import Attribute, Service
from authentic2.models import Attribute
from authentic2_auth_fc import models
from authentic2_auth_fc.backends import FcBackend
from authentic2_auth_fc.utils import requests_retry_session
from ..utils import get_link_from_mail, login
from ..utils import get_link_from_mail, login, set_service
User = get_user_model()
@ -54,7 +54,7 @@ def test_fc_url_on_login(app, franceconnect):
def test_retry_authorization_if_state_is_lost(settings, app, franceconnect, hooks):
response = app.get('/fc/callback/?next=/idp/&service=default%20portail', status=302)
response = app.get('/fc/callback/?next=/idp/', status=302)
# clear fc-state cookie
app.cookiejar.clear()
response = franceconnect.handle_authorization(app, response.location, status=302)
@ -81,26 +81,26 @@ def test_login_autorun(settings, app, franceconnect):
assert response.location == reverse('fc-login-or-link')
def test_create(settings, app, franceconnect, hooks):
def test_create(settings, app, franceconnect, hooks, service):
# test direct creation
response = app.get('/login/?service=portail&next=/idp/')
set_service(app, service)
response = app.get('/login/?next=/idp/')
response = response.click(href='callback')
assert User.objects.count() == 0
assert Event.objects.which_references(Service.objects.get()).count() == 0
assert Event.objects.which_references(service).count() == 0
response = franceconnect.handle_authorization(app, response.location, status=302)
assert 'fc-state' not in app.cookies
assert User.objects.count() == 1
# check login for service=portail was registered
assert Event.objects.which_references(Service.objects.get()).count() == 1
assert Event.objects.which_references(service).count() == 1
user = User.objects.get()
assert user.verified_attributes.first_name == 'Ÿuñe'
assert user.verified_attributes.last_name == 'Frédérique'
assert path(response.location) == '/idp/'
assert hooks.event[1]['kwargs']['name'] == 'login'
assert hooks.event[1]['kwargs']['service'] == 'portail'
assert hooks.event[1]['kwargs']['service'] == service
# we must be connected
assert app.session['_auth_user_id']
assert app.session.get_expire_at_browser_close()
@ -130,7 +130,7 @@ def test_create_expired(settings, app, franceconnect, hooks):
# test direct creation failure on an expired id_token
franceconnect.exp = now() - datetime.timedelta(seconds=30)
response = app.get('/login/?service=portail&next=/idp/')
response = app.get('/login/?next=/idp/')
response = response.click(href='callback')
assert User.objects.count() == 0

View File

@ -965,6 +965,9 @@ def test_role_control_access(login_first, oidc_settings, oidc_client, simple_use
def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simple_user, hooks, mailoutbox):
redirect_uri = simple_oidc_client.redirect_uris.split()[0]
simple_oidc_client.ou.home_url = 'https://portal/'
simple_oidc_client.ou.save()
params = {
'client_id': simple_oidc_client.client_id,
'scope': 'openid profile email',
@ -977,19 +980,18 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl
authorize_url = make_url('oidc-authorize', params=params)
response = app.get(authorize_url)
location = urllib.parse.urlparse(response['Location'])
query = urllib.parse.parse_qs(location.query)
assert query['service'] == ['default client']
response = response.follow().click('Register')
location = urllib.parse.urlparse(response.request.url)
query = urllib.parse.parse_qs(location.query)
assert query['service'] == ['default client']
response.form.set('email', 'john.doe@example.com')
response = response.form.submit()
assert len(mailoutbox) == 1
link = utils.get_link_from_mail(mailoutbox[0])
response = app.get(link)
body = response.pyquery('body')[0]
assert body.attrib['data-home-ou-slug'] == 'default'
assert body.attrib['data-home-ou-name'] == 'Default organizational unit'
assert body.attrib['data-home-service-slug'] == 'client'
assert body.attrib['data-home-service-name'] == 'client'
assert body.attrib['data-home-url'] == 'https://portal/'
response.form.set('first_name', 'John')
response.form.set('last_name', 'Doe')
response.form.set('password1', 'T0==toto')
@ -999,11 +1001,11 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl
assert hooks.event[0]['kwargs']['service'].slug == 'client'
assert hooks.event[1]['kwargs']['name'] == 'registration'
assert hooks.event[1]['kwargs']['service'] == 'client'
assert hooks.event[1]['kwargs']['service'].slug == 'client'
assert hooks.event[2]['kwargs']['name'] == 'login'
assert hooks.event[2]['kwargs']['how'] == 'email'
assert hooks.event[2]['kwargs']['service'] == 'client'
assert hooks.event[2]['kwargs']['service'].slug == 'client'
def test_claim_default_value(oidc_settings, normal_oidc_client, simple_user, app):

View File

@ -0,0 +1,57 @@
# authentic2 - versatile identity manager
# Copyright (C) 2010-2022 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import utils
def test_home(app, settings, simple_user, service):
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Service
utils.set_service(app, service)
settings.LOGIN_REDIRECT_URL = 'https://portal1/'
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal1/'
assert body.attr('data-home-service-slug') == service.slug
assert body.attr('data-home-service-name') == service.name
assert body.attr('data-home-ou-slug') == service.ou.slug
assert body.attr('data-home-ou-name') == service.ou.name
settings.A2_HOMEPAGE_URL = 'https://portal2/'
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal2/'
service.ou.home_url = 'https://portal3/'
service.ou.save()
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal3/'
# if user comes back from a different service, the information is updated
new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2')
utils.set_service(app, new_service)
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal3/'
assert body.attr('data-home-service-slug') == new_service.slug
assert body.attr('data-home-service-name') == new_service.name
assert body.attr('data-home-ou-slug') == new_service.ou.slug
assert body.attr('data-home-ou-name') == new_service.ou.name

View File

@ -33,7 +33,7 @@ from django.utils.encoding import force_bytes, force_str, force_text
from django.utils.translation import gettext as _
from authentic2.a2_rbac.models import OrganizationalUnit, Role
from authentic2.constants import NONCE_FIELD_NAME, SERVICE_FIELD_NAME
from authentic2.constants import NONCE_FIELD_NAME
from authentic2.custom_user.models import User
from authentic2.idp.saml import saml2_endpoints
from authentic2.idp.saml.saml2_endpoints import get_extensions, get_login_hints_extension
@ -330,7 +330,6 @@ class Scenario:
reverse('auth_login'),
**{
'nonce': '*',
SERVICE_FIELD_NAME: 'default ' + self.sp.slug,
REDIRECT_FIELD_NAME: make_url(
'a2-idp-saml-continue', params={NONCE_FIELD_NAME: request_id}
),

View File

@ -22,7 +22,7 @@ from django.contrib.auth import get_user_model
from authentic2 import models
from authentic2.utils.misc import get_token_login_url
from .utils import assert_event, login
from .utils import assert_event, login, set_service
User = get_user_model()
@ -85,22 +85,22 @@ def test_show_condition(db, app, settings, caplog):
assert len(caplog.records) == 1
def test_show_condition_service(db, app, settings):
def test_show_condition_service(db, rf, app, settings):
portal = models.Service.objects.create(pk=1, name='Service', slug='portal')
service = models.Service.objects.create(pk=2, name='Service', slug='service')
settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': 'service_slug == \'portal\''}}
response = app.get('/login/', params={})
response = app.get('/login/')
assert 'name="login-password-submit"' not in response
# service doesn't exist
response = app.get('/login/', params={'service': 'portal'})
assert 'name="login-password-submit"' not in response
set_service(app, portal)
# Create a service
models.Service.objects.create(name='Service', slug='portal')
response = app.get('/login/', params={'service': 'portal'})
response = app.get('/login/')
assert 'name="login-password-submit"' in response
models.Service.objects.create(name='Service', slug='service')
response = app.get('/login/', params={'service': 'service'})
set_service(app, service)
response = app.get('/login/')
assert 'name="login-password-submit"' not in response
@ -251,29 +251,31 @@ def test_ou_selector(app, settings, simple_user, ou1, ou2, user_ou1, role_ou1):
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
set_service(app, service)
# service is specified but not access-control is defined, default for user is selected
response = app.get('/login/?service=service')
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
# service is specified, access control is defined but role is empty, default for user is selected
service.authorized_roles.through.objects.create(service=service, role=role_ou1)
response = app.get('/login/?service=service')
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
# user is added to role_ou1, default for user is still selected
user_ou1.roles.add(role_ou1)
response = app.get('/login/?service=service')
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
# Clear cookies, OU1 is selected
app.cookiejar.clear()
response = app.get('/login/?service=service')
set_service(app, service)
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU1'
# if we change the user's ou, then default selected OU changes
user_ou1.ou = ou2
user_ou1.save()
response = app.get('/login/?service=service')
response = app.get('/login/')
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU2'

View File

@ -15,10 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from django.urls import reverse
from authentic2.a2_rbac.utils import get_default_ou
from authentic2.models import Service
from authentic2.utils.template import Template, TemplateError
pytestmark = pytest.mark.django_db
@ -111,24 +108,3 @@ def test_render_template_missing_variable():
with pytest.raises(TemplateError) as raised:
template.render(context=context)
assert 'missing template variable' in raised
def test_service_in_template(app, simple_user, service):
resp = app.get(reverse('auth_login') + '?service=%s' % service.slug)
assert resp.pyquery('body').attr('data-service-slug') == service.slug
assert resp.pyquery('body').attr('data-service-name') == service.name
resp.form.set('username', simple_user.username)
resp.form.set('password', simple_user.username)
resp.form.submit(name='login-password-submit')
resp = app.get(reverse('account_management'))
assert resp.pyquery('body').attr('data-service-slug') == service.slug
assert resp.pyquery('body').attr('data-service-name') == service.name
# if user comes back from a different service, the information is updated
new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2')
resp = app.get(reverse('account_management') + '?service=%s' % new_service.slug)
assert resp.pyquery('body').attr('data-service-slug') == new_service.slug
assert resp.pyquery('body').attr('data-service-name') == new_service.name

View File

@ -307,3 +307,21 @@ def assert_event(event_type_name, user=None, session=None, service=None, target_
)
elif data and count > 1:
assert qs.filter(**{'data__' + k: v for k, v in data.items()}).count() == 1
def set_service(app, service):
from importlib import import_module
from django.conf import settings
from authentic2.utils.service import _set_session_service
engine = import_module(settings.SESSION_ENGINE)
if app.session == {}:
session = engine.SessionStore()
else:
session = app.session
_set_session_service(session, service)
session.save()
if app.session == {}:
app.set_cookie(settings.SESSION_COOKIE_NAME, session.session_key)