300 lines
10 KiB
Python
300 lines
10 KiB
Python
# authentic2 - versatile identity manager
|
|
# Copyright (C) 2010-2020 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
# authentic2
|
|
|
|
import base64
|
|
import re
|
|
import socket
|
|
import urllib.parse
|
|
from contextlib import closing, contextmanager
|
|
|
|
import httmock
|
|
from django.core.management import call_command as django_call_command
|
|
from django.shortcuts import resolve_url
|
|
from django.test import TestCase
|
|
from django.urls import reverse
|
|
from django.utils.encoding import force_text, iri_to_uri
|
|
from lxml import etree
|
|
|
|
from authentic2 import models, utils
|
|
from authentic2.apps.journal.models import Event
|
|
|
|
|
|
def login(app, user, path=None, password=None, remember_me=None, args=None, kwargs=None, fail=False):
|
|
if path:
|
|
args = args or []
|
|
kwargs = kwargs or {}
|
|
path = resolve_url(path, *args, **kwargs)
|
|
login_page = app.get(path, status=302).maybe_follow()
|
|
else:
|
|
login_page = app.get(reverse('auth_login'))
|
|
assert login_page.request.path == reverse('auth_login')
|
|
form = login_page.form
|
|
username = user.username if hasattr(user, 'username') else user
|
|
form.set('username', username)
|
|
# password is supposed to be the same as username
|
|
form.set('password', password or (user.clear_password if hasattr(user, 'clear_password') else username))
|
|
if remember_me is not None:
|
|
form.set('remember_me', bool(remember_me))
|
|
response = form.submit(name='login-password-submit')
|
|
if fail:
|
|
assert response.status_code == 200
|
|
assert '_auth_user_id' not in app.session
|
|
else:
|
|
response = response.follow()
|
|
if path:
|
|
assert response.request.path == path
|
|
else:
|
|
assert response.request.path == reverse('auth_homepage')
|
|
assert '_auth_user_id' in app.session
|
|
assert not hasattr(user, 'id') or (app.session['_auth_user_id'] == str(user.id))
|
|
return response
|
|
|
|
|
|
def logout(app):
|
|
assert '_auth_user_id' in app.session
|
|
response = app.get(reverse('auth_logout')).maybe_follow()
|
|
response = response.form.submit().maybe_follow()
|
|
if 'continue-link' in response.text:
|
|
response = response.click('Continue logout').maybe_follow()
|
|
assert '_auth_user_id' not in app.session
|
|
return response
|
|
|
|
|
|
def basic_authorization_header(user, password=None):
|
|
cred = '%s:%s' % (user.username, password or user.username)
|
|
b64_cred = base64.b64encode(cred.encode('utf-8'))
|
|
return {'Authorization': 'Basic %s' % str(force_text(b64_cred))}
|
|
|
|
|
|
def get_response_form(response, form='form'):
|
|
contexts = list(response.context)
|
|
for c in contexts:
|
|
if form not in c:
|
|
continue
|
|
return c[form]
|
|
|
|
|
|
def assert_equals_url(url1, url2, **kwargs):
|
|
"""Check that url1 is equals to url2 augmented with parameters kwargs
|
|
in its query string.
|
|
|
|
The string '*' is a special value, when used it just check that the
|
|
given parameter exist in the first url, it does not check the exact
|
|
value.
|
|
"""
|
|
url1 = iri_to_uri(utils.make_url(url1, params=None))
|
|
splitted1 = urllib.parse.urlsplit(url1)
|
|
url2 = iri_to_uri(utils.make_url(url2, params=kwargs))
|
|
splitted2 = urllib.parse.urlsplit(url2)
|
|
for i, (elt1, elt2) in enumerate(zip(splitted1, splitted2)):
|
|
if i == 3:
|
|
elt1 = urllib.parse.parse_qs(elt1, True)
|
|
elt2 = urllib.parse.parse_qs(elt2, True)
|
|
for k, v in elt1.items():
|
|
elt1[k] = set(v)
|
|
for k, v in elt2.items():
|
|
if v == ['*']:
|
|
elt2[k] = elt1.get(k, v)
|
|
else:
|
|
elt2[k] = set(v)
|
|
assert elt1 == elt2, "URLs are not equal: %s != %s" % (splitted1, splitted2)
|
|
|
|
|
|
def assert_redirects_complex(response, expected_url, **kwargs):
|
|
assert response.status_code == 302, 'code should be 302'
|
|
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(response.url)
|
|
e_scheme, e_netloc, e_path, e_query, e_fragment = urllib.parse.urlsplit(expected_url)
|
|
e_scheme = e_scheme if e_scheme else scheme
|
|
e_netloc = e_netloc if e_netloc else netloc
|
|
expected_url = urllib.parse.urlunsplit((e_scheme, e_netloc, e_path, e_query, e_fragment))
|
|
assert_equals_url(response['Location'], expected_url, **kwargs)
|
|
|
|
|
|
def assert_xpath_constraints(xml, constraints, namespaces):
|
|
if hasattr(xml, 'content'):
|
|
xml = xml.content
|
|
doc = etree.fromstring(xml)
|
|
for xpath, content in constraints:
|
|
nodes = doc.xpath(xpath, namespaces=namespaces)
|
|
assert len(nodes) > 0, 'xpath %s not found' % xpath
|
|
if isinstance(content, str):
|
|
for node in nodes:
|
|
if hasattr(node, 'text'):
|
|
assert node.text == content, 'xpath %s does not contain %s but %s' % (
|
|
xpath,
|
|
content,
|
|
node.text,
|
|
)
|
|
else:
|
|
assert node == content, 'xpath %s does not contain %s but %s' % (xpath, content, node)
|
|
else:
|
|
values = [node.text if hasattr(node, 'text') else node for node in nodes]
|
|
if isinstance(content, set):
|
|
assert set(values) == content
|
|
elif isinstance(content, list):
|
|
assert values == content
|
|
elif hasattr(content, 'pattern'):
|
|
for value in values:
|
|
assert content.match(value), 'xpath %s does not match regexp %s' % (
|
|
xpath,
|
|
content.pattern,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
'comparing xpath result to type %s: %r is not implemented' % (type(content), content)
|
|
)
|
|
|
|
|
|
class Authentic2TestCase(TestCase):
|
|
def assertEqualsURL(self, url1, url2, **kwargs):
|
|
assert_equals_url(url1, url2, **kwargs)
|
|
|
|
def assertRedirectsComplex(self, response, expected_url, **kwargs):
|
|
assert_redirects_complex(response, expected_url, **kwargs)
|
|
|
|
def assertXPathConstraints(self, xml, constraints, namespaces):
|
|
assert_xpath_constraints(xml, constraints, namespaces)
|
|
|
|
|
|
@contextmanager
|
|
def check_log(caplog, message, levelname=None):
|
|
idx = len(caplog.records)
|
|
yield
|
|
assert any(
|
|
message in record.message
|
|
for record in caplog.records[idx:]
|
|
if not levelname or record.levelname == levelname
|
|
), ('%r not found in log records' % message)
|
|
|
|
|
|
def get_links_from_mail(mail):
|
|
'''Extract links from mail sent by Django'''
|
|
return re.findall('https?://[^ \n]*', mail.body)
|
|
|
|
|
|
def get_link_from_mail(mail):
|
|
'''Extract the first and only link from this mail'''
|
|
links = get_links_from_mail(mail)
|
|
assert links, 'there is not link in this mail'
|
|
assert len(links) == 1, 'there are more than one link in this mail'
|
|
return links[0]
|
|
|
|
|
|
def saml_sp_metadata(base_url):
|
|
return '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
|
<EntityDescriptor
|
|
entityID="{base_url}/"
|
|
xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
|
|
<SPSSODescriptor
|
|
AuthnRequestsSigned="true"
|
|
WantAssertionsSigned="true"
|
|
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
|
<SingleLogoutService
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
|
Location="https://files.entrouvert.org/mellon/logout" />
|
|
<AssertionConsumerService
|
|
index="0"
|
|
isDefault="true"
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
|
Location="{base_url}/sso/POST" />
|
|
<AssertionConsumerService
|
|
index="1"
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact"
|
|
Location="{base_url}/mellon/artifactResponse" />
|
|
</SPSSODescriptor>
|
|
</EntityDescriptor>'''.format(
|
|
base_url=base_url
|
|
)
|
|
|
|
|
|
def find_free_tcp_port():
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
s.bind(('', 0))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def request_select2(app, response, term='', get_kwargs=None):
|
|
select2_url = response.pyquery('select')[0].attrib['data-ajax--url']
|
|
select2_field_id = response.pyquery('select')[0].attrib['data-field_id']
|
|
select2_response = app.get(
|
|
select2_url, params={'field_id': select2_field_id, 'term': term}, **(get_kwargs or {})
|
|
)
|
|
if select2_response['content-type'] == 'application/json':
|
|
return select2_response.json
|
|
else:
|
|
return select2_response
|
|
|
|
|
|
@contextmanager
|
|
def run_on_commit_hooks():
|
|
yield
|
|
|
|
from django.db import connection
|
|
|
|
current_run_on_commit = connection.run_on_commit
|
|
connection.run_on_commit = []
|
|
while current_run_on_commit:
|
|
sids, func = current_run_on_commit.pop(0)
|
|
func()
|
|
|
|
|
|
def call_command(*args, **kwargs):
|
|
with run_on_commit_hooks():
|
|
return django_call_command(*args, **kwargs)
|
|
|
|
|
|
def text_content(node):
|
|
"""Extract text content from node and all its children. Equivalent to
|
|
xmlNodeGetContent from libxml."""
|
|
return ''.join(node.itertext()) if node is not None else ''
|
|
|
|
|
|
def assert_event(event_type_name, user=None, session=None, service=None, **data):
|
|
qs = Event.objects.filter(type__name=event_type_name)
|
|
if user:
|
|
qs = qs.filter(user=user)
|
|
else:
|
|
qs = qs.filter(user__isnull=True)
|
|
if session:
|
|
qs = qs.filter(session=session.session_key)
|
|
else:
|
|
qs = qs.filter(session__isnull=True)
|
|
if service:
|
|
qs = qs.which_references(service)
|
|
else:
|
|
qs = qs.exclude(qs._which_references_query(models.Service))
|
|
|
|
assert qs.count() == 1
|
|
|
|
if data:
|
|
event = qs.get()
|
|
assert event.data, 'no event.data, should be %s' % data
|
|
for key, value in data.items():
|
|
assert event.data.get(key) == value, 'event.data[%s] != data[%s] (%s != %s)' % (
|
|
key,
|
|
key,
|
|
event.data.get(key),
|
|
value,
|
|
)
|
|
|
|
|
|
@httmock.HTTMock
|
|
@httmock.urlmatch()
|
|
def norequest(request, url):
|
|
assert False, 'no request should be done'
|