auth_oidc: add an oidc-sync-provider command (#62710)
gitea/authentic/pipeline/head Something is wrong with the build of this commit Details

This commit is contained in:
Paul Marillonnet 2022-05-30 17:59:15 +02:00
parent 2aeb5bad51
commit 73ac9f079a
4 changed files with 298 additions and 3 deletions

View File

@ -0,0 +1,45 @@
# authentic2 - versatile identity manager
# Copyright (C) 2010-2022 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
from authentic2.base_commands import LogToConsoleCommand
from authentic2_auth_oidc.models import OIDCProvider
class Command(LogToConsoleCommand):
loggername = 'authentic2_auth_oidc.models'
def add_arguments(self, parser):
parser.add_argument('--provider', type=str, default=None)
def core_command(self, *args, **kwargs):
provider = kwargs['provider']
logger = logging.getLogger(self.loggername)
providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
if provider:
providers = providers.filter(slug=provider)
if not providers.count():
logger.error('no provider supporting synchronization found, exiting')
return
logger.info(
'got %s provider(s): %s',
providers.count(),
' '.join(providers.values_list('slug', flat=True)),
)
for provider in providers:
provider.perform_synchronization()

View File

@ -0,0 +1,23 @@
# Generated by Django 2.2.26 on 2022-08-03 09:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentic2_auth_oidc', '0016_auto_20221019_1148'),
]
operations = [
migrations.AddField(
model_name='oidcprovider',
name='a2_synchronization_supported',
field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
),
migrations.AddField(
model_name='oidcprovider',
name='last_sync_time',
field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
),
]

View File

@ -15,12 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import logging
from datetime import datetime, timedelta
import django
import requests
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models
from django.shortcuts import render
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django.utils.translation import pgettext_lazy
from jwcrypto.jwk import InvalidJWKValue, JWKSet
@ -32,7 +36,7 @@ from authentic2.apps.authenticators.models import (
BaseAuthenticator,
)
from authentic2.utils.misc import make_url
from authentic2.utils.template import validate_template
from authentic2.utils.template import Template, validate_template
from . import managers
@ -120,6 +124,17 @@ class OIDCProvider(BaseAuthenticator):
verbose_name=_('max authentication age'), blank=True, null=True
)
# authentic2 specific synchronization api
a2_synchronization_supported = models.BooleanField(
verbose_name=_('Authentic2 synchronization supported'),
default=False,
)
last_sync_time = models.DateTimeField(
verbose_name=_('Last synchronization time'),
null=True,
blank=True,
)
# metadata
created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True)
modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True)
@ -254,6 +269,83 @@ class OIDCProvider(BaseAuthenticator):
]
return render(request, template_names, context)
def perform_synchronization(self, sync_time=None, timeout=30):
logger = logging.getLogger(__name__)
if not self.a2_synchronization_supported:
logger.error('OIDC provider %s does not support synchronization', self.slug)
return
if not sync_time:
sync_time = now() - timedelta(minutes=1)
# check all existing users
def chunks(l, n):
for i in range(0, len(l), n):
yield l[i : i + n]
url = self.issuer + '/api/users/synchronization/'
unknown_uuids = []
auth = (self.client_id, self.client_secret)
for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
subs = [x.sub for x in accounts]
resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout)
resp.raise_for_status()
unknown_uuids.extend(resp.json().get('unknown_uuids'))
deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong
logger.error(
'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
)
else:
OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
# update recently modified users
url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
self.last_sync_time or datetime.utcfromtimestamp(0)
).strftime('%Y-%m-%dT%H:%M:%S')
while url:
resp = requests.get(url, auth=auth, timeout=timeout)
resp.raise_for_status()
url = resp.json().get('next')
logger.info('got %s users', len(resp.json()['results']))
for user_dict in resp.json()['results']:
if not user_dict.get('sub', None):
continue
try:
account = OIDCAccount.objects.get(sub=user_dict['sub'])
except OIDCAccount.DoesNotExist:
continue
except OIDCAccount.MultipleObjectsReturned:
continue
had_changes = False
for claim in self.claim_mappings.all():
if '{{' in claim.claim or '{%' in claim.claim:
template = Template(claim.claim)
attribute_value = template.render(context=user_dict)
else:
attribute_value = user_dict.get(claim.claim)
try:
old_attribute_value = getattr(account.user, claim.attribute)
except AttributeError:
try:
old_attribute_value = getattr(account.user.attributes, claim.attribute)
except AttributeError:
old_attribute_value = None
if old_attribute_value == attribute_value:
continue
had_changes = True
setattr(account.user, claim.attribute, attribute_value)
try:
setattr(account.user.attributes, claim.attribute, attribute_value)
except AttributeError:
pass
if had_changes:
logger.debug('had changes, saving %r', account.user)
account.user.save()
self.last_sync_time = sync_time
self.save(update_fields=['last_sync_time'])
class OIDCClaimMapping(AuthenticatorRelatedObjectBase):
NOT_VERIFIED = 0

View File

@ -17,8 +17,11 @@
import datetime
import importlib
import json
import random
import uuid
from io import BufferedReader, BufferedWriter, TextIOWrapper
import httmock
import py
import pytest
import webtest
@ -40,9 +43,10 @@ from authentic2.a2_rbac.utils import get_default_ou, get_operation
from authentic2.apps.journal.models import Event
from authentic2.custom_user.models import DeletedUser
from authentic2.models import UserExternalId
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
from authentic2.utils import crypto
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
from .utils import call_command, login
from .utils import call_command, check_log, login
User = get_user_model()
@ -520,3 +524,134 @@ def test_clean_user_exports(settings, app, superuser, freezer):
call_command('clean-user-exports')
with pytest.raises(webtest.app.AppError):
resp.click('Download CSV')
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
oidc_provider = OIDCProvider.objects.create(
issuer='https://some.provider',
name='Some Provider',
slug='some-provider',
ou=get_default_ou(),
)
OIDCClaimMapping.objects.create(
authenticator=oidc_provider,
attribute='username',
idtoken_claim=False,
claim='username',
)
OIDCClaimMapping.objects.create(
authenticator=oidc_provider,
attribute='email',
idtoken_claim=False,
claim='email',
)
# last one, with an idtoken claim
OIDCClaimMapping.objects.create(
authenticator=oidc_provider,
attribute='last_name',
idtoken_claim=True,
claim='family_name',
)
# typo in template string
OIDCClaimMapping.objects.create(
authenticator=oidc_provider,
attribute='first_name',
idtoken_claim=True,
claim='given_name',
)
User = get_user_model()
for i in range(100):
user = User.objects.create(
first_name='John%s' % i,
last_name='Doe%s' % i,
username='john.doe.%s' % i,
email='john.doe.%s@ad.dre.ss' % i,
ou=get_default_ou(),
)
identifier = uuid.UUID(user.uuid).bytes
sector_identifier = 'some.provider'
cipher_args = [
settings.SECRET_KEY.encode('utf-8'),
identifier,
sector_identifier,
]
sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
def synchronization_post_deletion_response(url, request):
headers = {'content-type': 'application/json'}
content = {
'unknown_uuids': [
account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
]
}
return httmock.response(status_code=200, headers=headers, content=content, request=request)
def synchronization_get_modified_response(url, request):
headers = {'content-type': 'application/json'}
# randomized batch of modified users
modified_users = random.sample(list(User.objects.all()), 20)
results = []
for count, user in enumerate(modified_users):
user_json = user.to_json()
user_json['username'] = f'modified_{count}'
user_json['first_name'] = 'Mod'
user_json['last_name'] = 'Ified'
# mocking claim resolution by oidc provider
user_json['given_name'] = 'Mod'
user_json['family_name'] = 'Ified'
# add user sub to response
try:
account = OIDCAccount.objects.get(user=user)
except OIDCAccount.DoesNotExist:
pass
else:
user_json['sub'] = account.sub
results.append(user_json)
content = {'results': results}
return httmock.response(status_code=200, headers=headers, content=content, request=request)
with httmock.HTTMock(
httmock.urlmatch(
netloc=r'some\.provider',
path=r'^/api/users/synchronization/$',
method='POST',
)(synchronization_post_deletion_response)
):
with httmock.HTTMock(
httmock.urlmatch(
netloc=r'some\.provider',
path=r'^/api/users/*',
method='GET',
)(synchronization_get_modified_response)
):
with check_log(caplog, 'no provider supporting synchronization'):
call_command('oidc-sync-provider', '-v1')
oidc_provider.a2_synchronization_supported = True
oidc_provider.save()
with check_log(caplog, 'no provider supporting synchronization'):
call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
with check_log(caplog, 'got 20 users'):
call_command('oidc-sync-provider', '-v1')
if deletion_valid:
# existing users check
assert OIDCAccount.objects.count() == 100 - deletion_number
else:
assert OIDCAccount.objects.count() == 100
assert caplog.records[3].levelname == 'ERROR'
assert 'deletion ratio is abnormally high' in caplog.records[3].message
# users update
assert User.objects.filter(username__startswith='modified').count() in range(
20 - deletion_number, 21
)
assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
20 - deletion_number, 21
)