misc: add Service.home_url field (#61735)

This commit is contained in:
Benjamin Dauvergne 2022-02-16 17:25:05 +01:00
parent fa86daeaab
commit 2c95dddb83
10 changed files with 56 additions and 17 deletions

View File

@ -20,7 +20,7 @@ from pkg_resources import get_distribution
from . import app_settings, constants
from .models import Service
from .utils import misc as utils_misc
from .utils.service import get_service
from .utils.service import get_home_url, get_service
class UserFederations:
@ -82,10 +82,5 @@ def home(request):
ctx['home_service'] = service
if service.ou:
ctx['home_ou'] = service.ou
if request.session.get('home_url'):
ctx['home_url'] = request.session['home_url']
elif service and service.ou and service.ou.home_url:
ctx['home_url'] = service.ou.home_url
else:
ctx['home_url'] = app_settings.A2_HOMEPAGE_URL or settings.LOGIN_REDIRECT_URL
ctx['home_url'] = get_home_url(request)
return ctx

View File

@ -0,0 +1,18 @@
# Generated by Django 2.2.23 on 2022-02-16 16:23
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentic2', '0034_attribute_required_on_login'),
]
operations = [
migrations.AddField(
model_name='service',
name='home_url',
field=models.URLField(blank=True, max_length=256, null=True, verbose_name='Home URL'),
),
]

View File

@ -381,6 +381,7 @@ class Service(models.Model):
unauthorized_url = models.URLField(
verbose_name=_('callback url when unauthorized'), max_length=256, null=True, blank=True
)
home_url = models.URLField(verbose_name=_('Home URL'), max_length=256, null=True, blank=True)
objects = managers.ServiceManager()
@ -452,7 +453,7 @@ class Service(models.Model):
return reverse('a2-manager-service', kwargs={'service_pk': self.pk})
def get_base_urls(self):
return []
return [self.home_url] if self.home_url else []
Service._meta.natural_key = [['slug', 'ou']]

View File

@ -87,6 +87,7 @@ class LibertyProviderForm(ModelForm):
'slug',
'ou',
'unauthorized_url',
'home_url',
'entity_id',
'entity_id_sha1',
'federation_source',

View File

@ -422,6 +422,9 @@ class LibertyProvider(Service):
self.clean()
self.save()
def get_base_urls(self):
return super().get_base_urls() + [self.entity_id]
class Meta:
ordering = ('service_ptr__name',)
verbose_name = _('SAML provider')

View File

@ -18,6 +18,7 @@ import urllib.parse
from django.apps import apps
from authentic2 import app_settings
from authentic2.utils.cache import GlobalCache
from authentic2.utils.misc import same_origin
@ -61,6 +62,8 @@ def set_home_url(request, url=None):
url = select_next_url(request, default=None)
if not url or not urllib.parse.urlparse(url).netloc:
# clean saved home_url
request.session.pop('home_url', None)
return
urls_map = _base_urls_map()
for base_url, (Model, pk) in urls_map.items():
@ -82,3 +85,17 @@ def get_service(request):
else:
request._service = None
return request._service
def get_home_url(request):
service = get_service(request)
if request.session.get('home_url'):
return request.session['home_url']
elif service and service.home_url:
return service.home_url
elif service and service.ou and service.ou.home_url:
return service.ou.home_url
elif request.user.is_authenticated and request.user.ou and request.user.ou.home_url:
return request.user.ou.home_url
else:
return app_settings.A2_HOMEPAGE_URL

View File

@ -90,6 +90,7 @@ class ServiceAdmin(admin.ModelAdmin):
'slug',
'ou',
'unauthorized_url',
'home_url',
'urls',
'identifier_attribute',
'proxy',

View File

@ -72,6 +72,9 @@ class Service(LogoutUrlAbstract, BaseService):
wanted.add(attribute.attribute_name)
return list(wanted)
def get_base_urls(self):
return super().get_base_urls() + [url for url in self.get_urls() if url]
def __str__(self):
return str(self.name)

View File

@ -221,6 +221,9 @@ class OIDCClient(Service):
raise NotImplementedError('unknown self.authorization_mode %s' % self.authorization_mode)
return sector_identifier
def get_base_urls(self):
return super().get_base_urls() + [url for url in self.redirect_uris.split() if url]
def __repr__(self):
return '<OIDCClient name:%r client_id:%r identifier_policy:%r>' % (
self.name,

View File

@ -23,7 +23,7 @@ def test_home(app, settings, simple_user, service):
utils.set_service(app, service)
settings.LOGIN_REDIRECT_URL = 'https://portal1/'
settings.A2_HOMEPAGE_URL = 'https://portal1/'
resp = app.get('/login/')
body = resp.pyquery('body')
@ -33,19 +33,16 @@ def test_home(app, settings, simple_user, service):
assert body.attr('data-home-ou-slug') == service.ou.slug
assert body.attr('data-home-ou-name') == service.ou.name
settings.A2_HOMEPAGE_URL = 'https://portal2/'
service.ou.home_url = 'https://portal2/'
service.ou.save()
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal2/'
service.ou.home_url = 'https://portal3/'
service.ou.save()
resp = app.get('/login/')
body = resp.pyquery('body')
assert body.attr('data-home-url') == 'https://portal3/'
# if user comes back from a different service, the information is updated
new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2')
new_service = Service.objects.create(
ou=get_default_ou(), slug='service2', name='Service2', home_url='https://portal3/'
)
utils.set_service(app, new_service)
resp = app.get('/login/')