diff --git a/mellon/utils.py b/mellon/utils.py index a076c05..ce965b4 100644 --- a/mellon/utils.py +++ b/mellon/utils.py @@ -8,6 +8,7 @@ from django.core.urlresolvers import reverse from django.template.loader import render_to_string from django.utils.timezone import make_aware, now, make_naive, is_aware, get_default_timezone from django.conf import settings +from django.utils.six.moves.urllib.parse import urlparse import lasso from . import app_settings @@ -197,3 +198,14 @@ def create_logout(request): def is_nonnull(s): return not '\x00' in s + + +def same_origin(url1, url2): + """ + Checks if two URLs are 'same-origin' + """ + p1, p2 = urlparse(url1), urlparse(url2) + try: + return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port) + except ValueError: + return False diff --git a/mellon/views.py b/mellon/views.py index b302fb0..a48012d 100644 --- a/mellon/views.py +++ b/mellon/views.py @@ -9,7 +9,7 @@ from django.contrib import auth from django.conf import settings from django.views.decorators.csrf import csrf_exempt from django.shortcuts import render, redirect, resolve_url -from django.utils.http import same_origin, urlencode +from django.utils.http import urlencode from . import app_settings @@ -26,7 +26,7 @@ class LogMixin(object): class LoginView(LogMixin, View): def get_idp(self, request): - entity_id = request.REQUEST.get('entityID') + entity_id = request.POST.get('entityID') or request.GET.get('entityID') if not entity_id: for idp in utils.get_idps(): return idp @@ -315,7 +315,7 @@ class LogoutView(LogMixin, View): next_url = resolve_url(settings.LOGIN_REDIRECT_URL) next_url = request.GET.get('next') or next_url referer = request.META.get('HTTP_REFERER') - if not referer or same_origin(referer, request.build_absolute_uri()): + if not referer or utils.same_origin(referer, request.build_absolute_uri()): if request.user.is_authenticated(): try: issuer = request.session.get('mellon_session', {}).get('issuer') @@ -357,7 +357,7 @@ class LogoutView(LogMixin, View): logout.processResponseMsg(request.META['QUERY_STRING']) except lasso.Error, e: self.log.error('unable to process a logout response %r', e) - if logout.msgRelayState and same_origin(logout.msgRelayState, request.build_absolute_uri()): + if logout.msgRelayState and utils.same_origin(logout.msgRelayState, request.build_absolute_uri()): return redirect(logout.msgRelayState) return redirect(next_url) diff --git a/tests/test_utils.py b/tests/test_utils.py index 61b715d..6cc907b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -164,7 +164,10 @@ def test_iso8601_to_datetime(private_settings): import pytz private_settings.TIME_ZONE = 'UTC' + if hasattr(django.utils.timezone.get_default_timezone, 'cache_clear'): + django.utils.timezone.get_default_timezone.cache_clear() django.utils.timezone._localtime = None + private_settings.USE_TZ = False # UTC ISO8601 -> naive datetime UTC assert iso8601_to_datetime('2010-10-01T10:10:34Z') == datetime.datetime( 2010, 10, 01, 10, 10, 34)