auth_oidc: rewrite loading of jwkset by URL (#85934)
gitea/authentic/pipeline/head This commit looks good
Details
gitea/authentic/pipeline/head This commit looks good
Details
Use the new common HTTP API.
This commit is contained in:
parent
eceb4b2424
commit
e63b9c2898
|
@ -39,6 +39,11 @@ class OIDCProviderEditForm(forms.ModelForm):
|
|||
if self.instance.jwkset_url:
|
||||
self.fields['jwkset_json'].disabled = True
|
||||
self.fields['jwkset_json'].help_text = _('JSON is fetched from the WebKey Set URL')
|
||||
self.old_jwkset = self.instance.jwkset_json or {}
|
||||
|
||||
def save(self, commit=True):
|
||||
super().save(commit=commit)
|
||||
self.instance.log_jwkset_change(self.old_jwkset, self.instance.jwkset_json)
|
||||
|
||||
|
||||
class OIDCProviderAdvancedForm(forms.ModelForm):
|
||||
|
|
|
@ -19,20 +19,15 @@ import logging
|
|||
from authentic2.base_commands import LogToConsoleCommand
|
||||
from authentic2_auth_oidc.models import OIDCProvider
|
||||
|
||||
logger = logging.getLogger('authentic2.auth.oidc')
|
||||
|
||||
|
||||
class Command(LogToConsoleCommand):
|
||||
loggername = 'authentic2_auth_oidc.models'
|
||||
|
||||
def core_command(self, *args, **kwargs):
|
||||
logger = logging.getLogger(self.loggername)
|
||||
providers = OIDCProvider.objects.exclude(jwkset_url='')
|
||||
if not providers.count():
|
||||
return
|
||||
logger.info(
|
||||
'got %s provider(s): %s',
|
||||
providers.count(),
|
||||
' '.join(providers.values_list('slug', flat=True)),
|
||||
)
|
||||
for provider in providers:
|
||||
provider.set_jwkset_json_from_url()
|
||||
provider.save()
|
||||
for oidc_provider in OIDCProvider.objects.all():
|
||||
try:
|
||||
oidc_provider.refresh_jwkset_json()
|
||||
except Exception as e:
|
||||
logger.warning('auth_oidc: could not refresh jwkset for provider %s (%s)', oidc_provider, e)
|
||||
|
|
|
@ -21,7 +21,7 @@ from datetime import datetime, timedelta
|
|||
import requests
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db import models, transaction
|
||||
from django.db.models import JSONField
|
||||
from django.shortcuts import render
|
||||
from django.utils.timezone import now
|
||||
|
@ -36,18 +36,22 @@ from authentic2.apps.authenticators.models import (
|
|||
BaseAuthenticator,
|
||||
)
|
||||
from authentic2.apps.journal.journal import journal
|
||||
from authentic2.utils import http
|
||||
from authentic2.utils.misc import make_url
|
||||
from authentic2.utils.template import validate_template
|
||||
|
||||
from . import managers, utils
|
||||
|
||||
|
||||
def validate_jwkset(data):
|
||||
data = json.dumps(data)
|
||||
def parse_jwkset(data):
|
||||
try:
|
||||
JWKSet.from_json(data)
|
||||
except InvalidJWKValue as e:
|
||||
raise ValidationError(_('Invalid JWKSet: %s') % e)
|
||||
return JWKSet.from_json(data)
|
||||
except InvalidJWKValue:
|
||||
raise ValidationError(_('Invalid JWKSet'))
|
||||
|
||||
|
||||
def validate_jwkset(data):
|
||||
parse_jwkset(json.dumps(data))
|
||||
|
||||
|
||||
class OIDCProvider(BaseAuthenticator):
|
||||
|
@ -190,7 +194,10 @@ class OIDCProvider(BaseAuthenticator):
|
|||
@property
|
||||
def jwkset(self):
|
||||
if self.jwkset_json:
|
||||
return JWKSet.from_json(json.dumps(self.jwkset_json))
|
||||
try:
|
||||
return parse_jwkset(json.dumps(self.jwkset_json))
|
||||
except ValidationError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_short_description(self):
|
||||
|
@ -213,18 +220,25 @@ class OIDCProvider(BaseAuthenticator):
|
|||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
if self.jwkset_url:
|
||||
try:
|
||||
self.jwkset_json = self.load_jwkset_url()
|
||||
except ValidationError as e:
|
||||
raise ValidationError({'jwkset_url': e})
|
||||
|
||||
if self.idtoken_algo not in (self.ALGO_NONE, self.ALGO_HMAC):
|
||||
key_sig_mapping = {
|
||||
self.ALGO_RSA: 'RSA',
|
||||
self.ALGO_EC: 'EC',
|
||||
}
|
||||
if not self.jwkset_json:
|
||||
jwkset = self.jwkset
|
||||
if not jwkset:
|
||||
raise ValidationError(
|
||||
_('Provider signature method is %s yet no jwkset was provided.')
|
||||
% key_sig_mapping[self.idtoken_algo]
|
||||
)
|
||||
# verify that a key is available for the chosen algorithm
|
||||
for key in self.jwkset:
|
||||
for key in jwkset['keys']:
|
||||
# compatibility with jwcrypto < 1
|
||||
key_type = key.get('kty', None) if isinstance(key, dict) else key.key_type
|
||||
if key_type == key_sig_mapping[self.idtoken_algo]:
|
||||
|
@ -240,51 +254,46 @@ class OIDCProvider(BaseAuthenticator):
|
|||
def save(self, *args, **kwargs):
|
||||
if not self.ou:
|
||||
self.ou = get_default_ou()
|
||||
if self.jwkset_url:
|
||||
self.set_jwkset_json_from_url()
|
||||
if self.jwkset_url and not self.jwkset_json:
|
||||
raise ValueError('model is not initialized')
|
||||
if self.jwkset_json:
|
||||
validate_jwkset(self.jwkset_json)
|
||||
return super().save(*args, **kwargs)
|
||||
|
||||
def set_jwkset_json_from_url(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
def load_jwkset_url(self):
|
||||
try:
|
||||
response = requests.get(
|
||||
self.jwkset_url,
|
||||
timeout=settings.REQUESTS_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException:
|
||||
logger.error('Unable to reach JWKSet content from URL %s', self.jwkset_url)
|
||||
response = http.get(self.jwkset_url)
|
||||
except http.HTTPError as e:
|
||||
raise ValidationError(_('JWKSet URL is unreachable: %s') % e)
|
||||
return parse_jwkset(response.content).export(as_dict=True)
|
||||
|
||||
def refresh_jwkset_json(self):
|
||||
if not self.jwkset_url:
|
||||
return
|
||||
|
||||
if not hasattr(response, 'json'):
|
||||
logger.error('JWKSet URL %s is not JSON', self.jwkset_url)
|
||||
return
|
||||
try:
|
||||
json_value = response.json()
|
||||
except json.JSONDecodeError:
|
||||
logger.error('JWKSet from URL %s is invalid', self.jwkset_url)
|
||||
return
|
||||
if not isinstance(json_value, dict) or 'keys' not in json_value:
|
||||
logger.error('JWKSet from URL %s does not contain a \'keys\' entry', self.jwkset_url)
|
||||
return
|
||||
if self.jwkset_json != json_value:
|
||||
old_keyset = {key.get('kid') for key in (self.jwkset_json or dict()).get('keys', [])}
|
||||
new_keyset = {key.get('kid') for key in json_value.get('keys', [])}
|
||||
logger.debug(
|
||||
'provider %s renewed its JWKSet with new keys [%s] whereas old keys [%s] are now deprecated',
|
||||
self,
|
||||
', '.join(new_keyset - old_keyset) or '—',
|
||||
', '.join(old_keyset - new_keyset) or '—',
|
||||
)
|
||||
journal.record(
|
||||
'provider.keyset.change',
|
||||
provider=self.name,
|
||||
new_keyset=new_keyset,
|
||||
old_keyset=old_keyset,
|
||||
)
|
||||
old_jwkset = self.jwkset_json
|
||||
new_jwkset = self.load_jwkset_url()
|
||||
|
||||
# JSON is checked as part of attribute jwkset_json validation
|
||||
self.jwkset_json = json_value
|
||||
if old_jwkset == new_jwkset:
|
||||
return
|
||||
|
||||
with transaction.atomic():
|
||||
self.jwkset_json = new_jwkset
|
||||
self.log_jwkset_change(old_jwkset, new_jwkset)
|
||||
self.save(update_fields=['jwkset_json', 'modified'])
|
||||
|
||||
def log_jwkset_change(self, old_jwkset, new_jwkset):
|
||||
if old_jwkset == new_jwkset:
|
||||
return
|
||||
|
||||
old_keyset = {key.get('kid') for key in (old_jwkset or dict()).get('keys', [])}
|
||||
new_keyset = {key.get('kid') for key in new_jwkset.get('keys', [])}
|
||||
journal.record(
|
||||
'provider.keyset.change',
|
||||
provider=self.name,
|
||||
new_keyset=new_keyset,
|
||||
old_keyset=old_keyset,
|
||||
)
|
||||
|
||||
def authorization_claims_parameter(self):
|
||||
idtoken_claims = {}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) Entr'ouvert
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
from jwcrypto.jwk import JWK, JWKSet
|
||||
|
||||
KID_RSA = '1e9gdk7'
|
||||
KID_EC = 'jb20Cg8'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kid_rsa():
|
||||
return KID_RSA
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kid_ec():
|
||||
return KID_EC
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwkset(kid_rsa, kid_ec):
|
||||
key_rsa = JWK.generate(kty='RSA', size=512, kid=kid_rsa)
|
||||
key_ec = JWK.generate(kty='EC', size=256, kid=kid_ec)
|
||||
jwkset = JWKSet()
|
||||
jwkset.add(key_rsa)
|
||||
jwkset.add(key_ec)
|
||||
return jwkset
|
||||
|
||||
|
||||
@responses.activate
|
||||
@pytest.fixture
|
||||
def jwkset_url(jwkset):
|
||||
jwkset_url = 'https://www.example.com/common/discovery/v3.0/keys'
|
||||
responses.get(jwkset_url, json=jwkset.export(as_dict=True))
|
||||
yield jwkset_url
|
|
@ -231,7 +231,7 @@ def test_auth_oidc_refresh_jwkset_json(db, app, admin, settings, caplog):
|
|||
)
|
||||
|
||||
issuer = ('https://www.example.com',)
|
||||
provider = OIDCProvider.objects.create(
|
||||
provider = OIDCProvider(
|
||||
ou=get_default_ou(),
|
||||
name='Foo',
|
||||
slug='foo',
|
||||
|
@ -249,12 +249,15 @@ def test_auth_oidc_refresh_jwkset_json(db, app, admin, settings, caplog):
|
|||
claims_parameter_supported=False,
|
||||
button_label='Connect with Foo',
|
||||
)
|
||||
provider.clean()
|
||||
provider.save()
|
||||
assert {key['kid'] for key in provider.jwkset_json['keys']} == {'123', '456'}
|
||||
|
||||
kid_rsa = 'abcdefg'
|
||||
kid_ec = 'hijklmn'
|
||||
|
||||
responses.get(
|
||||
responses.replace(
|
||||
responses.GET,
|
||||
jwkset_url,
|
||||
json={
|
||||
'headers': {
|
||||
|
|
|
@ -19,7 +19,6 @@ import json
|
|||
import pytest
|
||||
import responses
|
||||
from django.utils.html import escape
|
||||
from jwcrypto.jwk import JWK, JWKSet
|
||||
from webtest import Upload
|
||||
|
||||
from authentic2.a2_rbac.models import Role
|
||||
|
@ -35,7 +34,7 @@ from .test_misc import oidc_provider, oidc_provider_jwkset # pylint: disable=un
|
|||
|
||||
@pytest.mark.freeze_time('2022-04-19 14:00')
|
||||
@responses.activate
|
||||
def test_authenticators_oidc(app, superuser, ou1, ou2):
|
||||
def test_authenticators_oidc(app, superuser, ou1, ou2, jwkset_url, kid_rsa):
|
||||
resp = login(app, superuser, path='/manage/authenticators/')
|
||||
|
||||
resp = resp.click('Add new authenticator')
|
||||
|
@ -68,11 +67,12 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
|
|||
resp.form['authorization_endpoint'] = 'https://oidc.example.com/authorize'
|
||||
resp.form['token_endpoint'] = 'https://oidc.example.com/token'
|
||||
resp.form['userinfo_endpoint'] = 'https://oidc.example.com/user_info'
|
||||
resp.form['idtoken_algo'] = 2
|
||||
resp.form['button_label'] = 'Test'
|
||||
resp.form['button_description'] = 'test'
|
||||
resp.form['client_id'] = 'auie'
|
||||
resp.form['client_secret'] = 'tsrn'
|
||||
resp.form['idtoken_algo'].select(text='RSA')
|
||||
resp.form['jwkset_url'] = jwkset_url
|
||||
resp = resp.form.submit().follow()
|
||||
assert_event('authenticator.edit', user=superuser, session=app.session)
|
||||
|
||||
|
@ -90,33 +90,25 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
|
|||
assert_event('authenticator.enable', user=superuser, session=app.session)
|
||||
|
||||
resp = resp.click('Journal of edits')
|
||||
assert 'enable' in resp.text
|
||||
assert (
|
||||
'edit (ou, issuer, scopes, strategy, client_id, button_label, idtoken_algo, '
|
||||
'client_secret, token_endpoint, userinfo_endpoint, button_description, authorization_endpoint)'
|
||||
in resp.text
|
||||
)
|
||||
assert 'creation' in resp.text
|
||||
|
||||
jwkset_url = 'https://www.example.com/common/discovery/v3.0/keys'
|
||||
kid_rsa = '123'
|
||||
|
||||
def generate_remote_jwkset_json():
|
||||
key_rsa = JWK.generate(kty='RSA', size=512, kid=kid_rsa)
|
||||
jwkset = JWKSet()
|
||||
jwkset.add(key_rsa)
|
||||
return jwkset.export(as_dict=True)
|
||||
|
||||
responses.get(
|
||||
jwkset_url,
|
||||
json={
|
||||
'headers': {
|
||||
'content-type': 'application/json',
|
||||
},
|
||||
'status_code': 200,
|
||||
**generate_remote_jwkset_json(),
|
||||
},
|
||||
)
|
||||
assert resp.pyquery('.journal-list--message-column:contains("creation")')
|
||||
assert resp.pyquery('.journal-list--message-column:contains("enable")')
|
||||
edit_message = resp.pyquery('.journal-list--message-column:contains("edit")').text()
|
||||
terms = {term.strip(',').strip('(').strip(')') for term in edit_message.split()}
|
||||
assert terms == {
|
||||
'edit',
|
||||
'ou',
|
||||
'issuer',
|
||||
'scopes',
|
||||
'strategy',
|
||||
'client_id',
|
||||
'button_label',
|
||||
'client_secret',
|
||||
'token_endpoint',
|
||||
'userinfo_endpoint',
|
||||
'button_description',
|
||||
'jwkset_url',
|
||||
'authorization_endpoint',
|
||||
}
|
||||
|
||||
provider.refresh_from_db()
|
||||
provider.jwkset_url = jwkset_url
|
||||
|
@ -125,7 +117,7 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
|
|||
resp = app.get('/manage/authenticators/%s/edit/' % provider.pk)
|
||||
assert resp.pyquery('input#id_jwkset_url')[0].value == jwkset_url
|
||||
assert 'disabled' in resp.pyquery('textarea#id_jwkset_json')[0].keys()
|
||||
assert '"kid": "123"' in resp.pyquery('textarea#id_jwkset_json')[0].text
|
||||
assert f'"kid": "{kid_rsa}"' in resp.pyquery('textarea#id_jwkset_json')[0].text
|
||||
assert (
|
||||
resp.pyquery('div[aria-labelledby="id_jwkset_json_title"] div.hint')[0].text
|
||||
== 'JSON is fetched from the WebKey Set URL'
|
||||
|
|
|
@ -54,10 +54,15 @@ from authentic2_auth_oidc.utils import IDToken, IDTokenError, parse_id_token, re
|
|||
from authentic2_auth_oidc.views import oidc_login
|
||||
from tests import utils
|
||||
|
||||
from .conftest import KID_EC, KID_RSA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
ANOTHER_KID_RSA = 'mt80xpd'
|
||||
ANOTHER_KID_EC = 'iet7tm31'
|
||||
|
||||
|
||||
def test_base64url_decode():
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -65,10 +70,6 @@ def test_base64url_decode():
|
|||
base64url_decode('aa')
|
||||
|
||||
|
||||
KID_RSA = '1e9gdk7'
|
||||
ANOTHER_KID_RSA = 'mt80xpd'
|
||||
KID_EC = 'jb20Cg8'
|
||||
ANOTHER_KID_EC = 'iet7tm31'
|
||||
JWKSET_URL = 'https://www.example.com/common/discovery/v3.0/keys'
|
||||
header_rsa_decoded = {'alg': 'RS256', 'kid': KID_RSA}
|
||||
header_ec_decoded = {'alg': 'ES256', 'kid': KID_EC}
|
||||
|
@ -121,12 +122,7 @@ def test_idtoken(oidc_provider):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_provider_jwkset():
|
||||
key_rsa = JWK.generate(kty='RSA', size=512, kid=KID_RSA)
|
||||
key_ec = JWK.generate(kty='EC', size=256, kid=KID_EC)
|
||||
jwkset = JWKSet()
|
||||
jwkset.add(key_rsa)
|
||||
jwkset.add(key_ec)
|
||||
def oidc_provider_jwkset(jwkset):
|
||||
return jwkset
|
||||
|
||||
|
||||
|
@ -403,7 +399,7 @@ def test_oidc_provider_jwkset_url(db):
|
|||
|
||||
with HTTMock(jwkset_url_mock):
|
||||
issuer = ('https://www.example.com',)
|
||||
provider = OIDCProvider.objects.create(
|
||||
provider = OIDCProvider(
|
||||
ou=get_default_ou(),
|
||||
name='Foo',
|
||||
slug='foo',
|
||||
|
@ -421,8 +417,8 @@ def test_oidc_provider_jwkset_url(db):
|
|||
claims_parameter_supported=False,
|
||||
button_label='Connect with Foo',
|
||||
)
|
||||
assert provider.jwkset_json
|
||||
assert isinstance(provider.jwkset_json, dict)
|
||||
provider.clean()
|
||||
provider.save()
|
||||
assert provider.jwkset
|
||||
assert len(provider.jwkset_json['keys']) == 2
|
||||
assert {key['kid'] for key in provider.jwkset_json['keys']} == {ANOTHER_KID_RSA, ANOTHER_KID_EC}
|
||||
|
|
Loading…
Reference in New Issue