misc: maintain home url, service and ou (#61199)
Home service is defined on SSO requests. Home URL is the OU home URL or the default homepage url, but on account's pages if you pass a ?next= to an /accounts/ or /register/ view and this URL can be linked ton an existing service, the home service is set and the URL is take as the home URL.
This commit is contained in:
parent
f72d1d3b2a
commit
051d27b068
|
@ -29,7 +29,7 @@ from . import app_settings, views
|
|||
from .forms import authentication as authentication_forms
|
||||
from .utils import misc as utils_misc
|
||||
from .utils.evaluate import evaluate_condition
|
||||
from .utils.service import get_service_from_request
|
||||
from .utils.service import get_service
|
||||
from .utils.views import csrf_token_check
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -88,7 +88,8 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
|
|||
return []
|
||||
return OU.objects.filter(pk__in=service_ou_ids)
|
||||
|
||||
def get_preferred_ous(self, request, service):
|
||||
def get_preferred_ous(self, request):
|
||||
service = get_service(request)
|
||||
preferred_ous_cookie = utils_misc.get_remember_cookie(request, 'preferred-ous')
|
||||
preferred_ous = []
|
||||
if preferred_ous_cookie:
|
||||
|
@ -102,7 +103,6 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
|
|||
return preferred_ous
|
||||
|
||||
def login(self, request, *args, **kwargs):
|
||||
service = get_service_from_request(request)
|
||||
context = kwargs.get('context', {})
|
||||
is_post = request.method == 'POST' and self.submit_name in request.POST
|
||||
data = request.POST if is_post else None
|
||||
|
@ -112,7 +112,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
|
|||
|
||||
# Special handling when the form contains an OU selector
|
||||
if app_settings.A2_LOGIN_FORM_OU_SELECTOR:
|
||||
preferred_ous = self.get_preferred_ous(request, service)
|
||||
preferred_ous = self.get_preferred_ous(request)
|
||||
if preferred_ous:
|
||||
initial['ou'] = preferred_ous[0]
|
||||
|
||||
|
@ -135,7 +135,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator):
|
|||
if form.cleaned_data.get('remember_me'):
|
||||
request.session['remember_me'] = True
|
||||
request.session.set_expiry(app_settings.A2_USER_REMEMBER_ME)
|
||||
response = utils_misc.login(request, form.get_user(), how, service=service)
|
||||
response = utils_misc.login(request, form.get_user(), how)
|
||||
if 'ou' in form.fields:
|
||||
utils_misc.prepend_remember_cookie(
|
||||
request, response, 'preferred-ous', form.cleaned_data['ou'].pk
|
||||
|
|
|
@ -20,5 +20,4 @@ CANCEL_FIELD_NAME = 'cancel'
|
|||
AUTHENTICATION_EVENTS_SESSION_KEY = 'authentication-events'
|
||||
SWITCH_USER_SESSION_KEY = '_switch_user'
|
||||
LAST_LOGIN_SESSION_KEY = '_last_login'
|
||||
SERVICE_FIELD_NAME = 'service'
|
||||
NEXT_URL_SIGNATURE = 'next-signature'
|
||||
|
|
|
@ -20,6 +20,7 @@ from pkg_resources import get_distribution
|
|||
from . import app_settings, constants
|
||||
from .models import Service
|
||||
from .utils import misc as utils_misc
|
||||
from .utils.service import get_service
|
||||
|
||||
|
||||
class UserFederations:
|
||||
|
@ -69,3 +70,19 @@ def a2_processor(request):
|
|||
except Service.DoesNotExist:
|
||||
pass
|
||||
return variables
|
||||
|
||||
|
||||
def home(request):
|
||||
ctx = {}
|
||||
service = get_service(request)
|
||||
if service:
|
||||
ctx['home_service'] = service
|
||||
if service.ou:
|
||||
ctx['home_ou'] = service.ou
|
||||
if request.session.get('home_url'):
|
||||
ctx['home_url'] = request.session['home_url']
|
||||
elif service and service.ou and service.ou.home_url:
|
||||
ctx['home_url'] = service.ou.home_url
|
||||
else:
|
||||
ctx['home_url'] = app_settings.A2_HOMEPAGE_URL or settings.LOGIN_REDIRECT_URL
|
||||
return ctx
|
||||
|
|
|
@ -113,6 +113,7 @@ from authentic2.utils import misc as utils_misc
|
|||
from authentic2.utils.misc import datetime_to_xs_datetime, find_authentication_event
|
||||
from authentic2.utils.misc import get_backends as get_idp_backends
|
||||
from authentic2.utils.misc import login_require, make_url
|
||||
from authentic2.utils.service import set_service
|
||||
from authentic2.utils.view_decorators import check_view_restriction, enable_view_restriction
|
||||
|
||||
from . import app_settings
|
||||
|
@ -582,6 +583,7 @@ def sso(request):
|
|||
},
|
||||
)
|
||||
else:
|
||||
set_service(request, provider_loaded)
|
||||
policy = get_sp_options_policy(provider_loaded)
|
||||
if not policy:
|
||||
return error_page(request, _('sso: No SP policy defined'), logger=logger, warning=True)
|
||||
|
@ -628,7 +630,7 @@ def sso(request):
|
|||
return sso_after_process_request(request, login, nid_format=nid_format)
|
||||
|
||||
|
||||
def need_login(request, login, nid_format, service):
|
||||
def need_login(request, login, nid_format):
|
||||
"""Redirect to the login page with a nonce parameter to verify later that
|
||||
the login form was submitted
|
||||
"""
|
||||
|
@ -640,7 +642,6 @@ def need_login(request, login, nid_format, service):
|
|||
request,
|
||||
next_url=next_url,
|
||||
params={NONCE_FIELD_NAME: nonce},
|
||||
service=service,
|
||||
login_hint=get_login_hints_extension(login),
|
||||
)
|
||||
|
||||
|
@ -789,7 +790,7 @@ def sso_after_process_request(
|
|||
|
||||
if not passive and (user.is_anonymous or (force_authn and not did_auth)):
|
||||
logger.debug('login required')
|
||||
return need_login(request, login, nid_format, service)
|
||||
return need_login(request, login, nid_format)
|
||||
|
||||
# No user is authenticated and passive is True, deny request
|
||||
if passive and user.is_anonymous:
|
||||
|
@ -1296,6 +1297,7 @@ def slo_soap(request):
|
|||
except ObjectDoesNotExist:
|
||||
logger.warning('provider %r unknown', logout.remoteProviderId)
|
||||
return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED)
|
||||
set_service(request, provider)
|
||||
policy = get_sp_options_policy(provider)
|
||||
if not policy:
|
||||
logger.warning('No policy found for %s', logout.remoteProviderId)
|
||||
|
@ -1385,6 +1387,7 @@ def slo(request):
|
|||
except ObjectDoesNotExist:
|
||||
logger.debug('provider %r unknown', logout.remoteProviderId)
|
||||
return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED)
|
||||
set_service(request, provider)
|
||||
policy = get_sp_options_policy(provider)
|
||||
if not policy:
|
||||
logger.debug('No policy found for %s', logout.remoteProviderId)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from authentic2.apps.journal.journal import Journal as BaseJournal
|
||||
from authentic2.utils.service import get_service_from_request
|
||||
from authentic2.utils.service import get_service
|
||||
|
||||
|
||||
class Journal(BaseJournal):
|
||||
|
@ -25,7 +25,7 @@ class Journal(BaseJournal):
|
|||
|
||||
@property
|
||||
def service(self):
|
||||
return self._service or (get_service_from_request(self.request) if self.request else None)
|
||||
return self._service or get_service(self.request) if self.request else None
|
||||
|
||||
def massage_kwargs(self, record_parameters, kwargs):
|
||||
if 'service' not in kwargs and 'service' in record_parameters:
|
||||
|
|
|
@ -28,12 +28,10 @@ from django.conf import settings
|
|||
from django.contrib import messages
|
||||
from django.db.models import Model
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.translation import ugettext as _
|
||||
|
||||
from . import app_settings, plugins
|
||||
from .utils import misc as utils_misc
|
||||
from .utils.service import get_service_from_request, get_service_from_session
|
||||
|
||||
|
||||
class CollectIPMiddleware(MiddlewareMixin):
|
||||
|
@ -263,18 +261,6 @@ class CookieTestMiddleware(MiddlewareMixin):
|
|||
return response
|
||||
|
||||
|
||||
class SaveServiceInSessionMiddleware:
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
service = get_service_from_request(request)
|
||||
if service:
|
||||
request.session['service_pk'] = service.pk
|
||||
request.service = SimpleLazyObject(lambda: get_service_from_session(request))
|
||||
return self.get_response(request)
|
||||
|
||||
|
||||
def journal_middleware(get_response):
|
||||
from . import journal
|
||||
|
||||
|
|
|
@ -451,6 +451,9 @@ class Service(models.Model):
|
|||
def get_absolute_url(self):
|
||||
return reverse('a2-manager-service', kwargs={'service_pk': self.pk})
|
||||
|
||||
def get_base_urls(self):
|
||||
return []
|
||||
|
||||
|
||||
Service._meta.natural_key = [['slug', 'ou']]
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ TEMPLATES = [
|
|||
'django.contrib.messages.context_processors.messages',
|
||||
'django.template.context_processors.static',
|
||||
'authentic2.context_processors.a2_processor',
|
||||
'authentic2.context_processors.home',
|
||||
],
|
||||
},
|
||||
},
|
||||
|
@ -96,7 +97,6 @@ MIDDLEWARE = (
|
|||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.http.ConditionalGetMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'authentic2.middleware.SaveServiceInSessionMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.middleware.locale.LocaleMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
{% endblock %}
|
||||
|
||||
{% block bodyargs %}
|
||||
data-service-slug="{{ service.slug }}" data-service-name="{{ service.name }}"
|
||||
{% if home_url %}data-home-url="{{ home_url }}"{% endif %}
|
||||
{% if home_service %}data-home-service-slug="{{ home_service.slug }}" data-home-service-name="{{ home_service.name }}"{% endif %}
|
||||
{% if home_ou %}data-home-ou-slug="{{ home_ou.slug }}" data-home-ou-name="{{ home_ou.name }}"{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block extrascripts %}
|
||||
|
|
|
@ -47,10 +47,8 @@ from django.utils.formats import localize
|
|||
from django.utils.translation import ungettext
|
||||
|
||||
from authentic2.saml.saml2utils import filter_attribute_private_key, filter_element_private_key
|
||||
from authentic2.utils import crypto
|
||||
|
||||
from .. import app_settings, constants, plugins
|
||||
from .service import set_service_ref
|
||||
from .. import app_settings, constants, crypto, plugins
|
||||
|
||||
|
||||
class CleanLogMessage(logging.Filter):
|
||||
|
@ -455,15 +453,13 @@ def last_authentication_event(request=None, session=None):
|
|||
return None
|
||||
|
||||
|
||||
def login(request, user, how, service=None, service_slug=None, nonce=None, record=True, **kwargs):
|
||||
def login(request, user, how, nonce=None, record=True, **kwargs):
|
||||
"""Login a user model, record the authentication event and redirect to next
|
||||
URL or settings.LOGIN_REDIRECT_URL."""
|
||||
from .. import hooks
|
||||
from .service import get_service
|
||||
from .views import check_cookie_works
|
||||
|
||||
if service:
|
||||
assert service_slug is None
|
||||
service_slug = service.slug
|
||||
check_cookie_works(request)
|
||||
last_login = user.last_login
|
||||
auth_login(request, user)
|
||||
|
@ -472,23 +468,21 @@ def login(request, user, how, service=None, service_slug=None, nonce=None, recor
|
|||
if constants.LAST_LOGIN_SESSION_KEY not in request.session:
|
||||
request.session[constants.LAST_LOGIN_SESSION_KEY] = localize(to_current_timezone(last_login), True)
|
||||
record_authentication_event(request, how, nonce=nonce)
|
||||
hooks.call_hooks('event', name='login', user=user, how=how, service=service_slug)
|
||||
hooks.call_hooks('event', name='login', user=user, how=how, service=get_service(request))
|
||||
# prevent logint-hint to influence next use of the login page
|
||||
if 'login-hint' in request.session:
|
||||
del request.session['login-hint']
|
||||
if record:
|
||||
request.journal.record('user.login', how=how, service=service)
|
||||
request.journal.record('user.login', how=how)
|
||||
return continue_to_next_url(request, **kwargs)
|
||||
|
||||
|
||||
def login_require(request, next_url=None, login_url='auth_login', service=None, login_hint=(), **kwargs):
|
||||
def login_require(request, next_url=None, login_url='auth_login', login_hint=(), **kwargs):
|
||||
'''Require a login and come back to current URL'''
|
||||
|
||||
next_url = next_url or request.get_full_path()
|
||||
params = kwargs.setdefault('params', {})
|
||||
params[REDIRECT_FIELD_NAME] = next_url
|
||||
if service:
|
||||
set_service_ref(params, service)
|
||||
if login_hint:
|
||||
request.session['login-hint'] = list(login_hint)
|
||||
elif 'login-hint' in request.session:
|
||||
|
@ -735,14 +729,12 @@ def get_fk_model(model, fieldname):
|
|||
return field.related_model
|
||||
|
||||
|
||||
def get_registration_url(request, service=None):
|
||||
def get_registration_url(request):
|
||||
next_url = select_next_url(request, settings.LOGIN_REDIRECT_URL)
|
||||
next_url = make_url(
|
||||
next_url, request=request, keep_params=True, include=(constants.NONCE_FIELD_NAME,), resolve=False
|
||||
)
|
||||
params = {REDIRECT_FIELD_NAME: next_url}
|
||||
if service:
|
||||
set_service_ref(params, service)
|
||||
return make_url('registration_register', params=params)
|
||||
|
||||
|
||||
|
@ -1041,9 +1033,17 @@ def get_next_url(params, field_name=None):
|
|||
return next_url
|
||||
|
||||
|
||||
def select_next_url(request, default, field_name=None, include_post=False, replace=None):
|
||||
EMPTY = object()
|
||||
|
||||
|
||||
def select_next_url(request, default=EMPTY, field_name=None, include_post=False, replace=None):
|
||||
'''Select the first valid next URL'''
|
||||
# pylint: disable=consider-using-ternary
|
||||
if default is EMPTY:
|
||||
if request.user.is_authenticated and request.user.ou and request.user.ou.home_url:
|
||||
default = request.user.ou.home_url
|
||||
else:
|
||||
default = settings.LOGIN_REDIRECT_URL
|
||||
next_url = (include_post and get_next_url(request.POST, field_name=field_name)) or get_next_url(
|
||||
request.GET, field_name=field_name
|
||||
)
|
||||
|
@ -1144,7 +1144,7 @@ def same_origin(url1, url2):
|
|||
return True
|
||||
|
||||
|
||||
def simulate_authentication(request, user, method, backend=None, service=None, record=False, **kwargs):
|
||||
def simulate_authentication(request, user, method, backend=None, record=False, **kwargs):
|
||||
"""Simulate a normal login by eventually forcing a backend attribute on the
|
||||
user instance"""
|
||||
if not getattr(user, 'backend', None) and not backend:
|
||||
|
@ -1152,7 +1152,7 @@ def simulate_authentication(request, user, method, backend=None, service=None, r
|
|||
if backend:
|
||||
user = copy.deepcopy(user)
|
||||
user.backend = backend
|
||||
return login(request, user, method, service=service, record=record, **kwargs)
|
||||
return login(request, user, method, record=record, **kwargs)
|
||||
|
||||
|
||||
def get_manager_login_url():
|
||||
|
|
|
@ -14,64 +14,71 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from authentic2.constants import SERVICE_FIELD_NAME
|
||||
import urllib.parse
|
||||
|
||||
from django.apps import apps
|
||||
|
||||
from authentic2.decorators import GlobalCache
|
||||
from authentic2.utils.misc import same_origin
|
||||
|
||||
|
||||
def service_ref(service):
|
||||
if service.ou:
|
||||
return '%s %s' % (service.ou.slug, service.slug)
|
||||
else:
|
||||
return service.slug
|
||||
|
||||
|
||||
def get_service_from_ref(ref):
|
||||
@GlobalCache(timeout=60)
|
||||
def _base_urls_map():
|
||||
from authentic2.models import Service
|
||||
|
||||
splitted = ref.split(' ')
|
||||
base_urls_map = {}
|
||||
for service in Service.objects.select_related().select_subclasses():
|
||||
for url in service.get_base_urls():
|
||||
base_urls_map[url] = (type(service), service.pk)
|
||||
return base_urls_map
|
||||
|
||||
try:
|
||||
ou_slug, service_slug = splitted
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
return Service.objects.filter(ou__slug=ou_slug, slug=service_slug).first()
|
||||
|
||||
try:
|
||||
(service_slug,) = splitted
|
||||
except ValueError:
|
||||
return None
|
||||
def clean_service(request):
|
||||
request.session.pop('sevice_type', None)
|
||||
request.session.pop('sevice_pk', None)
|
||||
|
||||
service = Service.objects.filter(ou__isnull=True, slug=service_slug).first()
|
||||
|
||||
def _set_session_service(session, service):
|
||||
if 'home_url' in session:
|
||||
del session['home_url']
|
||||
if service:
|
||||
return service
|
||||
try:
|
||||
return Service.objects.get(slug=service_slug)
|
||||
except (Service.DoesNotExist, Service.MultipleObjectsReturned):
|
||||
return None
|
||||
session['service_type'] = [type(service)._meta.app_label, type(service)._meta.model_name]
|
||||
session['service_pk'] = service.pk
|
||||
|
||||
|
||||
def get_service_from_request(request):
|
||||
service_ref = request.GET.get(SERVICE_FIELD_NAME)
|
||||
if service_ref and '\x00' not in service_ref:
|
||||
return get_service_from_ref(service_ref)
|
||||
return None
|
||||
def set_service(request, service):
|
||||
# do not set service on non document fetch (<script> tags, etc..)
|
||||
headers = request.headers
|
||||
if 'sec-fetch-dest' in headers and headers['sec-fetch-dest'] != 'document':
|
||||
return
|
||||
request._service = service
|
||||
_set_session_service(request.session, service)
|
||||
|
||||
|
||||
def get_service_from_session(request):
|
||||
session = getattr(request, 'session', None)
|
||||
if session and 'service_pk' in session:
|
||||
from authentic2.models import Service
|
||||
def set_home_url(request, url=None):
|
||||
if not url:
|
||||
from .misc import select_next_url
|
||||
|
||||
return Service.objects.get(pk=session['service_pk'])
|
||||
return None
|
||||
url = select_next_url(request, default=None)
|
||||
if not url or not urllib.parse.urlparse(url).netloc:
|
||||
return
|
||||
urls_map = _base_urls_map()
|
||||
for base_url, (Model, pk) in urls_map.items():
|
||||
if same_origin(base_url, url):
|
||||
set_service(request, Model.object.get(pk=pk))
|
||||
break
|
||||
else:
|
||||
clean_service(request)
|
||||
request.session['home_url'] = url
|
||||
|
||||
|
||||
def get_service_from_token(params):
|
||||
ref = params.get(SERVICE_FIELD_NAME)
|
||||
if not ref:
|
||||
return None
|
||||
return get_service_from_ref(ref)
|
||||
|
||||
|
||||
def set_service_ref(params, service):
|
||||
params[SERVICE_FIELD_NAME] = service_ref(service)
|
||||
def get_service(request):
|
||||
if not hasattr(request, '_service'):
|
||||
if 'service_type' in request.session and 'service_pk' in request.session:
|
||||
ServiceKlass = apps.get_app_config(request.session['service_type'][0]).get_model(
|
||||
request.session['service_type'][1]
|
||||
)
|
||||
request._service = ServiceKlass.objects.get(pk=request.session['service_pk'])
|
||||
else:
|
||||
request._service = None
|
||||
return request._service
|
||||
|
|
|
@ -63,7 +63,7 @@ from .utils import crypto
|
|||
from .utils import misc as utils_misc
|
||||
from .utils import switch_user as utils_switch_user
|
||||
from .utils.evaluate import make_condition_context
|
||||
from .utils.service import get_service_from_request, get_service_from_token, set_service_ref
|
||||
from .utils.service import get_service, set_home_url
|
||||
from .utils.view_decorators import enable_view_restriction
|
||||
|
||||
User = get_user_model()
|
||||
|
@ -71,7 +71,13 @@ User = get_user_model()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditProfile(cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView):
|
||||
class HomeURLMixin:
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
set_home_url(request)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class EditProfile(HomeURLMixin, cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView):
|
||||
model = User
|
||||
template_names = ['profiles/edit_profile.html', 'authentic2/accounts_edit.html']
|
||||
title = _('Edit account data')
|
||||
|
@ -182,7 +188,7 @@ class EditRequired(EditProfile):
|
|||
edit_required_profile = login_required(EditRequired.as_view())
|
||||
|
||||
|
||||
class EmailChangeView(cbv.TemplateNamesMixin, FormView):
|
||||
class EmailChangeView(HomeURLMixin, cbv.TemplateNamesMixin, FormView):
|
||||
template_names = ['profiles/email_change.html', 'authentic2/change_email.html']
|
||||
title = _('Email Change')
|
||||
success_url = '..'
|
||||
|
@ -295,8 +301,6 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
|
|||
|
||||
redirect_to = request.GET.get(redirect_field_name)
|
||||
|
||||
service = get_service_from_request(request)
|
||||
|
||||
if not redirect_to or ' ' in redirect_to:
|
||||
redirect_to = settings.LOGIN_REDIRECT_URL
|
||||
# Heavier security check -- redirects to http://example.com should
|
||||
|
@ -311,7 +315,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
|
|||
|
||||
blocks = []
|
||||
|
||||
registration_url = utils_misc.get_registration_url(request, service=service)
|
||||
registration_url = utils_misc.get_registration_url(request)
|
||||
|
||||
context = {
|
||||
'cancel': app_settings.A2_LOGIN_DISPLAY_A_CANCEL_BUTTON and nonce is not None,
|
||||
|
@ -346,6 +350,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE
|
|||
parameters = {'request': request, 'context': context}
|
||||
login_hint = set(request.session.get('login-hint', []))
|
||||
show_ctx = make_condition_context(request=request, login_hint=login_hint)
|
||||
service = get_service(request)
|
||||
if service:
|
||||
show_ctx['service_ou_slug'] = service.ou and service.ou.slug
|
||||
show_ctx['service_slug'] = service.slug
|
||||
|
@ -421,7 +426,12 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView):
|
|||
template_names = ['idp/homepage.html', 'authentic2/homepage.html']
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
if app_settings.A2_HOMEPAGE_URL:
|
||||
home_url = request.session.get('home_url')
|
||||
home_url = home_url or (
|
||||
request.user.is_authenticated and request.user and request.user.ou and request.user.ou.home_url
|
||||
)
|
||||
home_url = app_settings.A2_HOMEPAGE_URL
|
||||
if home_url:
|
||||
return utils_misc.redirect(request, app_settings.A2_HOMEPAGE_URL)
|
||||
return login_required(super().dispatch)(request, *args, **kwargs)
|
||||
|
||||
|
@ -435,7 +445,7 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView):
|
|||
homepage = enable_view_restriction(Homepage.as_view())
|
||||
|
||||
|
||||
class ProfileView(cbv.TemplateNamesMixin, TemplateView):
|
||||
class ProfileView(HomeURLMixin, cbv.TemplateNamesMixin, TemplateView):
|
||||
template_names = ['idp/account_management.html', 'authentic2/accounts.html']
|
||||
title = _('Your account')
|
||||
|
||||
|
@ -889,7 +899,7 @@ class PasswordResetConfirmView(cbv.RedirectToNextURLViewMixin, FormView):
|
|||
password_reset_confirm = PasswordResetConfirmView.as_view()
|
||||
|
||||
|
||||
class BaseRegistrationView(FormView):
|
||||
class BaseRegistrationView(HomeURLMixin, FormView):
|
||||
form_class = registration_forms.RegistrationForm
|
||||
template_name = 'registration/registration_form.html'
|
||||
title = _('Registration')
|
||||
|
@ -912,6 +922,7 @@ class BaseRegistrationView(FormView):
|
|||
if 'ou' in self.token:
|
||||
self.ou = OU.objects.get(pk=self.token['ou'])
|
||||
self.next_url = self.token.pop(REDIRECT_FIELD_NAME, utils_misc.select_next_url(request, None))
|
||||
set_home_url(request, self.next_url)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def form_valid(self, form):
|
||||
|
@ -978,11 +989,6 @@ class BaseRegistrationView(FormView):
|
|||
for field in form.cleaned_data:
|
||||
self.token[field] = form.cleaned_data[field]
|
||||
|
||||
# propagate service to the registration completion view
|
||||
service = get_service_from_request(self.request)
|
||||
if service:
|
||||
set_service_ref(self.token, service)
|
||||
|
||||
self.token.pop(REDIRECT_FIELD_NAME, None)
|
||||
self.token.pop('email', None)
|
||||
|
||||
|
@ -1066,8 +1072,7 @@ class RegistrationCompletionView(CreateView):
|
|||
if self.ou:
|
||||
self.email_is_unique |= self.ou.email_is_unique
|
||||
self.init_fields_labels_and_help_texts()
|
||||
# if registration is done during an SSO add the service to the registration event
|
||||
self.service = get_service_from_token(self.token)
|
||||
set_home_url(request, self.get_success_url())
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def init_fields_labels_and_help_texts(self):
|
||||
|
@ -1180,9 +1185,7 @@ class RegistrationCompletionView(CreateView):
|
|||
def get(self, request, *args, **kwargs):
|
||||
if len(self.users) == 1 and self.email_is_unique:
|
||||
# Found one user, EMAIL is unique, log her in
|
||||
utils_misc.simulate_authentication(
|
||||
request, self.users[0], method=self.authentication_method, service=self.service
|
||||
)
|
||||
utils_misc.simulate_authentication(request, self.users[0], method=self.authentication_method)
|
||||
return utils_misc.redirect(request, self.get_success_url())
|
||||
confirm_data = self.token.get('confirm_data', False)
|
||||
|
||||
|
@ -1220,9 +1223,7 @@ class RegistrationCompletionView(CreateView):
|
|||
uid = request.POST['uid']
|
||||
for user in self.users:
|
||||
if str(user.id) == uid:
|
||||
utils_misc.simulate_authentication(
|
||||
request, user, method=self.authentication_method, service=self.service
|
||||
)
|
||||
utils_misc.simulate_authentication(request, user, method=self.authentication_method)
|
||||
return utils_misc.redirect(request, self.get_success_url())
|
||||
return super().post(request, *args, **kwargs)
|
||||
|
||||
|
@ -1284,14 +1285,12 @@ class RegistrationCompletionView(CreateView):
|
|||
view=self,
|
||||
authentication_method=self.authentication_method,
|
||||
token=self.token,
|
||||
service=self.service and self.service.slug,
|
||||
service=get_service(request),
|
||||
)
|
||||
self.send_registration_success_email(user)
|
||||
|
||||
def registration_success(self, request, user):
|
||||
utils_misc.simulate_authentication(
|
||||
request, user, method=self.authentication_method, service=self.service
|
||||
)
|
||||
utils_misc.simulate_authentication(request, user, method=self.authentication_method)
|
||||
message_template = loader.get_template('authentic2/registration_success_message.html')
|
||||
messages.info(self.request, message_template.render(request=request))
|
||||
return utils_misc.redirect(request, self.get_success_url())
|
||||
|
@ -1319,7 +1318,7 @@ class RegistrationCompletionView(CreateView):
|
|||
registration_completion = RegistrationCompletionView.as_view()
|
||||
|
||||
|
||||
class AccountDeleteView(TemplateView):
|
||||
class AccountDeleteView(HomeURLMixin, TemplateView):
|
||||
template_name = 'authentic2/accounts_delete_request.html'
|
||||
title = _('Request account deletion')
|
||||
|
||||
|
@ -1407,7 +1406,7 @@ class RegistrationCompleteView(TemplateView):
|
|||
registration_complete = RegistrationCompleteView.as_view()
|
||||
|
||||
|
||||
class PasswordChangeView(DjPasswordChangeView):
|
||||
class PasswordChangeView(HomeURLMixin, DjPasswordChangeView):
|
||||
title = _('Password Change')
|
||||
do_not_call_in_templates = True
|
||||
|
||||
|
@ -1471,7 +1470,7 @@ class SuView(View):
|
|||
su = SuView.as_view()
|
||||
|
||||
|
||||
class Consents(ListView):
|
||||
class Consents(HomeURLMixin, ListView):
|
||||
template_name = 'authentic2/consents.html'
|
||||
title = _('Consent Management')
|
||||
model = OIDCAuthorization
|
||||
|
|
|
@ -42,7 +42,6 @@ from authentic2.utils import misc as utils_misc
|
|||
from authentic2.utils import views as utils_views
|
||||
from authentic2.utils.crypto import check_hmac_url, hash_chain, hmac_url
|
||||
from authentic2.utils.models import safe_get_or_create
|
||||
from authentic2.utils.service import get_service_from_ref, get_service_from_request, service_ref
|
||||
|
||||
from . import app_settings, models
|
||||
from .utils import (
|
||||
|
@ -69,7 +68,6 @@ class LoginOrLinkView(View):
|
|||
"""
|
||||
|
||||
_next_url = None
|
||||
service = None
|
||||
|
||||
@property
|
||||
def next_url(self):
|
||||
|
@ -114,7 +112,7 @@ class LoginOrLinkView(View):
|
|||
def handle_authorization_response(self, request, code, state):
|
||||
# check state signature and parse it
|
||||
try:
|
||||
state, self._next_url, self.service = self.decode_state(state)
|
||||
state, self._next_url = self.decode_state(state)
|
||||
except ValueError:
|
||||
return utils_misc.redirect(request, settings.LOGIN_REDIRECT_URL)
|
||||
|
||||
|
@ -186,10 +184,8 @@ class LoginOrLinkView(View):
|
|||
else:
|
||||
return self.login(request)
|
||||
|
||||
def encode_state(self, state, next_url, service):
|
||||
encoded_state = state + ' ' + self.next_url + ' '
|
||||
if service:
|
||||
encoded_state += service_ref(service)
|
||||
def encode_state(self, state, next_url):
|
||||
encoded_state = state + ' ' + self.next_url
|
||||
encoded_state += ' ' + hmac_url(settings.SECRET_KEY, encoded_state)
|
||||
return encoded_state
|
||||
|
||||
|
@ -197,32 +193,23 @@ class LoginOrLinkView(View):
|
|||
payload, signature = state.rsplit(' ', 1)
|
||||
if not check_hmac_url(settings.SECRET_KEY, payload, signature):
|
||||
raise ValueError
|
||||
# service_ref can be made of one or two parts
|
||||
try:
|
||||
state, next_url, service_ref = payload.split(' ')
|
||||
except ValueError:
|
||||
state, next_url, ou_slug, service_slug = payload.split(' ')
|
||||
service_ref = ou_slug + ' ' + service_slug
|
||||
service = get_service_from_ref(service_ref)
|
||||
return state, next_url, service
|
||||
state, next_url, *dummy = payload.split(' ')
|
||||
return state, next_url
|
||||
|
||||
def make_authorization_request(self, request):
|
||||
scope = ' '.join(set(['openid'] + app_settings.scopes))
|
||||
service = self.service or get_service_from_request(request)
|
||||
|
||||
nonce_seed, nonce, state = hash_chain(3)
|
||||
|
||||
# encode the target service and next_url in the state
|
||||
full_state = state + ' ' + self.next_url + ' '
|
||||
if service:
|
||||
full_state += service_ref(service)
|
||||
full_state += ' ' + hmac_url(settings.SECRET_KEY, full_state)
|
||||
params = {
|
||||
'client_id': app_settings.client_id,
|
||||
'scope': scope,
|
||||
'redirect_uri': self.redirect_uri,
|
||||
'response_type': 'code',
|
||||
'state': self.encode_state(state, self.next_url, service),
|
||||
'state': self.encode_state(state, self.next_url),
|
||||
'nonce': nonce,
|
||||
'acr_values': 'eidas1',
|
||||
}
|
||||
|
@ -340,7 +327,7 @@ class LoginOrLinkView(View):
|
|||
def finish_login(self, request, user, user_info, created):
|
||||
self.update_user_info(user, user_info)
|
||||
utils_views.check_cookie_works(request)
|
||||
utils_misc.login(request, user, 'france-connect', service=self.service)
|
||||
utils_misc.login(request, user, 'france-connect')
|
||||
|
||||
# keep id_token around for logout
|
||||
request.session['fc_id_token'] = self.id_token
|
||||
|
|
|
@ -36,6 +36,7 @@ from authentic2.utils.misc import (
|
|||
normalize_attribute_values,
|
||||
redirect,
|
||||
)
|
||||
from authentic2.utils.service import set_service
|
||||
from authentic2.utils.view_decorators import enable_view_restriction
|
||||
from authentic2.views import logout as logout_view
|
||||
from authentic2_idp_cas.constants import (
|
||||
|
@ -151,6 +152,7 @@ class LoginView(CasMixin, View):
|
|||
model = Service.objects.for_service(service)
|
||||
if not model:
|
||||
return self.failure(request, service, 'service unknown')
|
||||
set_service(request, model)
|
||||
if renew and gateway:
|
||||
return self.failure(request, service, 'renew and gateway cannot be requested at the same time')
|
||||
|
||||
|
@ -464,6 +466,7 @@ class LogoutView(View):
|
|||
if referrer:
|
||||
model = Service.objects.for_service(referrer)
|
||||
if model:
|
||||
set_service(request, model)
|
||||
return logout_view(request, next_url=next_url, check_referer=False, do_local=False)
|
||||
return redirect(request, next_url)
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@ from authentic2.a2_rbac.models import OrganizationalUnit
|
|||
from authentic2.decorators import setting_enabled
|
||||
from authentic2.exponential_retry_timeout import ExponentialRetryTimeout
|
||||
from authentic2.utils.misc import last_authentication_event, login_require, make_url, redirect
|
||||
from authentic2.utils.service import set_service
|
||||
from authentic2.utils.view_decorators import check_view_restriction
|
||||
from authentic2.views import logout as a2_logout
|
||||
|
||||
|
@ -246,6 +247,8 @@ def authorize(request, *args, **kwargs):
|
|||
client = get_client(client_id=client_id)
|
||||
if not client:
|
||||
raise InvalidRequest(_('Unknown client identifier: "%s"') % client_id)
|
||||
# define the current service
|
||||
set_service(request, client)
|
||||
try:
|
||||
client.validate_redirect_uri(redirect_uri)
|
||||
except ValueError:
|
||||
|
@ -333,7 +336,7 @@ def authorize_for_client(request, client, redirect_uri):
|
|||
params = {}
|
||||
if nonce is not None:
|
||||
params['nonce'] = nonce
|
||||
return login_require(request, params=params, service=client, login_hint=login_hint)
|
||||
return login_require(request, params=params, login_hint=login_hint)
|
||||
|
||||
# view restriction and passive SSO
|
||||
if hasattr(request, 'view_restriction_response'):
|
||||
|
@ -352,7 +355,7 @@ def authorize_for_client(request, client, redirect_uri):
|
|||
params = {}
|
||||
if nonce is not None:
|
||||
params['nonce'] = nonce
|
||||
return login_require(request, params=params, service=client, login_hint=login_hint)
|
||||
return login_require(request, params=params, login_hint=login_hint)
|
||||
|
||||
iat = now() # iat = issued at
|
||||
|
||||
|
@ -812,6 +815,7 @@ def logout(request):
|
|||
)
|
||||
for provider in providers:
|
||||
if post_logout_redirect_uri in provider.post_logout_redirect_uris.split():
|
||||
set_service(request, provider)
|
||||
break
|
||||
else:
|
||||
messages.warning(request, _('Invalid post logout URI'))
|
||||
|
|
|
@ -160,12 +160,16 @@ class FranceConnectMock:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def franceconnect(settings, db):
|
||||
def service(db):
|
||||
return Service.objects.create(name='portail', slug='portail', ou=get_default_ou())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def franceconnect(settings, service, db):
|
||||
settings.A2_FC_ENABLE = True
|
||||
settings.A2_FC_CLIENT_ID = CLIENT_ID
|
||||
settings.A2_FC_CLIENT_SECRET = CLIENT_SECRET
|
||||
|
||||
Service.objects.create(name='portail', slug='portail', ou=get_default_ou())
|
||||
mock_object = FranceConnectMock()
|
||||
with mock_object():
|
||||
yield mock_object
|
||||
|
|
|
@ -32,12 +32,12 @@ from authentic2.a2_rbac.models import OrganizationalUnit as OU
|
|||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.apps.journal.models import Event
|
||||
from authentic2.custom_user.models import DeletedUser
|
||||
from authentic2.models import Attribute, Service
|
||||
from authentic2.models import Attribute
|
||||
from authentic2_auth_fc import models
|
||||
from authentic2_auth_fc.backends import FcBackend
|
||||
from authentic2_auth_fc.utils import requests_retry_session
|
||||
|
||||
from ..utils import get_link_from_mail, login
|
||||
from ..utils import get_link_from_mail, login, set_service
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
@ -54,7 +54,7 @@ def test_fc_url_on_login(app, franceconnect):
|
|||
|
||||
|
||||
def test_retry_authorization_if_state_is_lost(settings, app, franceconnect, hooks):
|
||||
response = app.get('/fc/callback/?next=/idp/&service=default%20portail', status=302)
|
||||
response = app.get('/fc/callback/?next=/idp/', status=302)
|
||||
# clear fc-state cookie
|
||||
app.cookiejar.clear()
|
||||
response = franceconnect.handle_authorization(app, response.location, status=302)
|
||||
|
@ -81,26 +81,26 @@ def test_login_autorun(settings, app, franceconnect):
|
|||
assert response.location == reverse('fc-login-or-link')
|
||||
|
||||
|
||||
def test_create(settings, app, franceconnect, hooks):
|
||||
def test_create(settings, app, franceconnect, hooks, service):
|
||||
# test direct creation
|
||||
|
||||
response = app.get('/login/?service=portail&next=/idp/')
|
||||
set_service(app, service)
|
||||
response = app.get('/login/?next=/idp/')
|
||||
response = response.click(href='callback')
|
||||
|
||||
assert User.objects.count() == 0
|
||||
assert Event.objects.which_references(Service.objects.get()).count() == 0
|
||||
assert Event.objects.which_references(service).count() == 0
|
||||
response = franceconnect.handle_authorization(app, response.location, status=302)
|
||||
assert 'fc-state' not in app.cookies
|
||||
assert User.objects.count() == 1
|
||||
# check login for service=portail was registered
|
||||
assert Event.objects.which_references(Service.objects.get()).count() == 1
|
||||
assert Event.objects.which_references(service).count() == 1
|
||||
|
||||
user = User.objects.get()
|
||||
assert user.verified_attributes.first_name == 'Ÿuñe'
|
||||
assert user.verified_attributes.last_name == 'Frédérique'
|
||||
assert path(response.location) == '/idp/'
|
||||
assert hooks.event[1]['kwargs']['name'] == 'login'
|
||||
assert hooks.event[1]['kwargs']['service'] == 'portail'
|
||||
assert hooks.event[1]['kwargs']['service'] == service
|
||||
# we must be connected
|
||||
assert app.session['_auth_user_id']
|
||||
assert app.session.get_expire_at_browser_close()
|
||||
|
@ -130,7 +130,7 @@ def test_create_expired(settings, app, franceconnect, hooks):
|
|||
# test direct creation failure on an expired id_token
|
||||
franceconnect.exp = now() - datetime.timedelta(seconds=30)
|
||||
|
||||
response = app.get('/login/?service=portail&next=/idp/')
|
||||
response = app.get('/login/?next=/idp/')
|
||||
response = response.click(href='callback')
|
||||
|
||||
assert User.objects.count() == 0
|
||||
|
|
|
@ -965,6 +965,9 @@ def test_role_control_access(login_first, oidc_settings, oidc_client, simple_use
|
|||
def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simple_user, hooks, mailoutbox):
|
||||
redirect_uri = simple_oidc_client.redirect_uris.split()[0]
|
||||
|
||||
simple_oidc_client.ou.home_url = 'https://portal/'
|
||||
simple_oidc_client.ou.save()
|
||||
|
||||
params = {
|
||||
'client_id': simple_oidc_client.client_id,
|
||||
'scope': 'openid profile email',
|
||||
|
@ -977,19 +980,18 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl
|
|||
authorize_url = make_url('oidc-authorize', params=params)
|
||||
response = app.get(authorize_url)
|
||||
|
||||
location = urllib.parse.urlparse(response['Location'])
|
||||
query = urllib.parse.parse_qs(location.query)
|
||||
assert query['service'] == ['default client']
|
||||
response = response.follow().click('Register')
|
||||
location = urllib.parse.urlparse(response.request.url)
|
||||
query = urllib.parse.parse_qs(location.query)
|
||||
assert query['service'] == ['default client']
|
||||
|
||||
response.form.set('email', 'john.doe@example.com')
|
||||
response = response.form.submit()
|
||||
assert len(mailoutbox) == 1
|
||||
link = utils.get_link_from_mail(mailoutbox[0])
|
||||
response = app.get(link)
|
||||
body = response.pyquery('body')[0]
|
||||
assert body.attrib['data-home-ou-slug'] == 'default'
|
||||
assert body.attrib['data-home-ou-name'] == 'Default organizational unit'
|
||||
assert body.attrib['data-home-service-slug'] == 'client'
|
||||
assert body.attrib['data-home-service-name'] == 'client'
|
||||
assert body.attrib['data-home-url'] == 'https://portal/'
|
||||
response.form.set('first_name', 'John')
|
||||
response.form.set('last_name', 'Doe')
|
||||
response.form.set('password1', 'T0==toto')
|
||||
|
@ -999,11 +1001,11 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl
|
|||
assert hooks.event[0]['kwargs']['service'].slug == 'client'
|
||||
|
||||
assert hooks.event[1]['kwargs']['name'] == 'registration'
|
||||
assert hooks.event[1]['kwargs']['service'] == 'client'
|
||||
assert hooks.event[1]['kwargs']['service'].slug == 'client'
|
||||
|
||||
assert hooks.event[2]['kwargs']['name'] == 'login'
|
||||
assert hooks.event[2]['kwargs']['how'] == 'email'
|
||||
assert hooks.event[2]['kwargs']['service'] == 'client'
|
||||
assert hooks.event[2]['kwargs']['service'].slug == 'client'
|
||||
|
||||
|
||||
def test_claim_default_value(oidc_settings, normal_oidc_client, simple_user, app):
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) 2010-2022 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def test_home(app, settings, simple_user, service):
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.models import Service
|
||||
|
||||
utils.set_service(app, service)
|
||||
|
||||
settings.LOGIN_REDIRECT_URL = 'https://portal1/'
|
||||
|
||||
resp = app.get('/login/')
|
||||
body = resp.pyquery('body')
|
||||
assert body.attr('data-home-url') == 'https://portal1/'
|
||||
assert body.attr('data-home-service-slug') == service.slug
|
||||
assert body.attr('data-home-service-name') == service.name
|
||||
assert body.attr('data-home-ou-slug') == service.ou.slug
|
||||
assert body.attr('data-home-ou-name') == service.ou.name
|
||||
|
||||
settings.A2_HOMEPAGE_URL = 'https://portal2/'
|
||||
resp = app.get('/login/')
|
||||
body = resp.pyquery('body')
|
||||
assert body.attr('data-home-url') == 'https://portal2/'
|
||||
|
||||
service.ou.home_url = 'https://portal3/'
|
||||
service.ou.save()
|
||||
resp = app.get('/login/')
|
||||
body = resp.pyquery('body')
|
||||
assert body.attr('data-home-url') == 'https://portal3/'
|
||||
|
||||
# if user comes back from a different service, the information is updated
|
||||
new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2')
|
||||
utils.set_service(app, new_service)
|
||||
|
||||
resp = app.get('/login/')
|
||||
body = resp.pyquery('body')
|
||||
assert body.attr('data-home-url') == 'https://portal3/'
|
||||
assert body.attr('data-home-service-slug') == new_service.slug
|
||||
assert body.attr('data-home-service-name') == new_service.name
|
||||
assert body.attr('data-home-ou-slug') == new_service.ou.slug
|
||||
assert body.attr('data-home-ou-name') == new_service.ou.name
|
|
@ -33,7 +33,7 @@ from django.utils.encoding import force_bytes, force_str, force_text
|
|||
from django.utils.translation import gettext as _
|
||||
|
||||
from authentic2.a2_rbac.models import OrganizationalUnit, Role
|
||||
from authentic2.constants import NONCE_FIELD_NAME, SERVICE_FIELD_NAME
|
||||
from authentic2.constants import NONCE_FIELD_NAME
|
||||
from authentic2.custom_user.models import User
|
||||
from authentic2.idp.saml import saml2_endpoints
|
||||
from authentic2.idp.saml.saml2_endpoints import get_extensions, get_login_hints_extension
|
||||
|
@ -330,7 +330,6 @@ class Scenario:
|
|||
reverse('auth_login'),
|
||||
**{
|
||||
'nonce': '*',
|
||||
SERVICE_FIELD_NAME: 'default ' + self.sp.slug,
|
||||
REDIRECT_FIELD_NAME: make_url(
|
||||
'a2-idp-saml-continue', params={NONCE_FIELD_NAME: request_id}
|
||||
),
|
||||
|
|
|
@ -22,7 +22,7 @@ from django.contrib.auth import get_user_model
|
|||
from authentic2 import models
|
||||
from authentic2.utils.misc import get_token_login_url
|
||||
|
||||
from .utils import assert_event, login
|
||||
from .utils import assert_event, login, set_service
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
@ -85,22 +85,22 @@ def test_show_condition(db, app, settings, caplog):
|
|||
assert len(caplog.records) == 1
|
||||
|
||||
|
||||
def test_show_condition_service(db, app, settings):
|
||||
def test_show_condition_service(db, rf, app, settings):
|
||||
portal = models.Service.objects.create(pk=1, name='Service', slug='portal')
|
||||
service = models.Service.objects.create(pk=2, name='Service', slug='service')
|
||||
settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': 'service_slug == \'portal\''}}
|
||||
response = app.get('/login/', params={})
|
||||
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' not in response
|
||||
|
||||
# service doesn't exist
|
||||
response = app.get('/login/', params={'service': 'portal'})
|
||||
assert 'name="login-password-submit"' not in response
|
||||
set_service(app, portal)
|
||||
|
||||
# Create a service
|
||||
models.Service.objects.create(name='Service', slug='portal')
|
||||
response = app.get('/login/', params={'service': 'portal'})
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' in response
|
||||
|
||||
models.Service.objects.create(name='Service', slug='service')
|
||||
response = app.get('/login/', params={'service': 'service'})
|
||||
set_service(app, service)
|
||||
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' not in response
|
||||
|
||||
|
||||
|
@ -251,29 +251,31 @@ def test_ou_selector(app, settings, simple_user, ou1, ou2, user_ou1, role_ou1):
|
|||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
|
||||
|
||||
set_service(app, service)
|
||||
# service is specified but not access-control is defined, default for user is selected
|
||||
response = app.get('/login/?service=service')
|
||||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
|
||||
|
||||
# service is specified, access control is defined but role is empty, default for user is selected
|
||||
service.authorized_roles.through.objects.create(service=service, role=role_ou1)
|
||||
response = app.get('/login/?service=service')
|
||||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
|
||||
|
||||
# user is added to role_ou1, default for user is still selected
|
||||
user_ou1.roles.add(role_ou1)
|
||||
response = app.get('/login/?service=service')
|
||||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit'
|
||||
|
||||
# Clear cookies, OU1 is selected
|
||||
app.cookiejar.clear()
|
||||
response = app.get('/login/?service=service')
|
||||
set_service(app, service)
|
||||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU1'
|
||||
|
||||
# if we change the user's ou, then default selected OU changes
|
||||
user_ou1.ou = ou2
|
||||
user_ou1.save()
|
||||
response = app.get('/login/?service=service')
|
||||
response = app.get('/login/')
|
||||
assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU2'
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,7 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import pytest
|
||||
from django.urls import reverse
|
||||
|
||||
from authentic2.a2_rbac.utils import get_default_ou
|
||||
from authentic2.models import Service
|
||||
from authentic2.utils.template import Template, TemplateError
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
@ -111,24 +108,3 @@ def test_render_template_missing_variable():
|
|||
with pytest.raises(TemplateError) as raised:
|
||||
template.render(context=context)
|
||||
assert 'missing template variable' in raised
|
||||
|
||||
|
||||
def test_service_in_template(app, simple_user, service):
|
||||
resp = app.get(reverse('auth_login') + '?service=%s' % service.slug)
|
||||
|
||||
assert resp.pyquery('body').attr('data-service-slug') == service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == service.name
|
||||
|
||||
resp.form.set('username', simple_user.username)
|
||||
resp.form.set('password', simple_user.username)
|
||||
resp.form.submit(name='login-password-submit')
|
||||
|
||||
resp = app.get(reverse('account_management'))
|
||||
assert resp.pyquery('body').attr('data-service-slug') == service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == service.name
|
||||
|
||||
# if user comes back from a different service, the information is updated
|
||||
new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2')
|
||||
resp = app.get(reverse('account_management') + '?service=%s' % new_service.slug)
|
||||
assert resp.pyquery('body').attr('data-service-slug') == new_service.slug
|
||||
assert resp.pyquery('body').attr('data-service-name') == new_service.name
|
||||
|
|
|
@ -307,3 +307,21 @@ def assert_event(event_type_name, user=None, session=None, service=None, target_
|
|||
)
|
||||
elif data and count > 1:
|
||||
assert qs.filter(**{'data__' + k: v for k, v in data.items()}).count() == 1
|
||||
|
||||
|
||||
def set_service(app, service):
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from authentic2.utils.service import _set_session_service
|
||||
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
if app.session == {}:
|
||||
session = engine.SessionStore()
|
||||
else:
|
||||
session = app.session
|
||||
_set_session_service(session, service)
|
||||
session.save()
|
||||
if app.session == {}:
|
||||
app.set_cookie(settings.SESSION_COOKIE_NAME, session.session_key)
|
||||
|
|
Loading…
Reference in New Issue