ldap: subclass LDAPObject to provide uniform unicode support (fixes #30577)

This commit is contained in:
Benjamin Dauvergne 2019-02-12 23:49:32 +01:00
parent 93a0935e49
commit bc6892289e
2 changed files with 222 additions and 69 deletions

View File

@ -4,10 +4,13 @@ try:
import ldap.sasl
from ldap.filter import filter_format
from ldap.dn import escape_dn_chars
from ldap.ldapobject import ReconnectLDAPObject
from ldap.ldapobject import ReconnectLDAPObject as NativeLDAPObject
from ldap.controls import SimplePagedResultsControl
PYTHON_LDAP3 = map(int, ldap.__version__.split('.')) >= [3]
LDAPObject = NativeLDAPObject
except ImportError:
ldap = None
PYTHON_LDAP3 = None
import logging
import random
import base64
@ -57,14 +60,153 @@ for bundle_path in CA_BUNDLE_PATHS:
DEFAULT_CA_BUNDLE = bundle_path
break
if PYTHON_LDAP3 is True:
class LDAPObject(NativeLDAPObject):
def __init__(self, uri, trace_level=0, trace_file=None,
trace_stack_limit=5, bytes_mode=None,
bytes_strictness=None, retry_max=1, retry_delay=60.0):
NativeLDAPObject.__init__(self, uri=uri, trace_level=trace_level,
trace_file=trace_file,
trace_stack_limit=trace_stack_limit,
bytes_mode=False, bytes_strictness=bytes_strictness,
retry_max=retry_max,
retry_delay=retry_delay)
def map_bytes(d):
if isinstance(d, six.string_types):
return force_bytes(d)
@to_list
def _convert_results_to_unicode(self, result_list):
for dn, attrs in result_list:
if dn is not None:
# tuple is a real entry with a DN not a search reference
attrs = {attribute: map(force_text, attrs[attribute]) for attribute in attrs}
yield dn, attrs
def search_s(self, base, scope, filterstr='(objectclass=*)', attrlist=None, attrsonly=0):
return self._convert_results_to_unicode(
NativeLDAPObject.search_s(self, base, scope,
filterstr=filterstr,
attrlist=attrlist,
attrsonly=attrsonly))
def modify_s(self, dn, modlist):
new_modlist = []
for mod_op, mod_typ, mod_vals in modlist:
def convert(v):
if hasattr(v, 'isnumeric'):
# unicode case
v = v.encode('utf-8')
return v
if mod_vals is None:
pass
elif isinstance(mod_vals, list):
mod_vals = map(convert, mod_vals)
else:
mod_vals = convert(mod_vals)
new_modlist.append((mod_op, mod_typ, mod_vals))
return NativeLDAPObject.modify_s(self, dn, new_modlist)
def result3(self, msgid=ldap.RES_ANY, all=1, timeout=None, resp_ctrl_classes=None):
result_type, data, msgid, serverctrls = NativeLDAPObject.result3(self,
msgid=msgid,
all=all,
timeout=timeout,
resp_ctrl_classes=resp_ctrl_classes)
if data:
data = self._convert_results_to_unicode(data)
return result_type, data, msgid, serverctrls
elif PYTHON_LDAP3 is False:
class LDAPObject(NativeLDAPObject):
def simple_bind_s(self, who='', cred='', serverctrls=None, clientctrls=None):
who = force_bytes(who)
cred = force_bytes(cred)
return NativeLDAPObject.simple_bind_s(self, who=who, cred=cred,
serverctrls=serverctrls,
clientctrls=clientctrls)
def passwd_s(self, dn, oldpw, newpw, serverctrls=None, clientctrls=None):
dn = force_bytes(dn)
oldpw = force_bytes(oldpw)
newpw = force_bytes(newpw)
return NativeLDAPObject.passwd_s(self, dn, oldpw, newpw,
serverctrls=serverctrls,
clientctrls=clientctrls)
@to_list
def _convert_results_to_unicode(self, result_list):
for dn, attrs in result_list:
if dn is not None:
# tuple is a real entry with a DN not a search reference
attrs = {attribute: map(force_text, attrs[attribute]) for attribute in attrs}
yield force_text(dn), attrs
def search_s(self, base, scope, filterstr='(objectclass=*)', attrlist=None, attrsonly=0):
base = force_bytes(base)
filterstr = force_bytes(filterstr)
if attrlist:
attrlist = map(force_bytes, attrlist)
return self._convert_results_to_unicode(
NativeLDAPObject.search_s(self, base, scope,
filterstr=filterstr,
attrlist=attrlist,
attrsonly=attrsonly))
def search_ext(self, base, scope, filterstr='(objectclass=*)',
attrlist=None, attrsonly=0, serverctrls=None,
clientctrls=None, timeout=-1, sizelimit=0):
base = force_bytes(base)
filterstr = force_bytes(filterstr)
if attrlist:
attrlist = map(force_bytes, attrlist)
return NativeLDAPObject.search_ext(self, base, scope,
filterstr=filterstr,
attrlist=attrlist,
attrsonly=attrsonly,
serverctrls=serverctrls,
clientctrls=clientctrls,
timeout=timeout,
sizelimit=sizelimit)
def modify_s(self, dn, modlist):
dn = force_bytes(dn)
new_modlist = []
for mod_op, mod_typ, mod_vals in modlist:
mod_typ = force_bytes(mod_typ)
def convert(v):
if hasattr(v, 'isnumeric'):
# unicode case
v = force_bytes(v)
return v
if mod_vals is None:
pass
elif isinstance(mod_vals, list):
mod_vals = map(convert, mod_vals)
else:
mod_vals = convert(mod_vals)
new_modlist.append((mod_op, mod_typ, mod_vals))
return NativeLDAPObject.modify_s(self, dn, new_modlist)
def result3(self, msgid=ldap.RES_ANY, all=1, timeout=None, resp_ctrl_classes=None):
result_type, data, msgid, serverctrls = NativeLDAPObject.result3(self,
msgid=msgid,
all=all,
timeout=timeout,
resp_ctrl_classes=resp_ctrl_classes)
if data:
data = self._convert_results_to_unicode(data)
return result_type, data, msgid, serverctrls
def map_text(d):
if d is None:
return d
elif isinstance(d, six.string_types):
return force_text(d)
elif isinstance(d, (list, tuple)):
return d.__class__(map_bytes(x) for x in d)
return d.__class__(map_text(x) for x in d)
elif isinstance(d, dict):
return {map_bytes(k): map_bytes(v) for k, v in d.iteritems()}
return {map_text(k): map_text(v) for k, v in d.iteritems()}
raise NotImplementedError
class LDAPUser(get_user_model()):
@ -168,7 +310,7 @@ class LDAPUser(get_user_model()):
def check_password(self, raw_password):
connection = self.ldap_backend.get_connection(self.block)
try:
connection.simple_bind_s(self.dn, force_bytes(raw_password))
connection.simple_bind_s(self.dn, raw_password)
except ldap.INVALID_CREDENTIALS:
return False
except ldap.LDAPError, e:
@ -371,35 +513,32 @@ class LDAPBackend(object):
return user
def authenticate_block(self, block, username, password):
utf8_username = force_bytes(username)
utf8_password = force_bytes(password)
for conn in self.get_connections(block):
authz_ids = []
user_basedn = block.get('user_basedn') or block['basedn']
user_basedn = force_text(block.get('user_basedn') or block['basedn'])
try:
if block['user_dn_template']:
template = force_bytes(block['user_dn_template'])
escaped_username = escape_dn_chars(utf8_username)
template = force_text(block['user_dn_template'])
escaped_username = escape_dn_chars(username)
authz_ids.append(template.format(username=escaped_username))
else:
try:
if block.get('bind_with_username'):
authz_ids.append(utf8_username)
authz_ids.append(username)
elif block['user_filter']:
# allow multiple occurences of the username in the filter
user_filter = block['user_filter']
user_filter = force_text(block['user_filter'])
n = len(user_filter.split('%s')) - 1
try:
query = filter_format(user_filter, (utf8_username,) * n)
query = filter_format(user_filter, (username,) * n)
except TypeError, e:
log.error('user_filter syntax error %r: %s', block['user_filter'],
e)
return
log.debug('looking up dn for username %r using query %r', username,
query)
results = conn.search_s(user_basedn, ldap.SCOPE_SUBTREE, query)
results = conn.search_s(user_basedn, ldap.SCOPE_SUBTREE, query, [])
# remove search references
results = [result for result in results if result[0] is not None]
log.debug('found dns %r', results)
@ -430,7 +569,7 @@ class LDAPBackend(object):
if failed:
continue
try:
conn.simple_bind_s(authz_id, utf8_password)
conn.simple_bind_s(authz_id, password)
user_login_success(authz_id)
if not block['connect_with_user_credentials']:
try:
@ -609,7 +748,7 @@ class LDAPBackend(object):
if member_of_attribute:
group_dns.update(attributes.get(member_of_attribute, []))
if group_filter:
group_filter = force_bytes(group_filter)
group_filter = force_text(group_filter)
params = attributes.copy()
params['user_dn'] = dn
query = FilterFormatter().format(group_filter, **params)
@ -757,47 +896,47 @@ class LDAPBackend(object):
@classmethod
def get_ldap_attributes_names(cls, block):
attributes = set()
attributes.update(map(str, block['attributes']))
attributes.update(map_text(block['attributes']))
for field in ('email_field', 'fname_field', 'lname_field',
'member_of_attribute'):
'member_of_attribute'):
if block[field]:
attributes.add(block[field])
for external_id_tuple in block['external_id_tuples']:
for external_id_tuple in map_text(block['external_id_tuples']):
attributes.update(cls.attribute_name_from_external_id_tuple(
external_id_tuple))
for from_at, to_at in block['attribute_mappings']:
for from_at, to_at in map_text(block['attribute_mappings']):
attributes.add(to_at)
for mapping in block['user_attributes']:
from_ldap = mapping.get('from_ldap')
if from_ldap:
attributes.add(from_ldap)
return list(set(map(str.lower, map(str, attributes))))
return list(set(attribute.lower() for attribute in attributes))
@classmethod
def get_ldap_attributes(cls, block, conn, dn):
'''Retrieve some attributes from LDAP, add mandatory values then apply
defined mappings between atrribute names'''
attributes = cls.get_ldap_attributes_names(block)
attribute_mappings = block['attribute_mappings']
mandatory_attributes_values = block['mandatory_attributes_values']
attribute_mappings = map_text(block['attribute_mappings'])
mandatory_attributes_values = map_text(block['mandatory_attributes_values'])
try:
results = conn.search_s(dn, ldap.SCOPE_BASE, '(objectclass=*)', attributes)
results = conn.search_s(dn, ldap.SCOPE_BASE, u'(objectclass=*)', attributes)
except ldap.LDAPError:
log.exception('unable to retrieve attributes of dn %r', dn)
return None
attribute_map = cls.normalize_ldap_results(results[0][1])
# add mandatory attributes
for key, mandatory_values in mandatory_attributes_values.iteritems():
key = force_bytes(key)
key = force_text(key)
old = attribute_map.setdefault(key, [])
new = set(old) | set(mandatory_values)
attribute_map[key] = list(new)
# apply mappings
for from_attribute, to_attribute in attribute_mappings:
from_attribute = force_bytes(from_attribute)
from_attribute = force_text(from_attribute)
if from_attribute not in attribute_map:
continue
to_attribute = force_bytes(to_attribute)
to_attribute = force_text(to_attribute)
old = attribute_map.setdefault(to_attribute, [])
new = set(old) | set(attribute_map[from_attribute])
attribute_map[to_attribute] = list(new)
@ -822,9 +961,9 @@ class LDAPBackend(object):
if quote:
decoded.append((attribute, urllib.unquote(value)))
else:
decoded.append((attribute, force_bytes(value)))
filters = [filter_format('(%s=%s)', (a, b)) for a, b in decoded]
return '(&{0})'.format(''.join(filters))
decoded.append((attribute, force_text(value)))
filters = [filter_format(u'(%s=%s)', (a, b)) for a, b in decoded]
return u'(&{0})'.format(''.join(filters))
def build_external_id(self, external_id_tuple, attributes):
'''Build the exernal id for the user, use attribute that eventually
@ -839,8 +978,7 @@ class LDAPBackend(object):
v = attributes[attribute]
if isinstance(v, list):
v = v[0]
if isinstance(v, unicode):
v = force_bytes(v)
v = force_text(v)
if quote:
v = urllib.quote(v)
l.append(v)
@ -856,14 +994,14 @@ class LDAPBackend(object):
def lookup_by_external_id(self, block, attributes):
User = get_user_model()
for eid_tuple in block['external_id_tuples']:
for eid_tuple in map_text(block['external_id_tuples']):
external_id = self.build_external_id(eid_tuple, attributes)
if not external_id:
continue
log.debug('lookup using external_id %r: %r', eid_tuple, external_id)
users = LDAPUser.objects.prefetch_related('groups').filter(
userexternalid__external_id__iexact=external_id,
userexternalid__source=block['realm']).order_by('-last_login')
userexternalid__source=force_text(block['realm'])).order_by('-last_login')
# ordering of NULLs cannot be done through the ORM
users = sorted(users, reverse=True, key=lambda u: (u.last_login is not None, u.last_login))
if users:
@ -902,15 +1040,15 @@ class LDAPBackend(object):
user.save()
user._changed = False
external_id = self.build_external_id(
block['external_id_tuples'][0],
map_text(block['external_id_tuples'][0]),
attributes)
if external_id:
new, created = UserExternalId.objects.get_or_create(
user=user, external_id=external_id, source=block['realm'])
user=user, external_id=external_id, source=force_text(block['realm']))
if block['clean_external_id_on_update']:
UserExternalId.objects \
.exclude(id=new.id) \
.filter(user=user, source=block['realm']) \
.filter(user=user, source=force_text(block['realm'])) \
.delete()
def _return_user(self, dn, password, conn, block, attributes=None):
@ -954,7 +1092,7 @@ class LDAPBackend(object):
names = set()
for block in cls.get_config():
names.update(cls.get_ldap_attributes_names(block))
names.update(block['mandatory_attributes_values'].keys())
names.update(map_text(block['mandatory_attributes_values']).keys())
return [(a, '%s (LDAP)' % a) for a in sorted(names)]
@classmethod
@ -978,10 +1116,10 @@ class LDAPBackend(object):
for block in cls.get_config():
conn = cls.get_connection(block)
if conn is None:
logger.warning(u'unable to synchronize with LDAP servers %r', block['url'])
logger.warning(u'unable to synchronize with LDAP servers %s', force_text(block['url']))
continue
user_basedn = block.get('user_basedn') or block['basedn']
user_filter = block['sync_ldap_users_filter'] or block['user_filter']
user_basedn = force_text(block.get('user_basedn') or block['basedn'])
user_filter = force_text(block['sync_ldap_users_filter'] or block['user_filter'])
user_filter = user_filter.replace('%s', '*')
attrs = cls.get_ldap_attributes_names(block)
users = cls.paged_search(conn, user_basedn, ldap.SCOPE_SUBTREE, user_filter,
@ -1008,7 +1146,7 @@ class LDAPBackend(object):
else:
modlist = []
if block['active_directory']:
key = 'unicodePwd'
key = u'unicodePwd'
value = cls.ad_encoding(new_password)
if old_password:
modlist = [
@ -1018,9 +1156,8 @@ class LDAPBackend(object):
else:
modlist = [(ldap.MOD_REPLACE, key, [value])]
else:
key = 'userPassword'
value = force_bytes(new_password)
modlist = [(ldap.MOD_REPLACE, key, [value])]
key = u'userPassword'
modlist = [(ldap.MOD_REPLACE, key, [new_password])]
conn.modify_s(dn, modlist)
log.debug('modified password for dn %r', dn)
@ -1040,10 +1177,10 @@ class LDAPBackend(object):
'''Try each replicas, and yield successfull connections'''
if not block['url']:
raise ImproperlyConfigured("block['url'] must contain at least one url")
for url in block['url']:
for url in map_text(block['url']):
for key, value in block['global_ldap_options'].iteritems():
ldap.set_option(key, value)
conn = ReconnectLDAPObject(url)
conn = LDAPObject(url)
if block['timeout'] > 0:
conn.set_option(ldap.OPT_NETWORK_TIMEOUT, block['timeout'])
conn.set_option(ldap.OPT_X_TLS_REQUIRE_CERT,
@ -1097,26 +1234,27 @@ class LDAPBackend(object):
'''Bind to the LDAP server'''
try:
if credentials:
who = credentials[0]
conn.bind_s(*credentials)
who, password = credentials[0], credentials[1]
password = force_text(password)
conn.simple_bind_s(who, password)
elif block['bindsasl']:
sasl_mech, who, sasl_params = block['bindsasl']
sasl_mech, who, sasl_params = map_text(block['bindsasl'])
handler_class = getattr(ldap.sasl, sasl_mech)
auth = handler_class(*sasl_params)
conn.sasl_interactive_bind_s(who, auth)
elif block['binddn'] and block['bindpw']:
who = block['binddn']
conn.bind_s(block['binddn'], block['bindpw'])
who = force_text(block['binddn'])
conn.simple_bind_s(who, force_text(block['bindpw']))
else:
who = 'anonymous'
conn.simple_bind_s()
return True, None
except ldap.INVALID_CREDENTIALS:
return False, 'invalid credentials'
return False, u'invalid credentials'
except ldap.INVALID_DN_SYNTAX:
return False, 'invalid dn syntax %r' % who
return False, u'invalid dn syntax %s' % who
except (ldap.TIMEOUT, ldap.CONNECT_ERROR, ldap.SERVER_DOWN):
return False, 'ldap is down'
return False, u'ldap is down'
@classmethod
def get_connection(cls, block, credentials=()):
@ -1155,7 +1293,7 @@ class LDAPBackend(object):
raise ImproperlyConfigured(
'LDAP_AUTH_SETTINGS: attribute %r must be a string' % d)
try:
block[d] = force_bytes(block[d])
block[d] = force_text(block[d])
except UnicodeEncodeError:
raise ImproperlyConfigured(
'LDAP_AUTH_SETTINGS: attribute %r must be a string' % d)
@ -1174,32 +1312,33 @@ class LDAPBackend(object):
'LDAP_AUTH_SETTINGS: attribute %r is required but is empty')
# force_bytes all strings in iterable or dict
if isinstance(block[d], (list, tuple, dict)):
block[d] = map_bytes(block[d])
block[d] = map_text(block[d])
# lowercase LDAP attribute names
block['external_id_tuples'] = map(
lambda t: map(str.lower, map(str, t)), block['external_id_tuples'])
block['attribute_mappings'] = map(
lambda t: map(str.lower, map(str, t)), block['attribute_mappings'])
block['external_id_tuples'] = map_text([[t.lower() for t in id_tuple]
for id_tuple in block['external_id_tuples']])
block['attribute_mappings'] = map_text([[t.lower() for t in at_mapping]
for at_mapping in block['attribute_mappings']])
assert block['external_id_tuples'] is not None
for key in cls._TO_LOWERCASE:
# we handle strings, list of strings and list of list or tuple whose first element is a
# string
if isinstance(block[key], six.string_types):
block[key] = force_bytes(block[key]).lower()
block[key] = force_text(block[key]).lower()
elif isinstance(block[key], (list, tuple)):
new_seq = []
for elt in block[key]:
if isinstance(elt, six.string_types):
elt = force_bytes(elt).lower()
elt = force_text(elt).lower()
elif isinstance(elt, (list, tuple)):
elt = list(elt)
elt[0] = force_bytes(elt[0]).lower()
elt[0] = force_text(elt[0]).lower()
elt = tuple(elt)
new_seq.append(elt)
block[key] = tuple(new_seq)
elif isinstance(block[key], dict):
newdict = {}
for subkey in block[key]:
newdict[force_bytes(subkey).lower()] = block[key][subkey]
newdict[force_text(subkey).lower()] = block[key][subkey]
block[key] = newdict
else:
raise NotImplementedError(
@ -1223,7 +1362,7 @@ class LDAPBackendPasswordLost(LDAPBackend):
for block in config:
if user_external_id.source != force_text(block['realm']):
continue
for external_id_tuple in block['external_id_tuples']:
for external_id_tuple in map_text(block['external_id_tuples']):
conn = self.ldap_backend.get_connection(block)
try:
if external_id_tuple == ('dn:noquote',):

View File

@ -5,6 +5,9 @@ import string
import ldap.dn
import ldap.filter
from django.utils.encoding import force_text
class DnFormatter(string.Formatter):
def get_value(self, key, args, kwargs):
value = super(DnFormatter, self).get_value(key, args, kwargs)
@ -20,6 +23,12 @@ class DnFormatter(string.Formatter):
value = super(DnFormatter, self).format_field(value, format_spec)
return ldap.dn.escape_dn_chars(value)
def convert_field(self, value, conversion):
if conversion == 's':
return force_text(value)
return super(DnFormatter, self).convert_field(value, conversion)
class FilterFormatter(string.Formatter):
def get_value(self, key, args, kwargs):
value = super(FilterFormatter, self).get_value(key, args, kwargs)
@ -34,3 +43,8 @@ class FilterFormatter(string.Formatter):
def format_field(self, value, format_spec):
value = super(FilterFormatter, self).format_field(value, format_spec)
return ldap.filter.escape_filter_chars(value)
def convert_field(self, value, conversion):
if conversion == 's':
return force_text(value)
return super(FilterFormatter, self).convert_field(value, conversion)