utils: use same_origin() from authentic2 (#69740)

This commit is contained in:
Benjamin Dauvergne 2022-09-30 00:05:15 +02:00
parent e9008debf5
commit 43ce1d8141
2 changed files with 57 additions and 8 deletions

View File

@ -245,18 +245,67 @@ def is_nonnull(s):
return '\x00' not in s
PROTOCOLS_TO_PORT = {
'http': '80',
'https': '443',
}
def netloc_to_host_port(netloc):
if not netloc:
return None, None
splitted = netloc.split(':', 1)
if len(splitted) > 1:
return splitted[0], splitted[1]
return splitted[0], None
def same_domain(domain1, domain2):
if domain1 == domain2:
return True
if not domain1 or not domain2:
return False
if domain2.startswith('.'):
# p1 is a sub-domain or the base domain
if domain1.endswith(domain2) or domain1 == domain2[1:]:
return True
return False
def same_origin(url1, url2):
"""
Checks if two URLs are 'same-origin'
"""Checks if both URL use the same domain. It understands domain patterns on url2, i.e. .example.com
matches www.example.com.
If not scheme is given in url2, scheme compare is skipped.
If not scheme and not port are given, port compare is skipped.
The last two rules allow authorizing complete domains easily.
"""
p1, p2 = urlparse(url1), urlparse(url2)
if url1.startswith('/') or url2.startswith('/'):
p1_host, p1_port = netloc_to_host_port(p1.netloc)
p2_host, p2_port = netloc_to_host_port(p2.netloc)
# url2 is relative, always same domain
if url2.startswith('/') and not url2.startswith('//'):
return True
try:
return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port)
except ValueError:
if p2.scheme and p1.scheme != p2.scheme:
return False
if not same_domain(p1_host, p2_host):
return False
try:
if (p2_port or (p1_port and p2.scheme)) and (
(p1_port or PROTOCOLS_TO_PORT[p1.scheme]) != (p2_port or PROTOCOLS_TO_PORT[p2.scheme])
):
return False
except (ValueError, KeyError):
return False
return True
def get_status_codes_and_message(profile):
assert profile, 'missing lasso.Profile'

View File

@ -84,7 +84,7 @@ def check_next_url(request, next_url):
except UnicodeError:
log.warning('next parameter ignored, as is\'s not an ASCII string')
return
if not utils.same_origin(next_url, request.build_absolute_uri()):
if not utils.same_origin(request.build_absolute_uri(), next_url):
log.warning('next parameter ignored as it is not of the same origin')
return
return next_url
@ -725,7 +725,7 @@ class LogoutView(ProfileMixin, LogMixin, View):
'''Launch a logout request to the identity provider'''
next_url = request.GET.get(REDIRECT_FIELD_NAME)
referer = request.headers.get('Referer')
if not referer or utils.same_origin(referer, request.build_absolute_uri()):
if not referer or utils.same_origin(request.build_absolute_uri(), referer):
if hasattr(request, 'user') and request.user.is_authenticated:
logout = None
try: