auth_oidc: add an oidc-sync-provider command (#62710)
gitea/authentic/pipeline/head Something is wrong with the build of this commit
Details
gitea/authentic/pipeline/head Something is wrong with the build of this commit
Details
This commit is contained in:
parent
2aeb5bad51
commit
73ac9f079a
|
@ -0,0 +1,45 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) 2010-2022 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import logging
|
||||
|
||||
from authentic2.base_commands import LogToConsoleCommand
|
||||
from authentic2_auth_oidc.models import OIDCProvider
|
||||
|
||||
|
||||
class Command(LogToConsoleCommand):
|
||||
loggername = 'authentic2_auth_oidc.models'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('--provider', type=str, default=None)
|
||||
|
||||
def core_command(self, *args, **kwargs):
|
||||
provider = kwargs['provider']
|
||||
|
||||
logger = logging.getLogger(self.loggername)
|
||||
providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
|
||||
if provider:
|
||||
providers = providers.filter(slug=provider)
|
||||
if not providers.count():
|
||||
logger.error('no provider supporting synchronization found, exiting')
|
||||
return
|
||||
logger.info(
|
||||
'got %s provider(s): %s',
|
||||
providers.count(),
|
||||
' '.join(providers.values_list('slug', flat=True)),
|
||||
)
|
||||
for provider in providers:
|
||||
provider.perform_synchronization()
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 2.2.26 on 2022-08-03 09:30
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('authentic2_auth_oidc', '0016_auto_20221019_1148'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='oidcprovider',
|
||||
name='a2_synchronization_supported',
|
||||
field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='oidcprovider',
|
||||
name='last_sync_time',
|
||||
field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
|
||||
),
|
||||
]
|
|
@ -15,12 +15,16 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import django
|
||||
import requests
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.shortcuts import render
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import pgettext_lazy
|
||||
from jwcrypto.jwk import InvalidJWKValue, JWKSet
|
||||
|
@ -32,7 +36,7 @@ from authentic2.apps.authenticators.models import (
|
|||
BaseAuthenticator,
|
||||
)
|
||||
from authentic2.utils.misc import make_url
|
||||
from authentic2.utils.template import validate_template
|
||||
from authentic2.utils.template import Template, validate_template
|
||||
|
||||
from . import managers
|
||||
|
||||
|
@ -120,6 +124,17 @@ class OIDCProvider(BaseAuthenticator):
|
|||
verbose_name=_('max authentication age'), blank=True, null=True
|
||||
)
|
||||
|
||||
# authentic2 specific synchronization api
|
||||
a2_synchronization_supported = models.BooleanField(
|
||||
verbose_name=_('Authentic2 synchronization supported'),
|
||||
default=False,
|
||||
)
|
||||
last_sync_time = models.DateTimeField(
|
||||
verbose_name=_('Last synchronization time'),
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
|
||||
# metadata
|
||||
created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True)
|
||||
modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True)
|
||||
|
@ -254,6 +269,83 @@ class OIDCProvider(BaseAuthenticator):
|
|||
]
|
||||
return render(request, template_names, context)
|
||||
|
||||
def perform_synchronization(self, sync_time=None, timeout=30):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if not self.a2_synchronization_supported:
|
||||
logger.error('OIDC provider %s does not support synchronization', self.slug)
|
||||
return
|
||||
if not sync_time:
|
||||
sync_time = now() - timedelta(minutes=1)
|
||||
|
||||
# check all existing users
|
||||
def chunks(l, n):
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i : i + n]
|
||||
|
||||
url = self.issuer + '/api/users/synchronization/'
|
||||
|
||||
unknown_uuids = []
|
||||
auth = (self.client_id, self.client_secret)
|
||||
for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
|
||||
subs = [x.sub for x in accounts]
|
||||
resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
unknown_uuids.extend(resp.json().get('unknown_uuids'))
|
||||
deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
|
||||
if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong
|
||||
logger.error(
|
||||
'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
|
||||
)
|
||||
else:
|
||||
OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
|
||||
|
||||
# update recently modified users
|
||||
url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
|
||||
self.last_sync_time or datetime.utcfromtimestamp(0)
|
||||
).strftime('%Y-%m-%dT%H:%M:%S')
|
||||
while url:
|
||||
resp = requests.get(url, auth=auth, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get('next')
|
||||
logger.info('got %s users', len(resp.json()['results']))
|
||||
for user_dict in resp.json()['results']:
|
||||
if not user_dict.get('sub', None):
|
||||
continue
|
||||
try:
|
||||
account = OIDCAccount.objects.get(sub=user_dict['sub'])
|
||||
except OIDCAccount.DoesNotExist:
|
||||
continue
|
||||
except OIDCAccount.MultipleObjectsReturned:
|
||||
continue
|
||||
had_changes = False
|
||||
for claim in self.claim_mappings.all():
|
||||
if '{{' in claim.claim or '{%' in claim.claim:
|
||||
template = Template(claim.claim)
|
||||
attribute_value = template.render(context=user_dict)
|
||||
else:
|
||||
attribute_value = user_dict.get(claim.claim)
|
||||
try:
|
||||
old_attribute_value = getattr(account.user, claim.attribute)
|
||||
except AttributeError:
|
||||
try:
|
||||
old_attribute_value = getattr(account.user.attributes, claim.attribute)
|
||||
except AttributeError:
|
||||
old_attribute_value = None
|
||||
if old_attribute_value == attribute_value:
|
||||
continue
|
||||
had_changes = True
|
||||
setattr(account.user, claim.attribute, attribute_value)
|
||||
try:
|
||||
setattr(account.user.attributes, claim.attribute, attribute_value)
|
||||
except AttributeError:
|
||||
pass
|
||||
if had_changes:
|
||||
logger.debug('had changes, saving %r', account.user)
|
||||
account.user.save()
|
||||
self.last_sync_time = sync_time
|
||||
self.save(update_fields=['last_sync_time'])
|
||||
|
||||
|
||||
class OIDCClaimMapping(AuthenticatorRelatedObjectBase):
|
||||
NOT_VERIFIED = 0
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
from io import BufferedReader, BufferedWriter, TextIOWrapper
|
||||
|
||||
import httmock
|
||||
import py
|
||||
import pytest
|
||||
import webtest
|
||||
|
@ -40,9 +43,10 @@ from authentic2.a2_rbac.utils import get_default_ou, get_operation
|
|||
from authentic2.apps.journal.models import Event
|
||||
from authentic2.custom_user.models import DeletedUser
|
||||
from authentic2.models import UserExternalId
|
||||
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
|
||||
from authentic2.utils import crypto
|
||||
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
|
||||
|
||||
from .utils import call_command, login
|
||||
from .utils import call_command, check_log, login
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
@ -520,3 +524,134 @@ def test_clean_user_exports(settings, app, superuser, freezer):
|
|||
call_command('clean-user-exports')
|
||||
with pytest.raises(webtest.app.AppError):
|
||||
resp.click('Download CSV')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
|
||||
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
|
||||
oidc_provider = OIDCProvider.objects.create(
|
||||
issuer='https://some.provider',
|
||||
name='Some Provider',
|
||||
slug='some-provider',
|
||||
ou=get_default_ou(),
|
||||
)
|
||||
OIDCClaimMapping.objects.create(
|
||||
authenticator=oidc_provider,
|
||||
attribute='username',
|
||||
idtoken_claim=False,
|
||||
claim='username',
|
||||
)
|
||||
OIDCClaimMapping.objects.create(
|
||||
authenticator=oidc_provider,
|
||||
attribute='email',
|
||||
idtoken_claim=False,
|
||||
claim='email',
|
||||
)
|
||||
# last one, with an idtoken claim
|
||||
OIDCClaimMapping.objects.create(
|
||||
authenticator=oidc_provider,
|
||||
attribute='last_name',
|
||||
idtoken_claim=True,
|
||||
claim='family_name',
|
||||
)
|
||||
# typo in template string
|
||||
OIDCClaimMapping.objects.create(
|
||||
authenticator=oidc_provider,
|
||||
attribute='first_name',
|
||||
idtoken_claim=True,
|
||||
claim='given_name',
|
||||
)
|
||||
User = get_user_model()
|
||||
for i in range(100):
|
||||
user = User.objects.create(
|
||||
first_name='John%s' % i,
|
||||
last_name='Doe%s' % i,
|
||||
username='john.doe.%s' % i,
|
||||
email='john.doe.%s@ad.dre.ss' % i,
|
||||
ou=get_default_ou(),
|
||||
)
|
||||
identifier = uuid.UUID(user.uuid).bytes
|
||||
sector_identifier = 'some.provider'
|
||||
cipher_args = [
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
identifier,
|
||||
sector_identifier,
|
||||
]
|
||||
sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
|
||||
OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
|
||||
|
||||
def synchronization_post_deletion_response(url, request):
|
||||
headers = {'content-type': 'application/json'}
|
||||
content = {
|
||||
'unknown_uuids': [
|
||||
account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
|
||||
]
|
||||
}
|
||||
return httmock.response(status_code=200, headers=headers, content=content, request=request)
|
||||
|
||||
def synchronization_get_modified_response(url, request):
|
||||
headers = {'content-type': 'application/json'}
|
||||
# randomized batch of modified users
|
||||
modified_users = random.sample(list(User.objects.all()), 20)
|
||||
results = []
|
||||
for count, user in enumerate(modified_users):
|
||||
user_json = user.to_json()
|
||||
user_json['username'] = f'modified_{count}'
|
||||
user_json['first_name'] = 'Mod'
|
||||
user_json['last_name'] = 'Ified'
|
||||
# mocking claim resolution by oidc provider
|
||||
user_json['given_name'] = 'Mod'
|
||||
user_json['family_name'] = 'Ified'
|
||||
|
||||
# add user sub to response
|
||||
try:
|
||||
account = OIDCAccount.objects.get(user=user)
|
||||
except OIDCAccount.DoesNotExist:
|
||||
pass
|
||||
else:
|
||||
user_json['sub'] = account.sub
|
||||
|
||||
results.append(user_json)
|
||||
content = {'results': results}
|
||||
return httmock.response(status_code=200, headers=headers, content=content, request=request)
|
||||
|
||||
with httmock.HTTMock(
|
||||
httmock.urlmatch(
|
||||
netloc=r'some\.provider',
|
||||
path=r'^/api/users/synchronization/$',
|
||||
method='POST',
|
||||
)(synchronization_post_deletion_response)
|
||||
):
|
||||
|
||||
with httmock.HTTMock(
|
||||
httmock.urlmatch(
|
||||
netloc=r'some\.provider',
|
||||
path=r'^/api/users/*',
|
||||
method='GET',
|
||||
)(synchronization_get_modified_response)
|
||||
):
|
||||
with check_log(caplog, 'no provider supporting synchronization'):
|
||||
call_command('oidc-sync-provider', '-v1')
|
||||
|
||||
oidc_provider.a2_synchronization_supported = True
|
||||
oidc_provider.save()
|
||||
|
||||
with check_log(caplog, 'no provider supporting synchronization'):
|
||||
call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
|
||||
|
||||
with check_log(caplog, 'got 20 users'):
|
||||
call_command('oidc-sync-provider', '-v1')
|
||||
if deletion_valid:
|
||||
# existing users check
|
||||
assert OIDCAccount.objects.count() == 100 - deletion_number
|
||||
else:
|
||||
assert OIDCAccount.objects.count() == 100
|
||||
assert caplog.records[3].levelname == 'ERROR'
|
||||
assert 'deletion ratio is abnormally high' in caplog.records[3].message
|
||||
|
||||
# users update
|
||||
assert User.objects.filter(username__startswith='modified').count() in range(
|
||||
20 - deletion_number, 21
|
||||
)
|
||||
assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
|
||||
20 - deletion_number, 21
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue