authentic/src/authentic2_idp_oidc/models.py

439 lines
15 KiB
Python

# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import uuid
from importlib import import_module
from django.db import models
from django.core.validators import URLValidator
from django.core.exceptions import ValidationError, ImproperlyConfigured
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from django.utils import six
from django.utils.timezone import now
from django.utils.six.moves.urllib import parse as urlparse
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from authentic2.a2_rbac.models import OrganizationalUnit
from authentic2.models import Service
from . import utils, managers, app_settings
def generate_uuid():
return six.text_type(uuid.uuid4())
def validate_https_url(data):
errors = []
data = data.strip()
if not data:
return
for url in data.split():
try:
URLValidator(schemes=['http', 'https'])(url)
except ValidationError as e:
errors.append(e)
if errors:
raise ValidationError(errors)
def strip_words(data):
return u'\n'.join([url for url in data.split()])
class OIDCClient(Service):
POLICY_UUID = 1
POLICY_PAIRWISE = 2
POLICY_EMAIL = 3
POLICY_PAIRWISE_REVERSIBLE = 4
IDENTIFIER_POLICIES = [
(POLICY_UUID, _('uuid')),
(POLICY_PAIRWISE, _('pairwise unreversible')),
(POLICY_PAIRWISE_REVERSIBLE, _('pairwise reversible')),
(POLICY_EMAIL, _('email')),
]
ALGO_RSA = 1
ALGO_HMAC = 2
ALGO_EC = 3
ALGO_CHOICES = [
(ALGO_HMAC, _('HMAC')),
(ALGO_RSA, _('RSA')),
(ALGO_EC, _('EC')),
]
FLOW_AUTHORIZATION_CODE = 1
FLOW_IMPLICIT = 2
FLOW_RESOURCE_OWNER_CRED = 3
FLOW_CHOICES = [
(FLOW_AUTHORIZATION_CODE, _('authorization code')),
(FLOW_IMPLICIT, _('implicit/native')),
(FLOW_RESOURCE_OWNER_CRED, _('resource owner password credentials')),
]
AUTHORIZATION_MODE_BY_SERVICE = 1
AUTHORIZATION_MODE_BY_OU = 2
AUTHORIZATION_MODE_NONE = 3
AUTHORIZATION_MODES = [
(AUTHORIZATION_MODE_BY_SERVICE, _('authorization by service')),
(AUTHORIZATION_MODE_BY_OU, _('authorization by ou')),
(AUTHORIZATION_MODE_NONE, _('none')),
]
client_id = models.CharField(
max_length=255,
verbose_name=_('client id'),
unique=True,
default=generate_uuid)
client_secret = models.CharField(
max_length=255,
verbose_name=_('client secret'),
default=generate_uuid)
idtoken_duration = models.DurationField(
verbose_name=_('time during which the token is valid'),
blank=True,
null=True,
default=None)
access_token_duration = models.DurationField(
verbose_name=_('time during which the access token is valid'),
blank=True,
null=True,
default=None)
authorization_mode = models.PositiveIntegerField(
default=AUTHORIZATION_MODE_BY_SERVICE,
choices=AUTHORIZATION_MODES,
verbose_name=_('authorization mode'))
authorization_flow = models.PositiveIntegerField(
verbose_name=_('authorization flow'),
default=FLOW_AUTHORIZATION_CODE,
choices=FLOW_CHOICES)
redirect_uris = models.TextField(
verbose_name=_('redirect URIs'),
validators=[validate_https_url])
post_logout_redirect_uris = models.TextField(
verbose_name=_('post logout redirect URIs'),
blank=True,
default='',
validators=[validate_https_url])
sector_identifier_uri = models.URLField(
verbose_name=_('sector identifier URI'),
blank=True)
identifier_policy = models.PositiveIntegerField(
verbose_name=_('identifier policy'),
default=POLICY_PAIRWISE,
choices=IDENTIFIER_POLICIES)
scope = models.TextField(
verbose_name=_('resource owner credentials grant scope'),
help_text=_('Permitted or default scopes (for credentials grant)'),
default='',
blank=True)
idtoken_algo = models.PositiveIntegerField(
default=ALGO_HMAC,
choices=ALGO_CHOICES,
verbose_name=_('IDToken signature algorithm'))
has_api_access = models.BooleanField(
verbose_name=_('has API access'),
default=False)
frontchannel_logout_uri = models.URLField(
verbose_name=_('frontchannel logout URI'),
blank=True)
frontchannel_timeout = models.PositiveIntegerField(
verbose_name=_('frontchannel timeout'),
null=True,
blank=True)
authorizations = GenericRelation('OIDCAuthorization',
content_type_field='client_ct',
object_id_field='client_id')
# metadata
created = models.DateTimeField(
verbose_name=_('created'),
auto_now_add=True)
modified = models.DateTimeField(
verbose_name=_('modified'),
auto_now=True)
def clean(self):
self.redirect_uris = strip_words(self.redirect_uris)
self.post_logout_redirect_uris = strip_words(self.post_logout_redirect_uris)
if self.idtoken_algo in (OIDCClient.ALGO_RSA, OIDCClient.ALGO_EC):
try:
utils.get_jwkset()
except ImproperlyConfigured:
raise ValidationError(
_('You cannot use algorithm %(algorithm)s, setting A2_IDP_OIDC_JWKSET is not defined') %
{'algorithm': self.get_idtoken_algo_display()})
if self.identifier_policy in [self.POLICY_PAIRWISE, self.POLICY_PAIRWISE_REVERSIBLE]:
try:
self.get_sector_identifier()
except ValueError:
raise ValidationError(
_('Redirect URIs must have the same domain or you must define a '
'sector identifier URI if you want to use pairwise'
'identifiers'))
def get_wanted_attributes(self):
return self.oidcclaim_set.filter(name__isnull=False).values_list('value', flat=True)
def validate_redirect_uri(self, redirect_uri):
if len(redirect_uri) > app_settings.REDIRECT_URI_MAX_LENGTH:
raise ValueError('redirect_uri length > %s' % app_settings.REDIRECT_URI_MAX_LENGTH)
parsed_uri = urlparse.urlparse(redirect_uri)
for valid_redirect_uri in self.redirect_uris.split():
parsed_valid_uri = urlparse.urlparse(valid_redirect_uri)
if parsed_uri.scheme != parsed_valid_uri.scheme:
continue
if parsed_valid_uri.netloc.startswith('*'):
# globing on the left
netloc = parsed_valid_uri.netloc.lstrip('*')
if (parsed_uri.netloc != netloc
and not parsed_uri.netloc.endswith('.' + netloc)):
continue
elif parsed_uri.netloc != parsed_valid_uri.netloc:
continue
if parsed_valid_uri.path.endswith('*'):
path = parsed_valid_uri.path.rstrip('*').rstrip('/')
if (parsed_uri.path.rstrip('/') != path
and not parsed_uri.path.startswith(path + '/')):
continue
else:
if parsed_uri.path.rstrip('/') != parsed_valid_uri.path.rstrip('/'):
continue
if parsed_uri.query and (
parsed_valid_uri.query != parsed_uri.query and
parsed_valid_uri.query != '*'):
# xxx parameter validation
continue
if parsed_uri.fragment and (
parsed_valid_uri.fragment != parsed_uri.fragment and
parsed_valid_uri.fragment != '*'):
continue
return
raise ValueError('redirect_uri is not declared')
def scope_set(self):
return utils.scope_set(self.scope)
def get_sector_identifier(self):
if self.authorization_mode in (self.AUTHORIZATION_MODE_BY_SERVICE, self.AUTHORIZATION_MODE_NONE):
sector_identifier = None
if self.sector_identifier_uri:
sector_identifier = utils.url_domain(self.sector_identifier_uri)
else:
for redirect_uri in self.redirect_uris.split():
hostname = utils.url_domain(redirect_uri)
if sector_identifier is None:
sector_identifier = hostname
elif sector_identifier != hostname:
raise ValueError('all redirect_uri do not have the same hostname')
elif self.authorization_mode == self.AUTHORIZATION_MODE_BY_OU:
if not self.ou:
raise ValidationError(
_('OU-based authorization requires that the client be '
'within an OU.'))
sector_identifier = self.ou.slug
else:
raise NotImplementedError('unknown self.authorization_mode %s' % self.authorization_mode)
return sector_identifier
def __repr__(self):
return ('<OIDCClient name:%r client_id:%r identifier_policy:%r>' %
(self.name, self.client_id, self.get_identifier_policy_display()))
class OIDCAuthorization(models.Model):
client_ct = models.ForeignKey(
'contenttypes.ContentType',
verbose_name=_('client ct'),
on_delete=models.CASCADE)
client_id = models.PositiveIntegerField(
verbose_name=_('client id'))
client = GenericForeignKey('client_ct', 'client_id')
user = models.ForeignKey(
to=settings.AUTH_USER_MODEL,
verbose_name=_('user'),
on_delete=models.CASCADE)
scopes = models.TextField(
blank=False,
verbose_name=_('scopes'))
# metadata
created = models.DateTimeField(
verbose_name=_('created'),
auto_now_add=True)
expired = models.DateTimeField(
verbose_name=_('expire'))
objects = managers.OIDCExpiredManager()
def scope_set(self):
return utils.scope_set(self.scopes)
def __repr__(self):
return '<OIDCAuthorization client:%r user:%r scopes:%r>' % (
self.client_id and six.text_type(self.client),
self.user_id and six.text_type(self.user),
self.scopes)
class OIDCCode(models.Model):
uuid = models.CharField(
max_length=128,
verbose_name=_('uuid'),
default=generate_uuid)
client = models.ForeignKey(
to=OIDCClient,
verbose_name=_('client'),
on_delete=models.CASCADE)
user = models.ForeignKey(
to=settings.AUTH_USER_MODEL,
verbose_name=_('user'),
on_delete=models.CASCADE)
scopes = models.TextField(
verbose_name=_('scopes'))
state = models.TextField(
null=True,
verbose_name=_('state'))
nonce = models.TextField(
null=True,
verbose_name=_('nonce'))
redirect_uri = models.TextField(
verbose_name=_('redirect URI'),
validators=[URLValidator()])
session_key = models.CharField(
verbose_name=_('session key'),
max_length=128)
auth_time = models.DateTimeField(
verbose_name=_('auth time'))
# metadata
created = models.DateTimeField(
verbose_name=_('created'),
auto_now_add=True)
expired = models.DateTimeField(
verbose_name=_('expire'))
objects = managers.OIDCExpiredManager()
@property
def session(self):
if not hasattr(self, '_session'):
engine = import_module(settings.SESSION_ENGINE)
session = engine.SessionStore(session_key=self.session_key)
session.load()
if session._session_key == self.session_key:
self._session = session
return getattr(self, '_session', None)
def scope_set(self):
return utils.scope_set(self.scopes)
def is_valid(self):
return self.expired >= now() and self.session is not None
def __repr__(self):
return '<OIDCCode uuid:%s client:%s user:%s expired:%s scopes:%s>' % (
self.uuid,
self.client_id and six.text_type(self.client),
self.user_id and six.text_type(self.user),
self.expired,
self.scopes)
class OIDCAccessToken(models.Model):
uuid = models.CharField(
max_length=128,
verbose_name=_('uuid'),
default=generate_uuid)
client = models.ForeignKey(
to=OIDCClient,
verbose_name=_('client'),
on_delete=models.CASCADE)
user = models.ForeignKey(
to=settings.AUTH_USER_MODEL,
verbose_name=_('user'),
on_delete=models.CASCADE)
scopes = models.TextField(
verbose_name=_('scopes'))
session_key = models.CharField(
verbose_name=_('session key'),
max_length=128,
blank=True)
# metadata
created = models.DateTimeField(
verbose_name=_('created'),
auto_now_add=True)
expired = models.DateTimeField(
verbose_name=_('expire'))
objects = managers.OIDCExpiredManager()
def scope_set(self):
return utils.scope_set(self.scopes)
@property
def session(self):
if not hasattr(self, '_session'):
engine = import_module(settings.SESSION_ENGINE)
session = engine.SessionStore(session_key=self.session_key)
if session._session_key == self.session_key:
self._session = session
return getattr(self, '_session', None)
def is_valid(self):
return self.expired >= now() and self.session is not None
def __repr__(self):
return '<OIDCAccessToken uuid:%s client:%s user:%s expired:%s scopes:%s>' % (
self.uuid,
self.client_id and six.text_type(self.client),
self.user_id and six.text_type(self.user),
self.expired,
self.scopes)
# Add generic field to a2_rbac.OrganizationalUnit
GenericRelation('authentic2_idp_oidc.OIDCAuthorization',
content_type_field='client_ct',
object_id_field='client_id').contribute_to_class(
OrganizationalUnit, 'oidc_authorizations')
class OIDCClaim(models.Model):
client = models.ForeignKey(
to=OIDCClient,
verbose_name=_('client'),
on_delete=models.CASCADE)
name = models.CharField(
max_length=128, blank=True,
verbose_name=_('attribute name'))
value = models.CharField(
max_length=128, blank=True,
verbose_name=_('attribute value'))
scopes = models.CharField(
max_length=128, blank=True,
verbose_name=_('attribute scopes'))
def __str__(self):
return u'%s - %s - %s' % (self.name, self.value, self.scopes)
def get_scopes(self):
return self.scopes.strip().split(',')