misc: expose HTTP headers in authenticators conditions (#47084)
This commit is contained in:
parent
7e38340a92
commit
b6f471b9fa
|
@ -28,6 +28,19 @@ from django.utils import six
|
|||
import ast
|
||||
|
||||
|
||||
class HTTPHeaders:
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
def __contains__(self, header):
|
||||
meta_header = 'HTTP_' + header.replace('-', '_').upper()
|
||||
return meta_header in self.request.META
|
||||
|
||||
def __getitem__(self, header):
|
||||
meta_header = 'HTTP_' + header.replace('-', '_').upper()
|
||||
return self.request.META.get(meta_header)
|
||||
|
||||
|
||||
class Unparse(ast.NodeVisitor):
|
||||
def visit_Name(self, node):
|
||||
return node.id
|
||||
|
@ -113,6 +126,12 @@ class BaseExpressionValidator(ast.NodeVisitor):
|
|||
six.reraise(*sys.exc_info())
|
||||
return compile(tree, expression, mode='eval')
|
||||
|
||||
# python 3.8 introduced ast.Constant to replace Num, Str, Bytes and NameConstant (True, False, None)
|
||||
if sys.version_info < (3, 8):
|
||||
CONSTANT_CLASSES = (ast.Num, ast.Str, ast.Bytes)
|
||||
else:
|
||||
CONSTANT_CLASSES = (ast.Constant,)
|
||||
|
||||
|
||||
class ConditionValidator(BaseExpressionValidator):
|
||||
'''
|
||||
|
@ -123,6 +142,7 @@ class ConditionValidator(BaseExpressionValidator):
|
|||
- unary operator expressions with all operators,
|
||||
- if expressions (x if y else z),
|
||||
- compare expressions with all operators.
|
||||
- subscript of direct variable reference.
|
||||
|
||||
Are implicitely forbidden:
|
||||
- binary expressions (so no "'aaa' * 99999999999" or 233333333333333233**2232323233232323 bombs),
|
||||
|
@ -134,7 +154,6 @@ class ConditionValidator(BaseExpressionValidator):
|
|||
- call,
|
||||
- Repr node (i dunno what it is),
|
||||
- attribute access,
|
||||
- subscript.
|
||||
'''
|
||||
authorized_nodes = [
|
||||
ast.Load,
|
||||
|
@ -144,6 +163,8 @@ class ConditionValidator(BaseExpressionValidator):
|
|||
ast.BoolOp,
|
||||
ast.UnaryOp,
|
||||
ast.IfExp,
|
||||
ast.Subscript,
|
||||
ast.Index,
|
||||
ast.boolop,
|
||||
ast.cmpop,
|
||||
ast.Compare,
|
||||
|
@ -151,8 +172,8 @@ class ConditionValidator(BaseExpressionValidator):
|
|||
|
||||
def __init__(self, authorized_nodes=None, forbidden_nodes=None):
|
||||
super(ConditionValidator, self).__init__(
|
||||
authorized_nodes=authorized_nodes,
|
||||
forbidden_nodes=forbidden_nodes)
|
||||
authorized_nodes=authorized_nodes,
|
||||
forbidden_nodes=forbidden_nodes)
|
||||
if six.PY3:
|
||||
self.authorized_nodes.append(ast.NameConstant)
|
||||
|
||||
|
@ -160,6 +181,14 @@ class ConditionValidator(BaseExpressionValidator):
|
|||
if node.id.startswith('_'):
|
||||
raise ExpressionError(_('name must not start with a _'), code='invalid-variable', node=node)
|
||||
|
||||
def check_Subscript(self, node):
|
||||
# check subscript are constant number or strings
|
||||
if (not isinstance(node.slice, ast.Index)
|
||||
or not isinstance(node.slice.value, CONSTANT_CLASSES)
|
||||
# with python <3.8 the node class is enough to determine the value
|
||||
or (sys.version_info >= (3, 8) and not isinstance(node.slice.value.value, (int, str, bytes)))):
|
||||
raise ExpressionError(_('subscript index MUST be a constant'), code='invalid-subscript', node=node)
|
||||
|
||||
|
||||
validate_condition = ConditionValidator()
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ from authentic2.custom_user.models import iter_attributes
|
|||
from . import (utils, app_settings, decorators, constants,
|
||||
models, cbv, hooks, validators, attribute_kinds)
|
||||
from .utils.service import get_service_from_request, get_service_from_token, set_service_ref
|
||||
from .utils.evaluate import HTTPHeaders
|
||||
from .utils import switch_user
|
||||
from .a2_rbac.utils import get_default_ou
|
||||
from .a2_rbac.models import OrganizationalUnit as OU
|
||||
|
@ -316,7 +317,10 @@ def login(request, template_name='authentic2/login.html',
|
|||
'context': context}
|
||||
remote_addr = request.META.get('REMOTE_ADDR')
|
||||
login_hint = set(request.session.get('login-hint', []))
|
||||
show_ctx = dict(remote_addr=remote_addr, login_hint=login_hint)
|
||||
show_ctx = dict(
|
||||
remote_addr=remote_addr,
|
||||
login_hint=login_hint,
|
||||
headers=HTTPHeaders(request))
|
||||
if service:
|
||||
show_ctx['service_ou_slug'] = service.ou and service.ou.slug
|
||||
show_ctx['service_slug'] = service.slug
|
||||
|
|
|
@ -50,7 +50,7 @@ def test_login_inactive_user(db, app):
|
|||
assert '_auth_user_id' not in app.session
|
||||
|
||||
|
||||
def test_login_with_conditionnal_enabled_authenticators(db, app, settings, caplog):
|
||||
def test_show_condition(db, app, settings, caplog):
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' in response
|
||||
|
||||
|
@ -64,16 +64,28 @@ def test_login_with_conditionnal_enabled_authenticators(db, app, settings, caplo
|
|||
settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': '\'admin\' in unknown'}}
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' not in response
|
||||
|
||||
|
||||
def test_show_condition_service(db, app, settings):
|
||||
settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': 'service_slug == \'portal\''}}
|
||||
response = app.get('/login/', params={'service': 'portal'})
|
||||
assert 'name="login-password-submit"' not in response
|
||||
|
||||
# Create a service
|
||||
service = models.Service.objects.create(name='Service', slug='portal')
|
||||
models.Service.objects.create(name='Service', slug='portal')
|
||||
response = app.get('/login/', params={'service': 'portal'})
|
||||
assert 'name="login-password-submit"' in response
|
||||
|
||||
|
||||
def test_show_condition_with_headers(app, settings):
|
||||
settings.A2_AUTH_OIDC_ENABLE = False # prevent db access by OIDC frontend
|
||||
settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': '\'X-Entrouvert\' in headers'}}
|
||||
response = app.get('/login/')
|
||||
assert 'name="login-password-submit"' not in response
|
||||
response = app.get('/login/', headers={'x-entrouvert': '1'})
|
||||
assert 'name="login-password-submit"' in response
|
||||
|
||||
|
||||
def test_registration_url_on_login_page(db, app):
|
||||
response = app.get('/login/?next=/whatever')
|
||||
assert 'register/?next=/whatever"' in response
|
||||
|
|
|
@ -21,7 +21,7 @@ import pytest
|
|||
|
||||
from authentic2.utils.evaluate import (
|
||||
BaseExpressionValidator, ConditionValidator, ExpressionError,
|
||||
evaluate_condition)
|
||||
evaluate_condition, HTTPHeaders)
|
||||
|
||||
|
||||
def test_base():
|
||||
|
@ -57,17 +57,32 @@ def test_condition_validator():
|
|||
with pytest.raises(ExpressionError) as raised:
|
||||
v('1 + 2')
|
||||
|
||||
v('a[1]')
|
||||
|
||||
def test_evaluate_condition():
|
||||
v = ConditionValidator()
|
||||
v('a[\'xx\']')
|
||||
|
||||
assert evaluate_condition('False', validator=v) is False
|
||||
assert evaluate_condition('True', validator=v) is True
|
||||
assert evaluate_condition('True and False', validator=v) is False
|
||||
assert evaluate_condition('True or False', validator=v) is True
|
||||
assert evaluate_condition('a or b', ctx=dict(a=True, b=False), validator=v) is True
|
||||
assert evaluate_condition('a < 1', ctx=dict(a=0), validator=v) is True
|
||||
with pytest.raises(ExpressionError, match='MUST be a constant'):
|
||||
v('a[1:2]')
|
||||
|
||||
with pytest.raises(ExpressionError, match='MUST be a constant'):
|
||||
v('headers[headers]')
|
||||
|
||||
|
||||
def test_evaluate_condition(rf):
|
||||
assert evaluate_condition('False') is False
|
||||
assert evaluate_condition('True') is True
|
||||
assert evaluate_condition('True and False') is False
|
||||
assert evaluate_condition('True or False') is True
|
||||
assert evaluate_condition('a or b', ctx=dict(a=True, b=False)) is True
|
||||
assert evaluate_condition('a < 1', ctx=dict(a=0)) is True
|
||||
with pytest.raises(ExpressionError) as exc_info:
|
||||
evaluate_condition('a < 1', validator=v)
|
||||
evaluate_condition('a < 1')
|
||||
assert exc_info.value.code == 'undefined-variable'
|
||||
assert evaluate_condition('a < 1', validator=v, on_raise=False) is False
|
||||
assert evaluate_condition('a < 1', on_raise=False) is False
|
||||
|
||||
|
||||
def test_http_headers(rf):
|
||||
request = rf.get('/', HTTP_X_ENTROUVERT='1')
|
||||
headers = HTTPHeaders(request)
|
||||
assert evaluate_condition('"X-Entrouvert" in headers', ctx={'headers': headers}) is True
|
||||
assert evaluate_condition('headers["X-Entrouvert"]', ctx={'headers': headers}) == '1'
|
||||
|
|
Loading…
Reference in New Issue