auth_oidc: rewrite loading of jwkset by URL (#85934)
gitea/authentic/pipeline/head This commit looks good Details

Use the new common HTTP API.
This commit is contained in:
Benjamin Dauvergne 2024-01-23 21:50:41 +01:00
parent eceb4b2424
commit e63b9c2898
7 changed files with 143 additions and 106 deletions

View File

@ -39,6 +39,11 @@ class OIDCProviderEditForm(forms.ModelForm):
if self.instance.jwkset_url:
self.fields['jwkset_json'].disabled = True
self.fields['jwkset_json'].help_text = _('JSON is fetched from the WebKey Set URL')
self.old_jwkset = self.instance.jwkset_json or {}
def save(self, commit=True):
super().save(commit=commit)
self.instance.log_jwkset_change(self.old_jwkset, self.instance.jwkset_json)
class OIDCProviderAdvancedForm(forms.ModelForm):

View File

@ -19,20 +19,15 @@ import logging
from authentic2.base_commands import LogToConsoleCommand
from authentic2_auth_oidc.models import OIDCProvider
logger = logging.getLogger('authentic2.auth.oidc')
class Command(LogToConsoleCommand):
loggername = 'authentic2_auth_oidc.models'
def core_command(self, *args, **kwargs):
logger = logging.getLogger(self.loggername)
providers = OIDCProvider.objects.exclude(jwkset_url='')
if not providers.count():
return
logger.info(
'got %s provider(s): %s',
providers.count(),
' '.join(providers.values_list('slug', flat=True)),
)
for provider in providers:
provider.set_jwkset_json_from_url()
provider.save()
for oidc_provider in OIDCProvider.objects.all():
try:
oidc_provider.refresh_jwkset_json()
except Exception as e:
logger.warning('auth_oidc: could not refresh jwkset for provider %s (%s)', oidc_provider, e)

View File

@ -21,7 +21,7 @@ from datetime import datetime, timedelta
import requests
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models
from django.db import models, transaction
from django.db.models import JSONField
from django.shortcuts import render
from django.utils.timezone import now
@ -36,18 +36,22 @@ from authentic2.apps.authenticators.models import (
BaseAuthenticator,
)
from authentic2.apps.journal.journal import journal
from authentic2.utils import http
from authentic2.utils.misc import make_url
from authentic2.utils.template import validate_template
from . import managers, utils
def validate_jwkset(data):
data = json.dumps(data)
def parse_jwkset(data):
try:
JWKSet.from_json(data)
except InvalidJWKValue as e:
raise ValidationError(_('Invalid JWKSet: %s') % e)
return JWKSet.from_json(data)
except InvalidJWKValue:
raise ValidationError(_('Invalid JWKSet'))
def validate_jwkset(data):
parse_jwkset(json.dumps(data))
class OIDCProvider(BaseAuthenticator):
@ -190,7 +194,10 @@ class OIDCProvider(BaseAuthenticator):
@property
def jwkset(self):
if self.jwkset_json:
return JWKSet.from_json(json.dumps(self.jwkset_json))
try:
return parse_jwkset(json.dumps(self.jwkset_json))
except ValidationError:
pass
return None
def get_short_description(self):
@ -213,18 +220,25 @@ class OIDCProvider(BaseAuthenticator):
def clean(self):
super().clean()
if self.jwkset_url:
try:
self.jwkset_json = self.load_jwkset_url()
except ValidationError as e:
raise ValidationError({'jwkset_url': e})
if self.idtoken_algo not in (self.ALGO_NONE, self.ALGO_HMAC):
key_sig_mapping = {
self.ALGO_RSA: 'RSA',
self.ALGO_EC: 'EC',
}
if not self.jwkset_json:
jwkset = self.jwkset
if not jwkset:
raise ValidationError(
_('Provider signature method is %s yet no jwkset was provided.')
% key_sig_mapping[self.idtoken_algo]
)
# verify that a key is available for the chosen algorithm
for key in self.jwkset:
for key in jwkset['keys']:
# compatibility with jwcrypto < 1
key_type = key.get('kty', None) if isinstance(key, dict) else key.key_type
if key_type == key_sig_mapping[self.idtoken_algo]:
@ -240,51 +254,46 @@ class OIDCProvider(BaseAuthenticator):
def save(self, *args, **kwargs):
if not self.ou:
self.ou = get_default_ou()
if self.jwkset_url:
self.set_jwkset_json_from_url()
if self.jwkset_url and not self.jwkset_json:
raise ValueError('model is not initialized')
if self.jwkset_json:
validate_jwkset(self.jwkset_json)
return super().save(*args, **kwargs)
def set_jwkset_json_from_url(self):
logger = logging.getLogger(__name__)
def load_jwkset_url(self):
try:
response = requests.get(
self.jwkset_url,
timeout=settings.REQUESTS_TIMEOUT,
)
response.raise_for_status()
except requests.RequestException:
logger.error('Unable to reach JWKSet content from URL %s', self.jwkset_url)
response = http.get(self.jwkset_url)
except http.HTTPError as e:
raise ValidationError(_('JWKSet URL is unreachable: %s') % e)
return parse_jwkset(response.content).export(as_dict=True)
def refresh_jwkset_json(self):
if not self.jwkset_url:
return
if not hasattr(response, 'json'):
logger.error('JWKSet URL %s is not JSON', self.jwkset_url)
return
try:
json_value = response.json()
except json.JSONDecodeError:
logger.error('JWKSet from URL %s is invalid', self.jwkset_url)
return
if not isinstance(json_value, dict) or 'keys' not in json_value:
logger.error('JWKSet from URL %s does not contain a \'keys\' entry', self.jwkset_url)
return
if self.jwkset_json != json_value:
old_keyset = {key.get('kid') for key in (self.jwkset_json or dict()).get('keys', [])}
new_keyset = {key.get('kid') for key in json_value.get('keys', [])}
logger.debug(
'provider %s renewed its JWKSet with new keys [%s] whereas old keys [%s] are now deprecated',
self,
', '.join(new_keyset - old_keyset) or '',
', '.join(old_keyset - new_keyset) or '',
)
journal.record(
'provider.keyset.change',
provider=self.name,
new_keyset=new_keyset,
old_keyset=old_keyset,
)
old_jwkset = self.jwkset_json
new_jwkset = self.load_jwkset_url()
# JSON is checked as part of attribute jwkset_json validation
self.jwkset_json = json_value
if old_jwkset == new_jwkset:
return
with transaction.atomic():
self.jwkset_json = new_jwkset
self.log_jwkset_change(old_jwkset, new_jwkset)
self.save(update_fields=['jwkset_json', 'modified'])
def log_jwkset_change(self, old_jwkset, new_jwkset):
if old_jwkset == new_jwkset:
return
old_keyset = {key.get('kid') for key in (old_jwkset or dict()).get('keys', [])}
new_keyset = {key.get('kid') for key in new_jwkset.get('keys', [])}
journal.record(
'provider.keyset.change',
provider=self.name,
new_keyset=new_keyset,
old_keyset=old_keyset,
)
def authorization_claims_parameter(self):
idtoken_claims = {}

View File

@ -0,0 +1,37 @@
# authentic2 - versatile identity manager
# Copyright (C) Entr'ouvert
import pytest
import responses
from jwcrypto.jwk import JWK, JWKSet
KID_RSA = '1e9gdk7'
KID_EC = 'jb20Cg8'
@pytest.fixture
def kid_rsa():
return KID_RSA
@pytest.fixture
def kid_ec():
return KID_EC
@pytest.fixture
def jwkset(kid_rsa, kid_ec):
key_rsa = JWK.generate(kty='RSA', size=512, kid=kid_rsa)
key_ec = JWK.generate(kty='EC', size=256, kid=kid_ec)
jwkset = JWKSet()
jwkset.add(key_rsa)
jwkset.add(key_ec)
return jwkset
@responses.activate
@pytest.fixture
def jwkset_url(jwkset):
jwkset_url = 'https://www.example.com/common/discovery/v3.0/keys'
responses.get(jwkset_url, json=jwkset.export(as_dict=True))
yield jwkset_url

View File

@ -231,7 +231,7 @@ def test_auth_oidc_refresh_jwkset_json(db, app, admin, settings, caplog):
)
issuer = ('https://www.example.com',)
provider = OIDCProvider.objects.create(
provider = OIDCProvider(
ou=get_default_ou(),
name='Foo',
slug='foo',
@ -249,12 +249,15 @@ def test_auth_oidc_refresh_jwkset_json(db, app, admin, settings, caplog):
claims_parameter_supported=False,
button_label='Connect with Foo',
)
provider.clean()
provider.save()
assert {key['kid'] for key in provider.jwkset_json['keys']} == {'123', '456'}
kid_rsa = 'abcdefg'
kid_ec = 'hijklmn'
responses.get(
responses.replace(
responses.GET,
jwkset_url,
json={
'headers': {

View File

@ -19,7 +19,6 @@ import json
import pytest
import responses
from django.utils.html import escape
from jwcrypto.jwk import JWK, JWKSet
from webtest import Upload
from authentic2.a2_rbac.models import Role
@ -35,7 +34,7 @@ from .test_misc import oidc_provider, oidc_provider_jwkset # pylint: disable=un
@pytest.mark.freeze_time('2022-04-19 14:00')
@responses.activate
def test_authenticators_oidc(app, superuser, ou1, ou2):
def test_authenticators_oidc(app, superuser, ou1, ou2, jwkset_url, kid_rsa):
resp = login(app, superuser, path='/manage/authenticators/')
resp = resp.click('Add new authenticator')
@ -68,11 +67,12 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
resp.form['authorization_endpoint'] = 'https://oidc.example.com/authorize'
resp.form['token_endpoint'] = 'https://oidc.example.com/token'
resp.form['userinfo_endpoint'] = 'https://oidc.example.com/user_info'
resp.form['idtoken_algo'] = 2
resp.form['button_label'] = 'Test'
resp.form['button_description'] = 'test'
resp.form['client_id'] = 'auie'
resp.form['client_secret'] = 'tsrn'
resp.form['idtoken_algo'].select(text='RSA')
resp.form['jwkset_url'] = jwkset_url
resp = resp.form.submit().follow()
assert_event('authenticator.edit', user=superuser, session=app.session)
@ -90,33 +90,25 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
assert_event('authenticator.enable', user=superuser, session=app.session)
resp = resp.click('Journal of edits')
assert 'enable' in resp.text
assert (
'edit (ou, issuer, scopes, strategy, client_id, button_label, idtoken_algo, '
'client_secret, token_endpoint, userinfo_endpoint, button_description, authorization_endpoint)'
in resp.text
)
assert 'creation' in resp.text
jwkset_url = 'https://www.example.com/common/discovery/v3.0/keys'
kid_rsa = '123'
def generate_remote_jwkset_json():
key_rsa = JWK.generate(kty='RSA', size=512, kid=kid_rsa)
jwkset = JWKSet()
jwkset.add(key_rsa)
return jwkset.export(as_dict=True)
responses.get(
jwkset_url,
json={
'headers': {
'content-type': 'application/json',
},
'status_code': 200,
**generate_remote_jwkset_json(),
},
)
assert resp.pyquery('.journal-list--message-column:contains("creation")')
assert resp.pyquery('.journal-list--message-column:contains("enable")')
edit_message = resp.pyquery('.journal-list--message-column:contains("edit")').text()
terms = {term.strip(',').strip('(').strip(')') for term in edit_message.split()}
assert terms == {
'edit',
'ou',
'issuer',
'scopes',
'strategy',
'client_id',
'button_label',
'client_secret',
'token_endpoint',
'userinfo_endpoint',
'button_description',
'jwkset_url',
'authorization_endpoint',
}
provider.refresh_from_db()
provider.jwkset_url = jwkset_url
@ -125,7 +117,7 @@ def test_authenticators_oidc(app, superuser, ou1, ou2):
resp = app.get('/manage/authenticators/%s/edit/' % provider.pk)
assert resp.pyquery('input#id_jwkset_url')[0].value == jwkset_url
assert 'disabled' in resp.pyquery('textarea#id_jwkset_json')[0].keys()
assert '"kid": "123"' in resp.pyquery('textarea#id_jwkset_json')[0].text
assert f'"kid": "{kid_rsa}"' in resp.pyquery('textarea#id_jwkset_json')[0].text
assert (
resp.pyquery('div[aria-labelledby="id_jwkset_json_title"] div.hint')[0].text
== 'JSON is fetched from the WebKey Set URL'

View File

@ -54,10 +54,15 @@ from authentic2_auth_oidc.utils import IDToken, IDTokenError, parse_id_token, re
from authentic2_auth_oidc.views import oidc_login
from tests import utils
from .conftest import KID_EC, KID_RSA
pytestmark = pytest.mark.django_db
User = get_user_model()
ANOTHER_KID_RSA = 'mt80xpd'
ANOTHER_KID_EC = 'iet7tm31'
def test_base64url_decode():
with pytest.raises(ValueError):
@ -65,10 +70,6 @@ def test_base64url_decode():
base64url_decode('aa')
KID_RSA = '1e9gdk7'
ANOTHER_KID_RSA = 'mt80xpd'
KID_EC = 'jb20Cg8'
ANOTHER_KID_EC = 'iet7tm31'
JWKSET_URL = 'https://www.example.com/common/discovery/v3.0/keys'
header_rsa_decoded = {'alg': 'RS256', 'kid': KID_RSA}
header_ec_decoded = {'alg': 'ES256', 'kid': KID_EC}
@ -121,12 +122,7 @@ def test_idtoken(oidc_provider):
@pytest.fixture
def oidc_provider_jwkset():
key_rsa = JWK.generate(kty='RSA', size=512, kid=KID_RSA)
key_ec = JWK.generate(kty='EC', size=256, kid=KID_EC)
jwkset = JWKSet()
jwkset.add(key_rsa)
jwkset.add(key_ec)
def oidc_provider_jwkset(jwkset):
return jwkset
@ -403,7 +399,7 @@ def test_oidc_provider_jwkset_url(db):
with HTTMock(jwkset_url_mock):
issuer = ('https://www.example.com',)
provider = OIDCProvider.objects.create(
provider = OIDCProvider(
ou=get_default_ou(),
name='Foo',
slug='foo',
@ -421,8 +417,8 @@ def test_oidc_provider_jwkset_url(db):
claims_parameter_supported=False,
button_label='Connect with Foo',
)
assert provider.jwkset_json
assert isinstance(provider.jwkset_json, dict)
provider.clean()
provider.save()
assert provider.jwkset
assert len(provider.jwkset_json['keys']) == 2
assert {key['kid'] for key in provider.jwkset_json['keys']} == {ANOTHER_KID_RSA, ANOTHER_KID_EC}