authentic/tests/utils.py

280 lines
10 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2020 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# authentic2
import re
import base64
import socket
from contextlib import contextmanager, closing
from lxml import etree
from django.core.management import call_command as django_call_command
from django.test import TestCase
from django.urls import reverse
from django.utils.encoding import iri_to_uri, force_text
from django.shortcuts import resolve_url
from django.utils import six
from django.utils.six.moves.urllib import parse as urlparse
from authentic2 import utils, models
from authentic2.apps.journal.models import Event
def login(app, user, path=None, password=None, remember_me=None, args=None, kwargs=None, fail=False):
if path:
args = args or []
kwargs = kwargs or {}
path = resolve_url(path, *args, **kwargs)
login_page = app.get(path, status=302).maybe_follow()
else:
login_page = app.get(reverse('auth_login'))
assert login_page.request.path == reverse('auth_login')
form = login_page.form
username = user.username if hasattr(user, 'username') else user
form.set('username', username)
# password is supposed to be the same as username
form.set('password', password or (user.clear_password if hasattr(user, 'clear_password') else username))
if remember_me is not None:
form.set('remember_me', bool(remember_me))
response = form.submit(name='login-password-submit')
if fail:
assert response.status_code == 200
assert '_auth_user_id' not in app.session
else:
response = response.follow()
if path:
assert response.request.path == path
else:
assert response.request.path == reverse('auth_homepage')
assert '_auth_user_id' in app.session
assert not hasattr(user, 'id') or (app.session['_auth_user_id'] == str(user.id))
return response
def logout(app):
assert '_auth_user_id' in app.session
response = app.get(reverse('auth_logout')).maybe_follow()
response = response.form.submit().maybe_follow()
if 'continue-link' in response.text:
response = response.click('Continue logout').maybe_follow()
assert '_auth_user_id' not in app.session
return response
def basic_authorization_header(user, password=None):
cred = '%s:%s' % (user.username, password or user.username)
b64_cred = base64.b64encode(cred.encode('utf-8'))
return {'Authorization': 'Basic %s' % str(force_text(b64_cred))}
def get_response_form(response, form='form'):
contexts = list(response.context)
for c in contexts:
if form not in c:
continue
return c[form]
def assert_equals_url(url1, url2, **kwargs):
'''Check that url1 is equals to url2 augmented with parameters kwargs
in its query string.
The string '*' is a special value, when used it just check that the
given parameter exist in the first url, it does not check the exact
value.
'''
url1 = iri_to_uri(utils.make_url(url1, params=None))
splitted1 = urlparse.urlsplit(url1)
url2 = iri_to_uri(utils.make_url(url2, params=kwargs))
splitted2 = urlparse.urlsplit(url2)
for i, (elt1, elt2) in enumerate(zip(splitted1, splitted2)):
if i == 3:
elt1 = urlparse.parse_qs(elt1, True)
elt2 = urlparse.parse_qs(elt2, True)
for k, v in elt1.items():
elt1[k] = set(v)
for k, v in elt2.items():
if v == ['*']:
elt2[k] = elt1.get(k, v)
else:
elt2[k] = set(v)
assert elt1 == elt2, "URLs are not equal: %s != %s" % (splitted1, splitted2)
def assert_redirects_complex(response, expected_url, **kwargs):
assert response.status_code == 302, 'code should be 302'
scheme, netloc, path, query, fragment = urlparse.urlsplit(response.url)
e_scheme, e_netloc, e_path, e_query, e_fragment = \
urlparse.urlsplit(expected_url)
e_scheme = e_scheme if e_scheme else scheme
e_netloc = e_netloc if e_netloc else netloc
expected_url = urlparse.urlunsplit((e_scheme, e_netloc, e_path, e_query, e_fragment))
assert_equals_url(response['Location'], expected_url, **kwargs)
def assert_xpath_constraints(xml, constraints, namespaces):
if hasattr(xml, 'content'):
xml = xml.content
doc = etree.fromstring(xml)
for xpath, content in constraints:
nodes = doc.xpath(xpath, namespaces=namespaces)
assert len(nodes) > 0, 'xpath %s not found' % xpath
if isinstance(content, six.string_types):
for node in nodes:
if hasattr(node, 'text'):
assert node.text == content, 'xpath %s does not contain %s but %s' % (xpath, content, node.text)
else:
assert node == content, 'xpath %s does not contain %s but %s' % (xpath, content, node)
else:
values = [node.text if hasattr(node, 'text') else node for node in nodes]
if isinstance(content, set):
assert set(values) == content
elif isinstance(content, list):
assert values == content
elif hasattr(content, 'pattern'):
for value in values:
assert content.match(value), 'xpath %s does not match regexp %s' % (xpath, content.pattern)
else:
raise NotImplementedError(
'comparing xpath result to type %s: %r is not implemented' % (
type(content), content))
class Authentic2TestCase(TestCase):
def assertEqualsURL(self, url1, url2, **kwargs):
assert_equals_url(url1, url2, **kwargs)
def assertRedirectsComplex(self, response, expected_url, **kwargs):
assert_redirects_complex(response, expected_url, **kwargs)
def assertXPathConstraints(self, xml, constraints, namespaces):
assert_xpath_constraints(xml, constraints, namespaces)
@contextmanager
def check_log(caplog, message, levelname=None):
idx = len(caplog.records)
yield
assert any(message in record.message for record in caplog.records[idx:]
if not levelname or record.levelname == levelname), \
'%r not found in log records' % message
def get_links_from_mail(mail):
'''Extract links from mail sent by Django'''
return re.findall('https?://[^ \n]*', mail.body)
def get_link_from_mail(mail):
'''Extract the first and only link from this mail'''
links = get_links_from_mail(mail)
assert links, 'there is not link in this mail'
assert len(links) == 1, 'there are more than one link in this mail'
return links[0]
def saml_sp_metadata(base_url):
return '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<EntityDescriptor
entityID="{base_url}/"
xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
<SPSSODescriptor
AuthnRequestsSigned="true"
WantAssertionsSigned="true"
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<SingleLogoutService
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
Location="https://files.entrouvert.org/mellon/logout" />
<AssertionConsumerService
index="0"
isDefault="true"
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="{base_url}/sso/POST" />
<AssertionConsumerService
index="1"
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact"
Location="{base_url}/mellon/artifactResponse" />
</SPSSODescriptor>
</EntityDescriptor>'''.format(base_url=base_url)
def find_free_tcp_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def request_select2(app, response, term='', get_kwargs=None):
select2_url = response.pyquery('select')[0].attrib['data-ajax--url']
select2_field_id = response.pyquery('select')[0].attrib['data-field_id']
select2_response = app.get(select2_url, params={'field_id': select2_field_id, 'term': term}, **(get_kwargs or {}))
if select2_response['content-type'] == 'application/json':
return select2_response.json
else:
return select2_response
@contextmanager
def run_on_commit_hooks():
yield
from django.db import connection
current_run_on_commit = connection.run_on_commit
connection.run_on_commit = []
while current_run_on_commit:
sids, func = current_run_on_commit.pop(0)
func()
def call_command(*args, **kwargs):
with run_on_commit_hooks():
return django_call_command(*args, **kwargs)
def text_content(node):
'''Extract text content from node and all its children. Equivalent to
xmlNodeGetContent from libxml.'''
return u''.join(node.itertext()) if node is not None else ''
def assert_event(event_type_name, user=None, session=None, service=None, **data):
qs = Event.objects.filter(type__name=event_type_name)
if user:
qs = qs.filter(user=user)
else:
qs = qs.filter(user__isnull=True)
if session:
qs = qs.filter(session=session.session_key)
else:
qs = qs.filter(session__isnull=True)
if service:
qs = qs.which_references(service)
else:
qs = qs.exclude(qs._which_references_query(models.Service))
assert qs.count() == 1
if data:
event = qs.get()
assert event.data, 'no event.data, should be %s' % data
for key, value in data.items():
assert event.data.get(key) == value, (
'event.data[%s] != data[%s] (%s != %s)' % (key, key, event.data.get(key), value)
)