passerelle/passerelle/apps/ldap/models.py

399 lines
15 KiB
Python

# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2022 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import base64
import contextlib
import ldap
import ldap.filter
from django.core.cache import cache
from django.core.exceptions import ValidationError
from django.db import models
from django.utils.html import format_html
from django.utils.translation import gettext_lazy as _
from OpenSSL import crypto
from passerelle.base.models import BaseResource
from passerelle.utils.api import endpoint
from passerelle.utils.jsonresponse import APIError
from passerelle.utils.models import LDAPURLField, resource_file_upload_to
from passerelle.utils.templates import render_to_string
from . import forms
LDAP_HAS_OPT_X_TLS_REQUIRE_SAN = hasattr(ldap, 'OPT_X_TLS_REQUIRE_SAN') # only in python-ldap >= 3.4.0
SEARCH_OP_SUBSTRING = 'substring'
SEARCH_OP_PREFIX = 'prefix'
SEARCH_OP_APPROX = 'approx'
SEARCH_OP_EXACT = 'exact'
class Resource(BaseResource):
ldap_url = LDAPURLField(verbose_name=_('Server URL'), max_length=512)
ldap_bind_dn = models.CharField(verbose_name=_('Bind DN'), max_length=256, null=True, blank=True)
ldap_bind_password = models.CharField(
verbose_name=_('Bind password'), max_length=128, null=True, blank=True
)
ldap_tls_cert = models.FileField(
verbose_name=_('TLS client certificate'),
upload_to=resource_file_upload_to,
null=True,
blank=True,
validators=[forms.validate_certificate],
)
ldap_tls_key = models.FileField(
verbose_name=_('TLS client key'),
upload_to=resource_file_upload_to,
null=True,
blank=True,
validators=[forms.validate_private_key],
)
ldap_tls_cacert = models.FileField(
verbose_name=_('TLS trusted certificate'),
upload_to=resource_file_upload_to,
null=True,
blank=True,
validators=[forms.validate_certificate],
)
ldap_tls_check_hostname = models.BooleanField(
verbose_name=_('TLS check hostname'),
default=True,
blank=True,
help_text=None
if LDAP_HAS_OPT_X_TLS_REQUIRE_SAN
else _('Warning: this option is actually not supported (python-ldap < 3.4)'),
)
ldap_tls_check_cert = models.BooleanField(
verbose_name=_('TLS check certificate'),
default=True,
blank=True,
)
category = _('Misc')
class Meta:
verbose_name = _('LDAP')
def tls_cert(self, value):
if not value.name:
return None
with value as fd:
content = fd.read()
try:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, content)
name = ','.join(
'%s=%s' % (a.decode(), b.decode()) for a, b in cert.get_subject().get_components()
)
except Exception:
name = ('%s bytes') % len(content)
return format_html(
'<a href="data:application/octet-string;base64,{}" target="_blank" download="tls.crt">{}<a/>',
base64.b64encode(content).decode(),
name,
)
def clean(self):
if bool(self.ldap_bind_dn) != bool(self.ldap_bind_password):
raise ValidationError('Bind DN and password must be set together.')
if bool(self.ldap_tls_cert.name) != bool(self.ldap_tls_key.name):
raise ValidationError('Client certificate and key must be set together.')
def get_description_fields(self):
fields = super().get_description_fields()
fields = [
(field, self.tls_cert(value) if field.name == 'ldap_tls_cert' else value)
for field, value in fields
]
return fields
def check_status(self):
with self.get_connection() as conn:
conn.whoami_s()
@contextlib.contextmanager
def get_connection(self):
conn = ldap.initialize(self.ldap_url)
conn.set_option(ldap.OPT_TIMEOUT, 5)
conn.set_option(ldap.OPT_NETWORK_TIMEOUT, 5)
if LDAP_HAS_OPT_X_TLS_REQUIRE_SAN:
if self.ldap_tls_check_hostname:
conn.set_option(ldap.OPT_X_TLS_REQUIRE_SAN, ldap.OPT_X_TLS_DEMAND)
else:
conn.set_option(ldap.OPT_X_TLS_REQUIRE_SAN, ldap.OPT_X_TLS_NEVER)
if self.ldap_tls_check_cert:
conn.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND)
else:
conn.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER)
if self.ldap_tls_cert.name and self.ldap_tls_key.name:
conn.set_option(ldap.OPT_X_TLS_CERTFILE, self.ldap_tls_cert.path)
conn.set_option(ldap.OPT_X_TLS_KEYFILE, self.ldap_tls_key.path)
if self.ldap_tls_cacert.name:
conn.set_option(ldap.OPT_X_TLS_CACERTFILE, self.ldap_tls_cacert.path)
conn.set_option(ldap.OPT_X_TLS_NEWCTX, 0)
if self.ldap_bind_dn:
conn.simple_bind_s(self.ldap_bind_dn, self.ldap_bind_password or '')
else:
conn.simple_bind_s()
yield conn
conn.unbind()
def ldap_search(self, base_dn, scope, ldap_filter, ldap_attributes, sizelimit=-1, timeout=5):
with self.get_connection() as conn:
message_id = conn.search_ext(
base_dn, scope, ldap_filter, ldap_attributes, timeout=timeout, sizelimit=sizelimit
)
while True:
try:
dummy, entries = conn.result(message_id, all=0)
except ldap.SIZELIMIT_EXCEEDED:
break
if not entries:
break
for dn, attributes in entries:
if dn:
decoded_attributes = cidict()
# decode values to unicode, if possible, and keep only the first value
for k, values in attributes.items():
decoded_values = []
for value in values:
try:
decoded_values.append(value.decode())
except UnicodeDecodeError:
pass
if decoded_values:
if len(decoded_values) == 1:
decoded_attributes[k] = decoded_values[0]
else:
decoded_attributes[k] = decoded_values
yield dn, decoded_attributes
def search(
self,
ldap_base_dn,
scope,
ldap_filter,
ldap_attributes,
sizelimit,
id_attribute,
search_attribute,
text_template,
):
ldap_attributes = tuple(sorted(ldap_attributes))
cache_fingerprint = str(
[
ldap_base_dn,
scope,
ldap_filter,
ldap_attributes,
sizelimit,
id_attribute,
search_attribute,
text_template,
]
)
cache_key = f'ldap-{self.id}-{hash(cache_fingerprint)}'
cache_value = cache.get(cache_key)
if cache_value and cache_value[0] == cache_fingerprint:
return {'data': cache_value[1]}
try:
entries = list(
self.ldap_search(ldap_base_dn, scope, ldap_filter, ldap_attributes, sizelimit=sizelimit)
)
except ldap.LDAPError as e:
# add a disabled entry to show something on search errors, with display_disabled_items on w.c.s.
return {
'err': 1,
'data': [
{
'id': '',
'text': _('Directory server is unavailable'),
'disabled': True,
}
],
'err_class': 'directory-server-unavailable',
'err_desc': str(e),
}
data = []
for dn, attributes in entries:
entry_id = attributes.get(id_attribute)
if not entry_id:
continue
if text_template:
entry_text = render_to_string(text_template, attributes)
else:
entry_text = attributes.get(search_attribute)
data.append(
{
'id': entry_id,
'text': entry_text,
'dn': dn,
'attributes': attributes,
}
)
data.sort(key=lambda x: (x['text'], x['id']))
cache.set(cache_key, (cache_fingerprint, data))
return {'data': data}
@endpoint(
description=_('Search'),
name='search',
perm='can_access',
parameters={
'ldap_base_dn': {
'description': _('Base DN for the LDAP search'),
'example_value': 'dc=company,dc=com',
},
'search_attribute': {
'description': _('Attribute to search for the substring search'),
'example_value': 'cn',
},
'id_attribute': {
'description': _('Attribute used as a unique identifier'),
'example_value': 'uid',
},
'text_template': {
'description': _(
'Optional template string based on LDAP attributes '
'to create a text value, if none given the search_attribute is used'
),
'example_value': '{{ givenName }} {{ surname }}',
},
'ldap_attributes': {
'description': _('Space separated list of LDAP attributes to retrieve'),
'example_value': 'l sn givenName locality',
},
'id': {
'description': _('Identifier for exact retrieval, using the id_attribute'),
'example_value': 'johndoe',
},
'q': {
'description': _('Substring to search in the search_attribute'),
'example_value': 'John Doe',
},
'sizelimit': {
'description': _('Maximum number of entries to retrieve, between 1 and 200, default is 30.')
},
'scope': {
'description': _('Scope of the LDAP search, subtree or onelevel, default is subtree.'),
},
'filter': {
'description': _('Extra LDAP filter.'),
'example_value': 'objectClass=*',
},
'search_op': {
'description': _(
'Search operator, can be "substring" (the default value), "prefix", "approx" or "exact"'
),
'example_value': SEARCH_OP_SUBSTRING,
},
},
)
def search_endpoint(
self,
request,
ldap_base_dn,
id_attribute,
search_attribute=None,
text_template=None,
ldap_attributes=None,
id=None,
q=None,
sizelimit=None,
scope=None,
filter=None,
search_op=SEARCH_OP_SUBSTRING,
):
if not q and not id and not filter:
raise APIError('filter or q or id are mandatory parameters', http_status=400)
if q and not search_attribute:
raise APIError('search_attribute is mandatory with q parameter', http_status=400)
if not search_attribute and not text_template:
raise APIError('search_attribute or text_template are mandatory parameters', http_status=400)
if search_attribute:
search_attribute = search_attribute.lower()
if not search_attribute.isascii():
raise APIError('search_attribute contains non ASCII characters', http_status=400)
id_attribute = id_attribute.lower()
if not id_attribute.isascii():
raise APIError('id_attribute contains non ASCII characters', http_status=400)
ldap_attributes = set(ldap_attributes.split()) if ldap_attributes else set()
ldap_attributes.add(id_attribute)
if search_attribute:
ldap_attributes.add(search_attribute)
if not all(attribute.isascii() for attribute in ldap_attributes):
raise APIError('ldap_attributes contains non ASCII characters', http_status=400)
try:
sizelimit = int(sizelimit)
except (ValueError, TypeError):
pass
sizelimit = max(1, min(sizelimit or 30, 200))
ldap_filter = None
if id:
ldap_filter = '(%s=%s)' % (id_attribute, ldap.filter.escape_filter_chars(id))
elif q:
if search_op == SEARCH_OP_SUBSTRING:
ldap_filter = '(%s=*%s*)' % (search_attribute, ldap.filter.escape_filter_chars(q))
elif search_op == SEARCH_OP_PREFIX:
ldap_filter = '(%s=%s*)' % (search_attribute, ldap.filter.escape_filter_chars(q))
elif search_op == SEARCH_OP_APPROX:
ldap_filter = '(%s~=%s)' % (search_attribute, ldap.filter.escape_filter_chars(q))
elif search_op == SEARCH_OP_EXACT:
ldap_filter = '(%s=%s)' % (search_attribute, ldap.filter.escape_filter_chars(q))
else:
raise APIError('unknown search_op %r' % search_op, http_status=400)
if filter:
if not filter.startswith('('):
filter = '(%s)' % filter
if ldap_filter:
ldap_filter = '(&%s%s)' % (ldap_filter, filter)
else:
ldap_filter = filter
scopes = {
'subtree': ldap.SCOPE_SUBTREE,
'onelevel': ldap.SCOPE_ONELEVEL,
}
scope = scopes.get(scope, ldap.SCOPE_SUBTREE)
return self.search(
ldap_base_dn=ldap_base_dn,
scope=scope,
ldap_filter=ldap_filter,
ldap_attributes=ldap_attributes,
sizelimit=sizelimit,
id_attribute=id_attribute,
search_attribute=search_attribute,
text_template=text_template,
)
# use a case-insensitive dictionnary to handle map of attribute to values.
class cidict(dict):
'''Case insensitive dictionnary'''
def __setitem__(self, key, value):
super().__setitem__(key.lower(), value)
def __getitem__(self, key):
return super().__getitem__(key.lower())
def __contains__(self, key):
return super().__contains__(key.lower())
def get(self, key, default=None):
return super().get(key.lower(), default)