485 lines
16 KiB
Python
485 lines
16 KiB
Python
# authentic2 - versatile identity manager
|
|
# Copyright (C) 2010-2020 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
# authentic2
|
|
|
|
import base64
|
|
import functools
|
|
import inspect
|
|
import re
|
|
import socket
|
|
import urllib.parse
|
|
from contextlib import closing, contextmanager
|
|
|
|
import pytest
|
|
from django.contrib.messages.storage.cookie import MessageDecoder, MessageEncoder
|
|
|
|
try:
|
|
from django.contrib.messages.storage.cookie import MessageSerializer
|
|
except ImportError: # oops, not running in django3
|
|
import json
|
|
|
|
class MessageSerializer:
|
|
def dumps(self, obj):
|
|
return json.dumps(
|
|
obj,
|
|
separators=(',', ':'),
|
|
cls=MessageEncoder,
|
|
).encode('latin-1')
|
|
|
|
def loads(self, data):
|
|
return json.loads(data.decode('latin-1'), cls=MessageDecoder)
|
|
|
|
|
|
from django.contrib.auth import get_user_model
|
|
from django.core import signing
|
|
from django.core.management import call_command as django_call_command
|
|
from django.shortcuts import resolve_url
|
|
from django.test import TestCase
|
|
from django.urls import reverse
|
|
from django.utils.encoding import force_str, iri_to_uri
|
|
from lxml import etree
|
|
|
|
from authentic2.apps.journal.models import Event
|
|
from authentic2.utils import misc as utils_misc
|
|
|
|
USER_ATTRIBUTES_SET = {
|
|
'ou',
|
|
'id',
|
|
'uuid',
|
|
'is_staff',
|
|
'is_superuser',
|
|
'first_name',
|
|
'first_name_verified',
|
|
'last_name',
|
|
'last_name_verified',
|
|
'date_joined',
|
|
'last_login',
|
|
'username',
|
|
'password',
|
|
'email',
|
|
'is_active',
|
|
'modified',
|
|
'email_verified',
|
|
'email_verified_date',
|
|
'email_verified_sources',
|
|
'phone_verified_on',
|
|
'last_account_deletion_alert',
|
|
'deactivation',
|
|
'deactivation_reason',
|
|
'full_name',
|
|
}
|
|
|
|
|
|
def create_user(**kwargs):
|
|
User = get_user_model()
|
|
password = kwargs.pop('password', None) or kwargs['username']
|
|
user, dummy = User.objects.get_or_create(**kwargs)
|
|
if password:
|
|
user.clear_password = password
|
|
user.set_password(password)
|
|
user.save()
|
|
return user
|
|
|
|
|
|
def login(
|
|
app,
|
|
user,
|
|
path=None,
|
|
login=None,
|
|
password=None,
|
|
remember_me=None,
|
|
phone_authn=False,
|
|
ou=None,
|
|
args=None,
|
|
kwargs=None,
|
|
fail=False,
|
|
):
|
|
if path:
|
|
args = args or []
|
|
kwargs = kwargs or {}
|
|
path = resolve_url(path, *args, **kwargs)
|
|
login_page = app.get(path, status=302).maybe_follow()
|
|
else:
|
|
login_page = app.get(reverse('auth_login'))
|
|
assert login_page.request.path == reverse('auth_login')
|
|
form = login_page.forms['login-password-form']
|
|
if not phone_authn:
|
|
if login:
|
|
username = login
|
|
elif hasattr(user, 'username'):
|
|
username = user.username
|
|
else:
|
|
username = user
|
|
form.set('username', username)
|
|
else:
|
|
if login:
|
|
phone = login
|
|
else:
|
|
phone = user.phone
|
|
form.set('username', phone)
|
|
# password is supposed to be the same as username
|
|
form.set('password', password or (user.clear_password if hasattr(user, 'clear_password') else username))
|
|
if ou is not None and 'ou' in form.fields:
|
|
form.set('ou', str(ou.pk))
|
|
if remember_me is not None:
|
|
form.set('remember_me', bool(remember_me))
|
|
response = form.submit(name='login-password-submit')
|
|
if fail:
|
|
assert response.status_code == 200
|
|
assert '_auth_user_id' not in app.session
|
|
else:
|
|
response = response.follow()
|
|
if path:
|
|
assert response.request.path == path
|
|
else:
|
|
assert response.request.path == reverse('auth_homepage')
|
|
assert '_auth_user_id' in app.session
|
|
assert not hasattr(user, 'id') or (app.session['_auth_user_id'] == str(user.id))
|
|
return response
|
|
|
|
|
|
def logout(app):
|
|
assert '_auth_user_id' in app.session
|
|
response = app.get(reverse('auth_logout')).maybe_follow()
|
|
response = response.form.submit().maybe_follow()
|
|
if 'continue-link' in response.text:
|
|
response = response.click('Continue logout').maybe_follow()
|
|
assert '_auth_user_id' not in app.session
|
|
return response
|
|
|
|
|
|
def basic_authorization_header(user_or_id, password=None):
|
|
if isinstance(user_or_id, get_user_model()):
|
|
username = user_or_id.username
|
|
password = password or user_or_id.username
|
|
else:
|
|
username = user_or_id
|
|
cred = '%s:%s' % (username, password)
|
|
b64_cred = base64.b64encode(cred.encode('utf-8'))
|
|
return {'Authorization': 'Basic %s' % str(force_str(b64_cred))}
|
|
|
|
|
|
def basic_authorization_oidc_client(client):
|
|
cred = f'{client.client_id}:{client.client_secret}'
|
|
b64_cred = base64.b64encode(cred.encode('utf-8'))
|
|
return {'Authorization': 'Basic %s' % str(force_str(b64_cred))}
|
|
|
|
|
|
def get_response_form(response, form='form'):
|
|
contexts = list(response.context)
|
|
for c in contexts:
|
|
if form not in c:
|
|
continue
|
|
return c[form]
|
|
|
|
|
|
def assert_equals_url(url1, url2, **kwargs):
|
|
"""Check that url1 is equals to url2 augmented with parameters kwargs
|
|
in its query string.
|
|
|
|
The string '*' is a special value, when used it just check that the
|
|
given parameter exist in the first url, it does not check the exact
|
|
value.
|
|
"""
|
|
url1 = iri_to_uri(utils_misc.make_url(url1, params=None))
|
|
splitted1 = urllib.parse.urlsplit(url1)
|
|
url2 = iri_to_uri(utils_misc.make_url(url2, params=kwargs))
|
|
splitted2 = urllib.parse.urlsplit(url2)
|
|
for i, (elt1, elt2) in enumerate(zip(splitted1, splitted2)):
|
|
if i == 3:
|
|
elt1 = urllib.parse.parse_qs(elt1, True)
|
|
elt2 = urllib.parse.parse_qs(elt2, True)
|
|
for k, v in elt1.items():
|
|
elt1[k] = set(v)
|
|
for k, v in elt2.items():
|
|
if v == ['*']:
|
|
elt2[k] = elt1.get(k, v)
|
|
else:
|
|
elt2[k] = set(v)
|
|
assert elt1 == elt2, 'URLs are not equal: %s != %s' % (splitted1, splitted2)
|
|
|
|
|
|
def assert_redirects_complex(response, expected_url, **kwargs):
|
|
assert response.status_code == 302, 'code should be 302'
|
|
scheme, netloc, _, _, _ = urllib.parse.urlsplit(response.url)
|
|
e_scheme, e_netloc, e_path, e_query, e_fragment = urllib.parse.urlsplit(expected_url)
|
|
e_scheme = e_scheme if e_scheme else scheme
|
|
e_netloc = e_netloc if e_netloc else netloc
|
|
expected_url = urllib.parse.urlunsplit((e_scheme, e_netloc, e_path, e_query, e_fragment))
|
|
assert_equals_url(response['Location'], expected_url, **kwargs)
|
|
|
|
|
|
def assert_xpath_constraints(xml, constraints, namespaces):
|
|
if hasattr(xml, 'content'):
|
|
xml = xml.content
|
|
doc = etree.fromstring(xml)
|
|
for xpath, content in constraints:
|
|
nodes = doc.xpath(xpath, namespaces=namespaces)
|
|
assert len(nodes) > 0, 'xpath %s not found' % xpath
|
|
if isinstance(content, str):
|
|
for node in nodes:
|
|
if hasattr(node, 'text'):
|
|
assert node.text == content, 'xpath %s does not contain %s but %s' % (
|
|
xpath,
|
|
content,
|
|
node.text,
|
|
)
|
|
else:
|
|
assert node == content, 'xpath %s does not contain %s but %s' % (xpath, content, node)
|
|
else:
|
|
values = [node.text if hasattr(node, 'text') else node for node in nodes]
|
|
if isinstance(content, set):
|
|
assert set(values) == content
|
|
elif isinstance(content, list):
|
|
assert values == content
|
|
elif hasattr(content, 'pattern'):
|
|
for value in values:
|
|
assert content.match(value), 'xpath %s does not match regexp %s' % (
|
|
xpath,
|
|
content.pattern,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
'comparing xpath result to type %s: %r is not implemented' % (type(content), content)
|
|
)
|
|
|
|
|
|
class Authentic2TestCase(TestCase):
|
|
def assertEqualsURL(self, url1, url2, **kwargs):
|
|
assert_equals_url(url1, url2, **kwargs)
|
|
|
|
def assertRedirectsComplex(self, response, expected_url, **kwargs):
|
|
assert_redirects_complex(response, expected_url, **kwargs)
|
|
|
|
def assertXPathConstraints(self, xml, constraints, namespaces):
|
|
assert_xpath_constraints(xml, constraints, namespaces)
|
|
|
|
|
|
@contextmanager
|
|
def check_log(caplog, message, levelname=None):
|
|
idx = len(caplog.records)
|
|
yield
|
|
assert any(
|
|
message in record.message
|
|
for record in caplog.records[idx:]
|
|
if not levelname or record.levelname == levelname
|
|
), ('%r not found in log records' % message)
|
|
|
|
|
|
def get_links_from_mail(mail):
|
|
'''Extract links from mail sent by Django'''
|
|
return re.findall('https?://[^ \n]*', mail.body)
|
|
|
|
|
|
def get_link_from_mail(mail):
|
|
'''Extract the first and only link from this mail'''
|
|
links = get_links_from_mail(mail)
|
|
assert links, 'there is not link in this mail'
|
|
assert len(links) == 1, 'there are more than one link in this mail'
|
|
return links[0]
|
|
|
|
|
|
def saml_sp_metadata(base_url):
|
|
return '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
|
<EntityDescriptor
|
|
entityID="{base_url}/"
|
|
xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
|
|
<SPSSODescriptor
|
|
AuthnRequestsSigned="true"
|
|
WantAssertionsSigned="true"
|
|
protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
|
<SingleLogoutService
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
|
Location="https://files.entrouvert.org/mellon/logout" />
|
|
<AssertionConsumerService
|
|
index="0"
|
|
isDefault="true"
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
|
Location="{base_url}/sso/POST" />
|
|
<AssertionConsumerService
|
|
index="1"
|
|
Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact"
|
|
Location="{base_url}/mellon/artifactResponse" />
|
|
</SPSSODescriptor>
|
|
</EntityDescriptor>'''.format(
|
|
base_url=base_url
|
|
)
|
|
|
|
|
|
def find_free_tcp_port():
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
s.bind(('', 0))
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def request_select2(app, response, term='', fetch_all=False, page=1, get_kwargs=None):
|
|
select2_url = response.pyquery('select')[0].attrib['data-ajax--url']
|
|
select2_field_id = response.pyquery('select')[0].attrib['data-field_id']
|
|
|
|
params = {'field_id': select2_field_id, 'term': term}
|
|
if page:
|
|
params['page'] = page
|
|
|
|
select2_response = app.get(select2_url, params=params, **(get_kwargs or {}))
|
|
if select2_response['content-type'] != 'application/json':
|
|
return select2_response
|
|
|
|
select2_json = select2_response.json
|
|
results = select2_json['results']
|
|
if fetch_all and select2_json['more']:
|
|
results.extend(request_select2(app, response, term, fetch_all, page + 1, get_kwargs)['results'])
|
|
|
|
return select2_json
|
|
|
|
|
|
@contextmanager
|
|
def run_on_commit_hooks():
|
|
yield
|
|
|
|
from django.db import connection
|
|
|
|
current_run_on_commit = connection.run_on_commit
|
|
connection.run_on_commit = []
|
|
while current_run_on_commit:
|
|
_, func = current_run_on_commit.pop(0)
|
|
func()
|
|
|
|
|
|
def call_command(*args, **kwargs):
|
|
with run_on_commit_hooks():
|
|
return django_call_command(*args, **kwargs)
|
|
|
|
|
|
def text_content(node):
|
|
"""Extract text content from node and all its children. Equivalent to
|
|
xmlNodeGetContent from libxml."""
|
|
return ''.join(node.itertext()) if node is not None else ''
|
|
|
|
|
|
def assert_event(event_type_name, user=None, session=None, service=None, target_user=None, api=False, **data):
|
|
from authentic2.models import Service
|
|
|
|
qs = Event.objects.filter(type__name=event_type_name, api=api)
|
|
if user is not None:
|
|
qs = qs.filter(user=user)
|
|
else:
|
|
qs = qs.filter(user__isnull=True)
|
|
if session is not None:
|
|
qs = qs.filter(session=session.session_key)
|
|
else:
|
|
qs = qs.filter(session__isnull=True)
|
|
if service is not None:
|
|
qs = qs.which_references(Service(pk=service.pk))
|
|
# else:
|
|
# qs = qs.exclude(qs._which_references_query(models.Service))
|
|
if target_user is not None:
|
|
qs = qs.which_references(target_user)
|
|
|
|
count = qs.count()
|
|
assert count > 0
|
|
|
|
if not data:
|
|
assert count == 1
|
|
|
|
if data and count == 1:
|
|
event = qs.get()
|
|
assert event.data, 'no event.data, should be %s' % data
|
|
for key, value in data.items():
|
|
assert event.data.get(key) == value, 'event.data[%s] != data[%s] (%s != %s)' % (
|
|
key,
|
|
key,
|
|
event.data.get(key),
|
|
value,
|
|
)
|
|
return event
|
|
elif data and count > 1:
|
|
qs = qs.filter(**{'data__' + k: v for k, v in data.items()})
|
|
assert qs.count() == 1
|
|
return qs.get()
|
|
|
|
|
|
def clear_events():
|
|
Event.objects.all().delete()
|
|
|
|
|
|
def set_service(app, service):
|
|
from importlib import import_module
|
|
|
|
from django.conf import settings
|
|
|
|
from authentic2.utils.service import _set_session_service
|
|
|
|
engine = import_module(settings.SESSION_ENGINE)
|
|
if app.session == {}:
|
|
session = engine.SessionStore()
|
|
else:
|
|
session = app.session
|
|
_set_session_service(session, service)
|
|
session.save()
|
|
if app.session == {}:
|
|
app.set_cookie(settings.SESSION_COOKIE_NAME, session.session_key)
|
|
|
|
|
|
def decode_cookie(data):
|
|
signer = signing.get_cookie_signer(salt='django.contrib.messages')
|
|
try:
|
|
return signer.unsign_object(data, serializer=MessageSerializer)
|
|
except signing.BadSignature:
|
|
return None
|
|
except AttributeError:
|
|
# xxx support legacy decoding?
|
|
return data
|
|
|
|
|
|
def scoped_db_fixture(func=None, /, **kwargs):
|
|
'''Create a db fixture with a scope different than 'function' the default one.'''
|
|
if func is None:
|
|
return functools.partial(scoped_db_fixture, **kwargs)
|
|
assert kwargs.get('scope') in [
|
|
'session',
|
|
'module',
|
|
'class',
|
|
], 'scoped_db_fixture is only usable with a non function scope'
|
|
signature = inspect.signature(func)
|
|
inner_parameters = []
|
|
for parameter in signature.parameters:
|
|
inner_parameters.append(parameter)
|
|
outer_parameters = ['scoped_db'] if 'scoped_db' not in inner_parameters else []
|
|
if inner_parameters and inner_parameters[0] == 'self':
|
|
outer_parameters = ['self'] + outer_parameters + inner_parameters[1:]
|
|
else:
|
|
outer_parameters = outer_parameters + inner_parameters
|
|
# build a new fixture function, inject the scoped_db fixture inside and the old function in a scoped_db context.
|
|
new_function_def = f'''def f({", ".join(outer_parameters)}):
|
|
if inspect.isgeneratorfunction(func):
|
|
def g():
|
|
yield from func({", ".join(inner_parameters)})
|
|
else:
|
|
def g():
|
|
return func({", ".join(inner_parameters)})
|
|
with scoped_db(g) as result:
|
|
yield result'''
|
|
new_function_bytecode = compile(new_function_def, filename=f'<scoped-db-fixture {func}>', mode='single')
|
|
global_dict = {'func': func, 'inspect': inspect}
|
|
eval(new_function_bytecode, global_dict) # pylint: disable=eval-used
|
|
new_function = global_dict['f']
|
|
wrapped_func = functools.wraps(func)(new_function)
|
|
# prevent original fixture signature to override the new fixture signature
|
|
# (because of inspect.signature(follow_wrapped=True)) during pytest collection
|
|
del wrapped_func.__wrapped__
|
|
return pytest.fixture(**kwargs)(wrapped_func)
|