misc: provide origin service in template context (#20699)
This commit is contained in:
parent
19d892f537
commit
41aa3734e3
|
@ -18,6 +18,7 @@ from pkg_resources import get_distribution
|
|||
from django.conf import settings
|
||||
|
||||
from . import utils, app_settings, constants
|
||||
from .models import Service
|
||||
|
||||
|
||||
class UserFederations(object):
|
||||
|
@ -59,4 +60,9 @@ def a2_processor(request):
|
|||
if hasattr(request, 'session'):
|
||||
variables['LAST_LOGIN'] = request.session.get(constants.LAST_LOGIN_SESSION_KEY)
|
||||
variables['USER_SWITCHED'] = constants.SWITCH_USER_SESSION_KEY in request.session
|
||||
if 'service_pk' in request.session:
|
||||
try:
|
||||
variables['service'] = Service.objects.get(pk=request.session['service_pk'])
|
||||
except Service.DoesNotExist:
|
||||
pass
|
||||
return variables
|
||||
|
|
|
@ -32,6 +32,7 @@ from django.utils.six.moves.urllib import parse as urlparse
|
|||
from django.shortcuts import render
|
||||
|
||||
from . import app_settings, utils, plugins
|
||||
from .utils.service import get_service_from_request
|
||||
|
||||
|
||||
class CollectIPMiddleware(MiddlewareMixin):
|
||||
|
@ -205,3 +206,17 @@ class CookieTestMiddleware(MiddlewareMixin):
|
|||
# set test cookie for 1 year
|
||||
response.set_cookie(self.COOKIE_NAME, '1', max_age=365 * 24 * 3600)
|
||||
return response
|
||||
|
||||
|
||||
class SaveServiceInSessionMiddleware:
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
service = None
|
||||
|
||||
service = get_service_from_request(request)
|
||||
if service:
|
||||
request.session['service_pk'] = service.pk
|
||||
|
||||
return self.get_response(request)
|
||||
|
|
|
@ -95,6 +95,7 @@ MIDDLEWARE = (
|
|||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.http.ConditionalGetMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'authentic2.middleware.SaveServiceInSessionMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.middleware.locale.LocaleMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
|
|
|
@ -11,6 +11,10 @@
|
|||
{{ form.media.css }}
|
||||
{% endblock %}
|
||||
|
||||
{% block bodyargs %}
|
||||
data-service-slug="{{ service.slug }}" data-service-name="{{ service.name }}"
|
||||
{% endblock %}
|
||||
|
||||
{% block extrascripts %}
|
||||
{{ block.super }}
|
||||
{{ form.media.js }}
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.models import Service
|
||||
from authentic2.utils.template import Template, TemplateError
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
@ -111,3 +115,25 @@ def test_render_template_missing_variable():
|
|||
with pytest.raises(TemplateError) as raised:
|
||||
template.render(context=context)
|
||||
assert 'missing template variable' in raised
|
||||
|
||||
|
||||
def test_service_in_template(app, simple_user, service):
|
||||
resp = app.get(reverse('auth_login') + '?service=%s' % service.slug)
|
||||
|
||||
assert resp.pyquery('body').attr('data-service-slug') == service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == service.name
|
||||
|
||||
resp.form.set('username', simple_user.username)
|
||||
resp.form.set('password', simple_user.username)
|
||||
response = resp.form.submit(name='login-password-submit')
|
||||
|
||||
resp = app.get(reverse('account_management'))
|
||||
assert resp.pyquery('body').attr('data-service-slug') == service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == service.name
|
||||
|
||||
# if user comes back from a different service, the information is updated
|
||||
new_service = Service.objects.create(ou=get_default_ou(), slug='service2',
|
||||
name='Service2')
|
||||
resp = app.get(reverse('account_management') + '?service=%s' % new_service.slug)
|
||||
assert resp.pyquery('body').attr('data-service-slug') == new_service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == new_service.name
|
||||
|
|
Loading…
Reference in New Issue