ci: add pre-commit hooks

This commit is contained in:
Benjamin Dauvergne 2024-02-03 14:53:38 +01:00
parent 435f49178f
commit 92f622d9c4
16 changed files with 417 additions and 320 deletions

42
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,42 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: double-quote-string-fixer
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: ['--keep-percent-format', '--py39-plus']
- repo: https://github.com/adamchainz/django-upgrade
rev: 1.13.0
hooks:
- id: django-upgrade
args: ['--target-version', '3.2']
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ['--profile', 'black', '--line-length', '110']
- repo: https://github.com/rtts/djhtml
rev: '3.0.5'
hooks:
- id: djhtml
args: ['--tabwidth', '2']
- repo: https://git.entrouvert.org/pre-commit-debian.git
rev: v0.3
hooks:
- id: pre-commit-debian
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.15
hooks:
# Run the linter.
- id: ruff
args: ['--fix']
exclude: 'debian/.*'
- id: ruff-format
args: ['--config', '.ruff.toml']
exclude: 'debian/.*'

10
debian/control vendored
View File

@ -2,11 +2,17 @@ Source: python-ldaptools
Section: python Section: python
Priority: optional Priority: optional
Maintainer: Benjamin Dauvergne <bdauvergne@entrouvert.com> Maintainer: Benjamin Dauvergne <bdauvergne@entrouvert.com>
Build-Depends: python3-setuptools, python3-all, debhelper-compat (= 12), dh-python Build-Depends: debhelper-compat (= 12),
dh-python,
python3-all,
python3-setuptools,
Standards-Version: 3.9.6 Standards-Version: 3.9.6
Homepage: http://dev.entrouvert.org/projects/ldaptools/ Homepage: http://dev.entrouvert.org/projects/ldaptools/
Package: python3-ldaptools Package: python3-ldaptools
Architecture: all Architecture: all
Depends: ${misc:Depends}, ${python3:Depends}, python3-ldap, python3-six Depends: python3-ldap,
python3-six,
${misc:Depends},
${python3:Depends},
Description: helper library for python-ldap and openldap Description: helper library for python-ldap and openldap

View File

@ -1,15 +1,15 @@
#! /usr/bin/env python #! /usr/bin/env python
import subprocess
import os import os
import subprocess
from setuptools import setup, find_packages from setuptools import find_packages, setup
from setuptools.command.sdist import sdist from setuptools.command.sdist import sdist
class eo_sdist(sdist): class eo_sdist(sdist):
def run(self): def run(self):
print("creating VERSION file") print('creating VERSION file')
if os.path.exists('VERSION'): if os.path.exists('VERSION'):
os.remove('VERSION') os.remove('VERSION')
version = get_version() version = get_version()
@ -17,22 +17,23 @@ class eo_sdist(sdist):
version_file.write(version) version_file.write(version)
version_file.close() version_file.close()
sdist.run(self) sdist.run(self)
print("removing VERSION file") print('removing VERSION file')
if os.path.exists('VERSION'): if os.path.exists('VERSION'):
os.remove('VERSION') os.remove('VERSION')
def get_version(): def get_version():
'''Use the VERSION, if absent generates a version with git describe, if not """Use the VERSION, if absent generates a version with git describe, if not
tag exists, take 0.0- and add the length of the commit log. tag exists, take 0.0- and add the length of the commit log.
''' """
if os.path.exists('VERSION'): if os.path.exists('VERSION'):
with open('VERSION', 'r') as v: with open('VERSION') as v:
return v.read() return v.read()
if os.path.exists('.git'): if os.path.exists('.git'):
p = subprocess.Popen( p = subprocess.Popen(
['git', 'describe', '--dirty=.dirty', '--match=v*'], ['git', 'describe', '--dirty=.dirty', '--match=v*'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) )
result = p.communicate()[0] result = p.communicate()[0]
if p.returncode == 0: if p.returncode == 0:
@ -44,33 +45,33 @@ def get_version():
version = result version = result
return version return version
else: else:
return '0.0.post%s' % len( return '0.0.post%s' % len(subprocess.check_output(['git', 'rev-list', 'HEAD']).splitlines())
subprocess.check_output(
['git', 'rev-list', 'HEAD']).splitlines())
return '0.0' return '0.0'
setup(name="ldaptools", setup(
version=get_version(), name='ldaptools',
license="AGPLv3+", version=get_version(),
description="ldaptools", license='AGPLv3+',
long_description=open('README.rst').read(), description='ldaptools',
url="http://dev.entrouvert.org/projects/ldaptools/", long_description=open('README.rst').read(),
author="Entr'ouvert", url='http://dev.entrouvert.org/projects/ldaptools/',
author_email="authentic@listes.entrouvert.com", author="Entr'ouvert",
maintainer="Benjamin Dauvergne", author_email='authentic@listes.entrouvert.com',
maintainer_email="bdauvergne@entrouvert.com", maintainer='Benjamin Dauvergne',
packages=find_packages('src'), maintainer_email='bdauvergne@entrouvert.com',
package_dir={'': 'src'}, packages=find_packages('src'),
include_package_data=True, package_dir={'': 'src'},
install_requires=['python-ldap', 'six'], include_package_data=True,
entry_points={ install_requires=['python-ldap', 'six'],
'console_scripts': ['ldapsync=ldaptools.ldapsync.cmd:main'], entry_points={
}, 'console_scripts': ['ldapsync=ldaptools.ldapsync.cmd:main'],
zip_safe=False, },
classifiers=[ zip_safe=False,
"License :: OSI Approved :: MIT License", classifiers=[
"Topic :: System :: Systems Administration :: Authentication/Directory", 'License :: OSI Approved :: MIT License',
"Programming Language :: Python", 'Topic :: System :: Systems Administration :: Authentication/Directory',
], 'Programming Language :: Python',
cmdclass={'sdist': eo_sdist}) ],
cmdclass={'sdist': eo_sdist},
)

View File

@ -1,12 +1,10 @@
class CommandError(Exception): class CommandError(Exception):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.err_code = kwargs.pop('err_code', None) self.err_code = kwargs.pop('err_code', None)
super(CommandError, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class ConfigError(CommandError): class ConfigError(CommandError):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['err_code'] = 1 kwargs['err_code'] = 1
super(ConfigError, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -3,7 +3,7 @@ import ldap
from ldaptools.utils import idict, istr from ldaptools.utils import idict, istr
class LDAPSource(object): class LDAPSource:
entries = None entries = None
conn = None conn = None
base_dn = None base_dn = None
@ -18,9 +18,9 @@ class LDAPSource(object):
self.base_dn = base_dn or self.base_dn self.base_dn = base_dn or self.base_dn
def search(self): def search(self):
for dn, entry in self.conn.paged_search_ext_s(self.base_dn, ldap.SCOPE_SUBTREE, for dn, entry in self.conn.paged_search_ext_s(
filterstr=self.filterstr, self.base_dn, ldap.SCOPE_SUBTREE, filterstr=self.filterstr, attrlist=self.attributes
attrlist=self.attributes): ):
if not dn: if not dn:
continue continue
entry = idict(entry) entry = idict(entry)

View File

@ -1,11 +1,9 @@
from __future__ import print_function
import argparse import argparse
import sys import sys
import ldap.sasl import ldap.sasl
from ldaptools import ldif_utils, paged, ldap_source from ldaptools import ldap_source, ldif_utils, paged
from ldaptools.synchronize import Synchronize from ldaptools.synchronize import Synchronize
@ -25,6 +23,7 @@ def or_type(f1, f2):
return f2(value) return f2(value)
except argparse.ArgumentTypeError as e2: except argparse.ArgumentTypeError as e2:
raise argparse.ArgumentTypeError('%s and %s' % (e1.args[0], e2.args[0])) raise argparse.ArgumentTypeError('%s and %s' % (e1.args[0], e2.args[0]))
return f return f
@ -36,57 +35,51 @@ def object_class_pivot(value):
def main(args=None): def main(args=None):
parser = argparse.ArgumentParser(description='''\ parser = argparse.ArgumentParser(
description="""\
Synchronize an LDIF file or a source LDAP directory to another directory Synchronize an LDIF file or a source LDAP directory to another directory
Base DN of the source is remapped to another DN in the target directory''') Base DN of the source is remapped to another DN in the target directory"""
parser.add_argument('--object-class-pivot', )
required=True, parser.add_argument(
type=object_class_pivot, '--object-class-pivot',
action='append', required=True,
help='an objectClass and an attribute name which is the unique identifier ' type=object_class_pivot,
'for this class') action='append',
parser.add_argument('--attributes-file', help='an objectClass and an attribute name which is the unique identifier ' 'for this class',
type=argparse.FileType('r'), )
help='a file containing the list of attributes to synchronize') parser.add_argument(
parser.add_argument('--attributes', '--attributes-file',
help='a list of attribute names separated by spaces') type=argparse.FileType('r'),
parser.add_argument('--source-uri', help='a file containing the list of attributes to synchronize',
required=True, )
type=or_type(source_uri, argparse.FileType('r')), parser.add_argument('--attributes', help='a list of attribute names separated by spaces')
help='URL of an LDAP directory (ldapi://, ldap:// or ldaps://) or path of ' parser.add_argument(
'and LDIF file') '--source-uri',
parser.add_argument('--case-insensitive-attribute', required=True,
action='append', type=or_type(source_uri, argparse.FileType('r')),
help='indicate that the attribute must be compared case insensitively') help='URL of an LDAP directory (ldapi://, ldap:// or ldaps://) or path of ' 'and LDIF file',
parser.add_argument('--source-base-dn', )
required=True, parser.add_argument(
help='base DN of the source') '--case-insensitive-attribute',
parser.add_argument('--source-bind-dn', action='append',
help='bind DN for a source LDAP directory') help='indicate that the attribute must be compared case insensitively',
parser.add_argument('--source-bind-password', )
help='bind password for a source LDAP directory') parser.add_argument('--source-base-dn', required=True, help='base DN of the source')
parser.add_argument('--source-filter', parser.add_argument('--source-bind-dn', help='bind DN for a source LDAP directory')
help='filter to apply to a source LDAP directory') parser.add_argument('--source-bind-password', help='bind password for a source LDAP directory')
parser.add_argument('--source-objectclasses', parser.add_argument('--source-filter', help='filter to apply to a source LDAP directory')
help='keep only thoses object classes') parser.add_argument('--source-objectclasses', help='keep only thoses object classes')
parser.add_argument('--target-uri', parser.add_argument(
type=source_uri, '--target-uri', type=source_uri, required=True, help='URL of the target LDAP directory'
required=True, )
help='URL of the target LDAP directory') parser.add_argument('--target-base-dn', required=True, help='base DN of the target LDAP directory')
parser.add_argument('--target-base-dn', parser.add_argument('--target-bind-dn', help='bind DN for a target LDAP directory')
required=True, parser.add_argument('--target-bind-password', help='bind password for a target LDAP directory')
help='base DN of the target LDAP directory') parser.add_argument(
parser.add_argument('--target-bind-dn', '--fake', action='store_true', help='compute synchronization actions but do not apply'
help='bind DN for a target LDAP directory') )
parser.add_argument('--target-bind-password', parser.add_argument('--verbose', action='store_true', help='print all actions to stdout')
help='bind password for a target LDAP directory')
parser.add_argument('--fake',
action='store_true',
help='compute synchronization actions but do not apply')
parser.add_argument('--verbose',
action='store_true',
help='print all actions to stdout')
options = parser.parse_args(args=args) options = parser.parse_args(args=args)
@ -116,30 +109,35 @@ Base DN of the source is remapped to another DN in the target directory''')
print(options.source_uri, end=' ') print(options.source_uri, end=' ')
conn = paged.PagedLDAPObject(options.source_uri) conn = paged.PagedLDAPObject(options.source_uri)
if options.source_uri.startswith('ldapi://'): if options.source_uri.startswith('ldapi://'):
conn.sasl_interactive_bind_s("", ldap.sasl.external()) conn.sasl_interactive_bind_s('', ldap.sasl.external())
elif options.source_bind_dn and options.source_bind_password: elif options.source_bind_dn and options.source_bind_password:
conn.simple_bind_s(options.source_bind_dn, options.source_bind_password) conn.simple_bind_s(options.source_bind_dn, options.source_bind_password)
source = ldap_source.LDAPSource(conn, base_dn=options.source_base_dn, attributes=attributes, source = ldap_source.LDAPSource(
filterstr=options.source_filter) conn, base_dn=options.source_base_dn, attributes=attributes, filterstr=options.source_filter
)
if options.verbose: if options.verbose:
print('to', options.target_uri, end=' ') print('to', options.target_uri, end=' ')
target_conn = paged.PagedLDAPObject(options.target_uri) target_conn = paged.PagedLDAPObject(options.target_uri)
if options.target_uri.startswith('ldapi://'): if options.target_uri.startswith('ldapi://'):
target_conn.sasl_interactive_bind_s("", ldap.sasl.external()) target_conn.sasl_interactive_bind_s('', ldap.sasl.external())
elif options.target_bind_dn and options.target_bind_dn: elif options.target_bind_dn and options.target_bind_dn:
target_conn.simple_bind_s(options.target_bind_dn, options.target_bind_password) target_conn.simple_bind_s(options.target_bind_dn, options.target_bind_password)
if options.source_objectclasses: if options.source_objectclasses:
source_objectclasses = options.source_objectclasses.split() source_objectclasses = options.source_objectclasses.split()
else: else:
source_objectclasses = [v[0] for v in options.object_class_pivot] source_objectclasses = [v[0] for v in options.object_class_pivot]
synchronize = Synchronize(source, options.source_base_dn, synchronize = Synchronize(
target_conn, options.target_base_dn, source,
pivot_attributes=options.object_class_pivot, options.source_base_dn,
objectclasses=source_objectclasses, target_conn,
attributes=attributes, options.target_base_dn,
case_insensitive_attribute=options.case_insensitive_attribute) pivot_attributes=options.object_class_pivot,
objectclasses=source_objectclasses,
attributes=attributes,
case_insensitive_attribute=options.case_insensitive_attribute,
)
synchronize.build_actions() synchronize.build_actions()
if options.verbose: if options.verbose:

View File

@ -2,7 +2,7 @@ import ldap
import ldif import ldif
from ldap.dn import dn2str from ldap.dn import dn2str
from ldaptools.utils import idict, str2dn, str2bytes_entry, bytes2str_entry from ldaptools.utils import bytes2str_entry, idict, str2bytes_entry, str2dn
class AddError(Exception): class AddError(Exception):

View File

@ -1,16 +1,24 @@
import six
import ldap import ldap
from ldap.ldapobject import ReconnectLDAPObject import six
from ldap.controls import SimplePagedResultsControl from ldap.controls import SimplePagedResultsControl
from ldap.ldapobject import ReconnectLDAPObject
class PagedResultsSearchObject: class PagedResultsSearchObject:
page_size = 500 page_size = 500
def paged_search_ext_s(self, base, scope, filterstr='(objectClass=*)', attrlist=None, def paged_search_ext_s(
attrsonly=0, serverctrls=None, clientctrls=None, timeout=-1, self,
sizelimit=0): base,
scope,
filterstr='(objectClass=*)',
attrlist=None,
attrsonly=0,
serverctrls=None,
clientctrls=None,
timeout=-1,
sizelimit=0,
):
""" """
Behaves exactly like LDAPObject.search_ext_s() but internally uses the Behaves exactly like LDAPObject.search_ext_s() but internally uses the
simple paged results control to retrieve search results in chunks. simple paged results control to retrieve search results in chunks.
@ -29,23 +37,17 @@ class PagedResultsSearchObject:
filterstr=filterstr, filterstr=filterstr,
attrlist=attrlist, attrlist=attrlist,
attrsonly=attrsonly, attrsonly=attrsonly,
serverctrls=(serverctrls or [])+[req_ctrl], serverctrls=(serverctrls or []) + [req_ctrl],
clientctrls=clientctrls, clientctrls=clientctrls,
timeout=timeout, timeout=timeout,
sizelimit=sizelimit sizelimit=sizelimit,
) )
while True: while True:
rtype, rdata, rmsgid, rctrls = self.result3(msgid) rtype, rdata, rmsgid, rctrls = self.result3(msgid)
for result in rdata: yield from rdata
yield result
# Extract the simple paged results response control # Extract the simple paged results response control
pctrls = [ pctrls = [c for c in rctrls if c.controlType == SimplePagedResultsControl.controlType]
c
for c in rctrls
if c.controlType == SimplePagedResultsControl.controlType
]
if pctrls and pctrls[0].cookie: if pctrls and pctrls[0].cookie:
# Copy cookie from response control to request control # Copy cookie from response control to request control
req_ctrl.cookie = pctrls[0].cookie req_ctrl.cookie = pctrls[0].cookie
@ -55,10 +57,10 @@ class PagedResultsSearchObject:
filterstr=filterstr, filterstr=filterstr,
attrlist=attrlist, attrlist=attrlist,
attrsonly=attrsonly, attrsonly=attrsonly,
serverctrls=(serverctrls or [])+[req_ctrl], serverctrls=(serverctrls or []) + [req_ctrl],
clientctrls=clientctrls, clientctrls=clientctrls,
timeout=timeout, timeout=timeout,
sizelimit=sizelimit sizelimit=sizelimit,
) )
continue continue
break # no more pages available break # no more pages available

View File

@ -1,19 +1,21 @@
import codecs import codecs
import time import os
import tempfile
import shutil import shutil
import subprocess import subprocess
import os import tempfile
import time
import ldap import ldap
import ldap.modlist import ldap.modlist
import ldap.sasl import ldap.sasl
try: try:
from StringIO import StringIO from StringIO import StringIO
except ImportError: except ImportError:
from io import StringIO from io import StringIO
import atexit
from six.moves.urllib.parse import quote import atexit
from urllib.parse import quote
from ldaptools.ldif_utils import ListLDIFParser from ldaptools.ldif_utils import ListLDIFParser
from ldaptools.paged import PagedLDAPObject from ldaptools.paged import PagedLDAPObject
@ -36,14 +38,15 @@ def has_slapd():
return not (SLAPD_PATH is None or SLAPADD_PATH is None) return not (SLAPD_PATH is None or SLAPADD_PATH is None)
class Slapd(object): class Slapd:
'''Initiliaze an OpenLDAP server with just one database containing branch """Initiliaze an OpenLDAP server with just one database containing branch
o=orga and loading the core schema. ACL are very permissive. o=orga and loading the core schema. ACL are very permissive.
''' """
root_bind_dn = 'uid=admin,cn=config' root_bind_dn = 'uid=admin,cn=config'
root_bind_password = 'admin' root_bind_password = 'admin'
config_ldif = '''dn: cn=config config_ldif = """dn: cn=config
objectClass: olcGlobal objectClass: olcGlobal
cn: config cn: config
olcToolThreads: 1 olcToolThreads: 1
@ -88,15 +91,14 @@ olcAccess: {{0}}to *
by dn.exact=gidNumber={gid}+uidNumber={uid},cn=peercred,cn=external,cn=auth manage by dn.exact=gidNumber={gid}+uidNumber={uid},cn=peercred,cn=external,cn=auth manage
by * break by * break
''' """
process = None process = None
schemas = ['core', 'cosine', 'inetorgperson', 'nis', 'eduorg-200210-openldap', 'eduperson', schemas = ['core', 'cosine', 'inetorgperson', 'nis', 'eduorg-200210-openldap', 'eduperson', 'supann-2009']
'supann-2009']
schemas_ldif = [] schemas_ldif = []
for schema in schemas: for schema in schemas:
with codecs.open( with codecs.open(
os.path.join( os.path.join(os.path.dirname(__file__), 'schemas', '%s.ldif' % schema), encoding='utf-8'
os.path.dirname(__file__), 'schemas', '%s.ldif' % schema), encoding='utf-8') as fd: ) as fd:
schemas_ldif.append(fd.read()) schemas_ldif.append(fd.read())
checkpoints = None checkpoints = None
data_dirs = None data_dirs = None
@ -130,23 +132,25 @@ olcAccess: {{0}}to *
extra_config += 'olcTLSCertificateKeyFile: %s\n' % real_key extra_config += 'olcTLSCertificateKeyFile: %s\n' % real_key
extra_config += 'olcTLSCertificateFile: %s\n' % real_cert extra_config += 'olcTLSCertificateFile: %s\n' % real_cert
extra_config += 'olcSecurity: ssf=1\n' extra_config += 'olcSecurity: ssf=1\n'
config_context.update({ config_context.update(
'slapd_dir': self.slapd_dir, {
'gid': os.getgid(), 'slapd_dir': self.slapd_dir,
'uid': os.getuid(), 'gid': os.getgid(),
'extra_config': extra_config, 'uid': os.getuid(),
}) 'extra_config': extra_config,
}
)
self.slapadd(self.config_ldif, context=config_context) self.slapadd(self.config_ldif, context=config_context)
for schema_ldif in self.schemas_ldif: for schema_ldif in self.schemas_ldif:
self.slapadd(schema_ldif) self.slapadd(schema_ldif)
self.start() self.start()
try: try:
self.add_db('o=orga') self.add_db('o=orga')
ldif = '''dn: o=orga ldif = """dn: o=orga
objectClass: organization objectClass: organization
o: orga o: orga
''' """
self.add_ldif(ldif) self.add_ldif(ldif)
except: except:
self.stop() self.stop()
@ -155,7 +159,7 @@ o: orga
def add_db(self, suffix): def add_db(self, suffix):
path = os.path.join(self.slapd_dir, suffix) path = os.path.join(self.slapd_dir, suffix)
os.mkdir(path) os.mkdir(path)
ldif = '''dn: olcDatabase={{{index}}}mdb,cn=config ldif = """dn: olcDatabase={{{index}}}mdb,cn=config
objectClass: olcDatabaseConfig objectClass: olcDatabaseConfig
objectClass: olcMdbConfig objectClass: olcMdbConfig
olcDatabase: mdb olcDatabase: mdb
@ -165,7 +169,7 @@ olcReadOnly: FALSE
# Index # Index
olcAccess: {{0}}to * by * manage olcAccess: {{0}}to * by * manage
''' """
self.add_ldif(ldif, context={'index': self.db_index, 'suffix': suffix, 'path': path}) self.add_ldif(ldif, context={'index': self.db_index, 'suffix': suffix, 'path': path})
self.db_index += 1 self.db_index += 1
self.data_dirs.append(path) self.data_dirs.append(path)
@ -177,8 +181,11 @@ olcAccess: {{0}}to * by * manage
ldif = ldif.format(**context) ldif = ldif.format(**context)
slapadd = subprocess.Popen( slapadd = subprocess.Popen(
[SLAPADD_PATH, '-v', '-n%d' % db, '-F', self.config_dir], [SLAPADD_PATH, '-v', '-n%d' % db, '-F', self.config_dir],
stdin=subprocess.PIPE, env=os.environ, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE) env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = slapadd.communicate(input=bytearray(ldif, 'utf-8')) stdout, stderr = slapadd.communicate(input=bytearray(ldif, 'utf-8'))
assert slapadd.returncode == 0, 'slapadd failed: %s' % stderr assert slapadd.returncode == 0, 'slapadd failed: %s' % stderr
@ -190,15 +197,11 @@ olcAccess: {{0}}to * by * manage
self.close_fds() self.close_fds()
self._close_fds = close_fds self._close_fds = close_fds
def start(self): def start(self):
'''Launch slapd''' """Launch slapd"""
assert not self.process assert not self.process
cmd = [SLAPD_PATH, cmd = [SLAPD_PATH, '-d768', '-F' + self.config_dir, '-h', self.ldap_url] # put slapd in foreground
'-d768', # put slapd in foreground
'-F' + self.config_dir,
'-h', self.ldap_url]
out_file = open(os.path.join(self.slapd_dir, 'stdout'), 'w') out_file = open(os.path.join(self.slapd_dir, 'stdout'), 'w')
dev_null = open(os.devnull) dev_null = open(os.devnull)
self.process = subprocess.Popen(cmd, stdin=dev_null, env=os.environ, stdout=out_file, stderr=out_file) self.process = subprocess.Popen(cmd, stdin=dev_null, env=os.environ, stdout=out_file, stderr=out_file)
@ -211,7 +214,7 @@ olcAccess: {{0}}to * by * manage
try: try:
conn = self.get_connection() conn = self.get_connection()
conn.whoami_s() conn.whoami_s()
except ldap.SERVER_DOWN as e: except ldap.SERVER_DOWN:
if c > 100: if c > 100:
raise raise
time.sleep(0.1) time.sleep(0.1)
@ -219,7 +222,7 @@ olcAccess: {{0}}to * by * manage
break break
def stop(self): def stop(self):
'''Send SIGTERM to slapd''' """Send SIGTERM to slapd"""
assert self.process assert self.process
process = self.process process = self.process
@ -237,18 +240,17 @@ olcAccess: {{0}}to * by * manage
self.process = None self.process = None
def checkpoint(self): def checkpoint(self):
'''Stop slapd and save current data state''' """Stop slapd and save current data state"""
assert not self.process assert not self.process
self.checkpoints.append( self.checkpoints.append(os.path.join(self.slapd_dir, 'checkpoint-%d' % len(self.checkpoints)))
os.path.join(self.slapd_dir, 'checkpoint-%d' % len(self.checkpoints)))
for data_dir in self.data_dirs: for data_dir in self.data_dirs:
dirname = os.path.basename(data_dir) dirname = os.path.basename(data_dir)
target = os.path.join(self.checkpoints[-1], dirname) target = os.path.join(self.checkpoints[-1], dirname)
shutil.copytree(data_dir, target) shutil.copytree(data_dir, target)
def restore(self): def restore(self):
'''Stop slapd and restore last data state''' """Stop slapd and restore last data state"""
assert not self.process assert not self.process
assert self.checkpoints, 'no checkpoint exists' assert self.checkpoints, 'no checkpoint exists'
for data_dir in self.data_dirs: for data_dir in self.data_dirs:
@ -263,7 +265,7 @@ olcAccess: {{0}}to * by * manage
self.clean() self.clean()
def clean(self): def clean(self):
'''Remove directory''' """Remove directory"""
self.close_fds() self.close_fds()
try: try:
if self.process: if self.process:
@ -312,5 +314,5 @@ olcAccess: {{0}}to * by * manage
assert self.ldap_url.startswith('ldapi://') assert self.ldap_url.startswith('ldapi://')
conn = self.get_connection() conn = self.get_connection()
conn.sasl_interactive_bind_s("", ldap.sasl.external()) conn.sasl_interactive_bind_s('', ldap.sasl.external())
return conn return conn

View File

@ -1,19 +1,17 @@
import logging
import functools import functools
import logging
from itertools import groupby from itertools import groupby
import ldap import ldap
from ldap.filter import filter_format
import ldap.modlist
import ldap.dn import ldap.dn
import ldap.modlist
from ldap.filter import filter_format
from .utils import batch_generator, bytes2str_entry, idict, istr, str2bytes_entry, str2dn, to_dict_of_set
from .utils import batch_generator, to_dict_of_set, idict, str2dn, istr, \
bytes2str_entry, str2bytes_entry
@functools.total_ordering @functools.total_ordering
class Action(object): class Action:
dn = None dn = None
new_dn = None new_dn = None
entry = None entry = None
@ -25,8 +23,12 @@ class Action(object):
self.results = [] self.results = []
def __eq__(self, other): def __eq__(self, other):
return (other.__class__ is self.__class__ and self.dn == other.dn and other.new_dn == return (
other.new_dn and to_dict_of_set(self.entry) == to_dict_of_set(other.entry)) other.__class__ is self.__class__
and self.dn == other.dn
and other.new_dn == other.new_dn
and to_dict_of_set(self.entry) == to_dict_of_set(other.entry)
)
# - first rename, sorted by dn depth # - first rename, sorted by dn depth
# - then update and creations, sorted by depth # - then update and creations, sorted by depth
@ -96,8 +98,9 @@ class Delete(Action):
self.msgids.append(conn.delete(self.dn)) self.msgids.append(conn.delete(self.dn))
class Synchronize(object): class Synchronize:
'''Synchronize a source or records with an LDAP server''' """Synchronize a source or records with an LDAP server"""
BATCH_SIZE = 100 BATCH_SIZE = 100
# an iterable yield pair of (dn, attributes) # an iterable yield pair of (dn, attributes)
@ -123,39 +126,52 @@ class Synchronize(object):
actions = None actions = None
case_insensitive_attribute = None case_insensitive_attribute = None
def __init__(self, source, source_dn, target_conn, target_dn, attributes=None, all_filter=None, def __init__(
pivot_attributes=None, logger=None, case_insensitive_attribute=None, self,
objectclasses=None): source,
source_dn,
target_conn,
target_dn,
attributes=None,
all_filter=None,
pivot_attributes=None,
logger=None,
case_insensitive_attribute=None,
objectclasses=None,
):
self.source = source self.source = source
self.source_dn = source_dn self.source_dn = source_dn
self.target_conn = target_conn self.target_conn = target_conn
self.target_dn = target_dn self.target_dn = target_dn
self.attributes = list(set(istr(attribute) for attribute in attributes or self.attributes)) self.attributes = list({istr(attribute) for attribute in attributes or self.attributes})
self.all_filter = all_filter or self.all_filter self.all_filter = all_filter or self.all_filter
self.pivot_attributes = pivot_attributes or self.pivot_attributes self.pivot_attributes = pivot_attributes or self.pivot_attributes
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
self.case_insensitive_attribute = map(istr, case_insensitive_attribute self.case_insensitive_attribute = map(
or self.case_insensitive_attribute or []) istr, case_insensitive_attribute or self.case_insensitive_attribute or []
)
self.objectclasses = [istr(v) for v in objectclasses or []] self.objectclasses = [istr(v) for v in objectclasses or []]
self.errors = [] self.errors = []
def massage_dn(self, old_dn): def massage_dn(self, old_dn):
return old_dn[:-len(self.source_dn)] + self.target_dn return old_dn[: -len(self.source_dn)] + self.target_dn
def get_pivot_attribute(self, dn, entry): def get_pivot_attribute(self, dn, entry):
'''Find a pivot attribute value for an LDAP entry''' """Find a pivot attribute value for an LDAP entry"""
for objc, attr in self.pivot_attributes: for objc, attr in self.pivot_attributes:
if istr(objc) in [istr(oc.decode('utf-8')) if istr(objc) in [
if isinstance(oc, bytes) else oc istr(oc.decode('utf-8')) if isinstance(oc, bytes) else oc for oc in entry['objectclass']
for oc in entry['objectclass']]: ]:
try: try:
value = entry[attr] value = entry[attr]
except KeyError: except KeyError:
raise Exception('entry %s missing pivot attribute %s: %s' % (dn, attr, entry)) raise Exception('entry %s missing pivot attribute %s: %s' % (dn, attr, entry))
break break
else: else:
raise Exception('entry %s has unknown objectclasses %s' % (dn, raise Exception(
[objclass for objclass in entry['objectclass']])) 'entry %s has unknown objectclasses %s'
% (dn, [objclass for objclass in entry['objectclass']])
)
if len(value) != 1: if len(value) != 1:
raise Exception('entry %s pivot attribute %s must have only one value' % (dn, attr)) raise Exception('entry %s pivot attribute %s must have only one value' % (dn, attr))
value = value[0] value = value[0]
@ -170,13 +186,16 @@ class Synchronize(object):
return objc, attr, value return objc, attr, value
def get_target_entries(self, filterstr=None, attributes=[]): def get_target_entries(self, filterstr=None, attributes=[]):
'''Return all target entries''' """Return all target entries"""
try: try:
# Check base DN exist # Check base DN exist
self.target_conn.search_s(self.target_dn, ldap.SCOPE_BASE) self.target_conn.search_s(self.target_dn, ldap.SCOPE_BASE)
res = self.target_conn.paged_search_ext_s(self.target_dn, ldap.SCOPE_SUBTREE, res = self.target_conn.paged_search_ext_s(
filterstr=filterstr or self.all_filter, self.target_dn,
attrlist=attributes) ldap.SCOPE_SUBTREE,
filterstr=filterstr or self.all_filter,
attrlist=attributes,
)
return ((dn, idict(bytes2str_entry(entry))) for dn, entry in res if dn) return ((dn, idict(bytes2str_entry(entry))) for dn, entry in res if dn)
except ldap.NO_SUCH_OBJECT: except ldap.NO_SUCH_OBJECT:
return [] return []
@ -189,15 +208,13 @@ class Synchronize(object):
# Ignore some objectclasses # Ignore some objectclasses
if self.objectclasses: if self.objectclasses:
for dn, entry in entries: for dn, entry in entries:
entry['objectclass'] = [v for v in entry['objectclass'] entry['objectclass'] = [v for v in entry['objectclass'] if istr(v) in self.objectclasses]
if istr(v) in self.objectclasses]
# Transform input entries into filters # Transform input entries into filters
for dn, entry in entries: for dn, entry in entries:
objectclass, attr, value = self.get_pivot_attribute(dn, entry) objectclass, attr, value = self.get_pivot_attribute(dn, entry)
in_dns.append(((attr, value), (dn, entry))) in_dns.append(((attr, value), (dn, entry)))
filter_tpl = '(&(objectclass=%%s)(%s=%%s))' % attr filter_tpl = '(&(objectclass=%%s)(%s=%%s))' % attr
out_filters.append( out_filters.append(filter_format(filter_tpl, (objectclass, value)))
filter_format(filter_tpl, (objectclass, value)))
out_filter = '(|%s)' % ''.join(out_filters) out_filter = '(|%s)' % ''.join(out_filters)
# Get existing output entries # Get existing output entries
out_dns = {} out_dns = {}
@ -230,8 +247,9 @@ class Synchronize(object):
for attribute in self.attributes: for attribute in self.attributes:
if attribute in to_dict_of_set(entry): if attribute in to_dict_of_set(entry):
new_entry[attribute] = entry[attribute] new_entry[attribute] = entry[attribute]
if (attribute in to_dict_of_set(out_entry) and not if attribute in to_dict_of_set(out_entry) and not to_dict_of_set(entry).get(
to_dict_of_set(entry).get(attribute)): attribute
):
new_entry[attribute] = [] new_entry[attribute] = []
self.update(target_dn, new_entry) self.update(target_dn, new_entry)
else: else:
@ -246,7 +264,7 @@ class Synchronize(object):
entries.sort(key=lambda dn_entry: len(str2dn(dn_entry[0]))) entries.sort(key=lambda dn_entry: len(str2dn(dn_entry[0])))
for dn, entry in entries: for dn, entry in entries:
for key in entry.keys(): for key in entry.keys():
if not str(key.lower()) in self.attributes: if str(key.lower()) not in self.attributes:
del entry[key] del entry[key]
# First create, rename and update # First create, rename and update
for batch in batch_generator(entries, self.BATCH_SIZE): for batch in batch_generator(entries, self.BATCH_SIZE):
@ -276,11 +294,12 @@ class Synchronize(object):
self.actions.append(Delete(dn=dn)) self.actions.append(Delete(dn=dn))
def apply_actions(self): def apply_actions(self):
'''Apply actions, wait for result of different kind of actions """Apply actions, wait for result of different kind of actions
separately, since openldap seem to reorder some of them''' separately, since openldap seem to reorder some of them"""
def action_key(action): def action_key(action):
return (action.__class__, str2dn(action.dn)) return (action.__class__, str2dn(action.dn))
for key, sequence in groupby(self.actions, action_key): for key, sequence in groupby(self.actions, action_key):
for batch in batch_generator(sequence, self.BATCH_SIZE): for batch in batch_generator(sequence, self.BATCH_SIZE):
for action in batch: for action in batch:

View File

@ -1,5 +1,4 @@
import ldap.dn import ldap.dn
import six
# Copied from http://code.activestate.com/recipes/194371-case-insensitive-strings/ # Copied from http://code.activestate.com/recipes/194371-case-insensitive-strings/
@ -12,7 +11,7 @@ class istr(str):
self.__lowerCaseMe = strMe.lower() self.__lowerCaseMe = strMe.lower()
def __repr__(self): def __repr__(self):
return "iStr(%s)" % str.__repr__(self) return 'iStr(%s)' % str.__repr__(self)
def __eq__(self, other): def __eq__(self, other):
if not hasattr(other, 'lower'): if not hasattr(other, 'lower'):
@ -36,9 +35,6 @@ class istr(str):
def __ge__(self, other): def __ge__(self, other):
return self.__lowerCaseMe >= other.lower() return self.__lowerCaseMe >= other.lower()
def __cmp__(self, other):
return cmp(self.__lowerCaseMe, other.lower())
def __hash__(self): def __hash__(self):
return hash(self.__lowerCaseMe) return hash(self.__lowerCaseMe)
@ -57,7 +53,7 @@ class istr(str):
def index(self, other, *args): def index(self, other, *args):
return str.index(self.__lowerCaseMe, other.lower(), *args) return str.index(self.__lowerCaseMe, other.lower(), *args)
def lower(self): # Courtesy Duncan Booth def lower(self): # Courtesy Duncan Booth
return self.__lowerCaseMe return self.__lowerCaseMe
def rfind(self, other, *args): def rfind(self, other, *args):
@ -72,6 +68,7 @@ class istr(str):
class idict(dict): class idict(dict):
"""A case insensitive dictionary that only permits strings as keys.""" """A case insensitive dictionary that only permits strings as keys."""
def __init__(self, indict={}): def __init__(self, indict={}):
dict.__init__(self) dict.__init__(self)
self._keydict = {} # not self.__keydict because I want it to be easily accessible by subclasses self._keydict = {} # not self.__keydict because I want it to be easily accessible by subclasses
@ -83,7 +80,7 @@ class idict(dict):
def findkey(self, item): def findkey(self, item):
"""A caseless way of checking if a key exists or not. """A caseless way of checking if a key exists or not.
It returns None or the correct key.""" It returns None or the correct key."""
if not isinstance(item, six.string_types): if not isinstance(item, str):
raise TypeError('Keywords for this object must be strings. You supplied %s' % type(item)) raise TypeError('Keywords for this object must be strings. You supplied %s' % type(item))
key = item.lower() key = item.lower()
try: try:
@ -97,7 +94,7 @@ class idict(dict):
This is useful when initially setting up default keys - but later might want to preserve an alternative casing. This is useful when initially setting up default keys - but later might want to preserve an alternative casing.
(e.g. if later read from a config file - and you might want to write back out with the user's casing preserved). (e.g. if later read from a config file - and you might want to write back out with the user's casing preserved).
""" """
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is None: if key is None:
raise KeyError(item) raise KeyError(item)
temp = self[key] temp = self[key]
@ -109,9 +106,9 @@ class idict(dict):
"""Returns a lowercase list of all member keywords.""" """Returns a lowercase list of all member keywords."""
return self._keydict.keys() return self._keydict.keys()
def __setitem__(self, item, value): # setting a keyword def __setitem__(self, item, value): # setting a keyword
"""To implement lowercase keys.""" """To implement lowercase keys."""
key = self.findkey(item) # if the key already exists key = self.findkey(item) # if the key already exists
if key is not None: if key is not None:
dict.__delitem__(self, key) dict.__delitem__(self, key)
self._keydict[item.lower()] = item self._keydict[item.lower()] = item
@ -119,13 +116,13 @@ class idict(dict):
def __getitem__(self, item): def __getitem__(self, item):
"""To implement lowercase keys.""" """To implement lowercase keys."""
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is None: if key is None:
raise KeyError(item) raise KeyError(item)
return dict.__getitem__(self, key) return dict.__getitem__(self, key)
def __delitem__(self, item): # deleting a keyword def __delitem__(self, item): # deleting a keyword
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is None: if key is None:
raise KeyError(item) raise KeyError(item)
dict.__delitem__(self, key) dict.__delitem__(self, key)
@ -133,7 +130,7 @@ class idict(dict):
def pop(self, item, default=None): def pop(self, item, default=None):
"""Correctly emulates the pop method.""" """Correctly emulates the pop method."""
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is None: if key is None:
if default is None: if default is None:
raise KeyError(item) raise KeyError(item)
@ -150,20 +147,20 @@ class idict(dict):
def has_key(self, item): def has_key(self, item):
"""A case insensitive test for keys.""" """A case insensitive test for keys."""
if not isinstance(item, six.string_types): if not isinstance(item, str):
return False # should never have a non-string key return False # should never have a non-string key
return item.lower() in self._keydict # does the key exist return item.lower() in self._keydict # does the key exist
def __contains__(self, item): def __contains__(self, item):
"""A case insensitive __contains__.""" """A case insensitive __contains__."""
if not isinstance(item, six.string_types): if not isinstance(item, str):
return False # should never have a non-string key return False # should never have a non-string key
return item.lower() in self._keydict # does the key exist return item.lower() in self._keydict # does the key exist
def setdefault(self, item, default=None): def setdefault(self, item, default=None):
"""A case insensitive setdefault. """A case insensitive setdefault.
If no default is supplied it sets the item to None""" If no default is supplied it sets the item to None"""
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is not None: if key is not None:
return self[key] return self[key]
self.__setitem__(item, default) self.__setitem__(item, default)
@ -172,7 +169,7 @@ class idict(dict):
def get(self, item, default=None): def get(self, item, default=None):
"""A case insensitive get.""" """A case insensitive get."""
key = self.findkey(item) # does the key exist key = self.findkey(item) # does the key exist
if key is not None: if key is not None:
return self[key] return self[key]
return default return default
@ -182,7 +179,7 @@ class idict(dict):
If your dictionary has overlapping keys (e.g. 'FISH' and 'fish') then one will overwrite the other. If your dictionary has overlapping keys (e.g. 'FISH' and 'fish') then one will overwrite the other.
The one that is kept is arbitrary.""" The one that is kept is arbitrary."""
for entry in indict: for entry in indict:
self[entry] = indict[entry] # this uses the new __setitem__ method self[entry] = indict[entry] # this uses the new __setitem__ method
def copy(self): def copy(self):
"""Create a new caselessDict object that is a copy of this one.""" """Create a new caselessDict object that is a copy of this one."""
@ -198,7 +195,7 @@ class idict(dict):
dict.clear(self) dict.clear(self)
def __repr__(self): def __repr__(self):
"""A caselessDict version of __repr__ """ """A caselessDict version of __repr__"""
return 'caselessDict(' + dict.__repr__(self) + ')' return 'caselessDict(' + dict.__repr__(self) + ')'
def __eq__(self, other): def __eq__(self, other):
@ -234,7 +231,7 @@ def batch_generator(gen, *batch_size):
def to_dict_of_set(d): def to_dict_of_set(d):
r = idict({k: set(v) for k, v in d.items()}) r = idict({k: set(v) for k, v in d.items()})
if 'objectclass' in r: if 'objectclass' in r:
r['objectclass'] = set(istr(v) for v in r['objectclass']) r['objectclass'] = {istr(v) for v in r['objectclass']}
return r return r
@ -260,5 +257,5 @@ def str2bytes_entry(entry):
bytes_entry = {} bytes_entry = {}
for key, values in entry.items(): for key, values in entry.items():
bytes_entry[key] = [v.encode('utf-8') if isinstance(v, six.text_type) else v for v in values] bytes_entry[key] = [v.encode('utf-8') if isinstance(v, str) else v for v in values]
return bytes_entry return bytes_entry

View File

@ -1,8 +1,7 @@
from __future__ import print_function import os
import tempfile
import pytest import pytest
import tempfile
import os
from ldaptools.slapd import Slapd from ldaptools.slapd import Slapd
@ -52,7 +51,7 @@ def any_slapd(request, slapd_tcp1, slapd_ssl, slapd_tls):
@pytest.fixture @pytest.fixture
def ldif(): def ldif():
return '''dn: dc=orga2 return """dn: dc=orga2
o: orga o: orga
dc: orga2 dc: orga2
objectClass: organization objectClass: organization
@ -66,7 +65,7 @@ sn: John
givenName: Doe givenName: Doe
mail: john.doe@entrouvert.com mail: john.doe@entrouvert.com
''' """
@pytest.fixture @pytest.fixture
@ -88,8 +87,10 @@ def ldif_path(request, ldif):
with open(path, 'w') as f: with open(path, 'w') as f:
f.write(ldif) f.write(ldif)
f.flush() f.flush()
def finalize(): def finalize():
os.unlink(path) os.unlink(path)
request.addfinalizer(finalize) request.addfinalizer(finalize)
return path return path
@ -101,7 +102,9 @@ def attributes_path(request, attributes):
for attribute in attributes: for attribute in attributes:
print(' %s ' % attribute, file=f) print(' %s ' % attribute, file=f)
f.flush() f.flush()
def finalize(): def finalize():
os.unlink(path) os.unlink(path)
request.addfinalizer(finalize) request.addfinalizer(finalize)
return path return path

View File

@ -5,12 +5,18 @@ from ldaptools.ldapsync.cmd import main
def test_ldapsync_ldif_to_ldapi(slapd, ldif_path, attributes, pivot_attributes): def test_ldapsync_ldif_to_ldapi(slapd, ldif_path, attributes, pivot_attributes):
args = [ args = [
'--source-uri', ldif_path, '--source-uri',
'--source-base-dn', 'dc=orga2', ldif_path,
'--target-uri', slapd.ldap_url, '--source-base-dn',
'--target-base-dn', 'o=orga', 'dc=orga2',
'--attributes', ' '.join(attributes), '--target-uri',
'--source-objectclasses', 'dcObject organization inetOrgPerson', slapd.ldap_url,
'--target-base-dn',
'o=orga',
'--attributes',
' '.join(attributes),
'--source-objectclasses',
'dcObject organization inetOrgPerson',
'--verbose', '--verbose',
] ]
for object_class, pivot_attribute in pivot_attributes: for object_class, pivot_attribute in pivot_attributes:
@ -19,19 +25,23 @@ def test_ldapsync_ldif_to_ldapi(slapd, ldif_path, attributes, pivot_attributes):
main(args) main(args)
conn = slapd.get_connection() conn = slapd.get_connection()
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert (set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
== set(['o=orga', 'uid=admin,o=orga']))
def test_ldapsync_ldif_to_ldapi_attributes_file(slapd, ldif_path, attributes_path, def test_ldapsync_ldif_to_ldapi_attributes_file(slapd, ldif_path, attributes_path, pivot_attributes):
pivot_attributes):
args = [ args = [
'--source-uri', ldif_path, '--source-uri',
'--source-base-dn', 'dc=orga2', ldif_path,
'--target-uri', slapd.ldap_url, '--source-base-dn',
'--target-base-dn', 'o=orga', 'dc=orga2',
'--attributes-file', attributes_path, '--target-uri',
'--source-objectclasses', 'dcObject organization inetOrgPerson', slapd.ldap_url,
'--target-base-dn',
'o=orga',
'--attributes-file',
attributes_path,
'--source-objectclasses',
'dcObject organization inetOrgPerson',
'--verbose', '--verbose',
] ]
for object_class, pivot_attribute in pivot_attributes: for object_class, pivot_attribute in pivot_attributes:
@ -40,8 +50,7 @@ def test_ldapsync_ldif_to_ldapi_attributes_file(slapd, ldif_path, attributes_pat
main(args) main(args)
conn = slapd.get_connection() conn = slapd.get_connection()
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert (set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
== set(['o=orga', 'uid=admin,o=orga']))
def test_ldapsync_ldap_to_ldap(slapd_tcp1, slapd_tcp2, ldif, attributes, pivot_attributes): def test_ldapsync_ldap_to_ldap(slapd_tcp1, slapd_tcp2, ldif, attributes, pivot_attributes):
@ -49,17 +58,26 @@ def test_ldapsync_ldap_to_ldap(slapd_tcp1, slapd_tcp2, ldif, attributes, pivot_a
slapd_tcp1.add_ldif(ldif) slapd_tcp1.add_ldif(ldif)
args = [ args = [
'--source-uri', slapd_tcp1.ldap_url, '--source-uri',
'--source-bind-dn', slapd_tcp1.root_bind_dn, slapd_tcp1.ldap_url,
'--source-bind-password', slapd_tcp1.root_bind_password, '--source-bind-dn',
'--source-base-dn', 'dc=orga2', slapd_tcp1.root_bind_dn,
'--source-bind-password',
'--target-uri', slapd_tcp2.ldap_url, slapd_tcp1.root_bind_password,
'--target-bind-dn', slapd_tcp2.root_bind_dn, '--source-base-dn',
'--target-bind-password', slapd_tcp2.root_bind_password, 'dc=orga2',
'--target-base-dn', 'o=orga', '--target-uri',
'--attributes', ' '.join(attributes), slapd_tcp2.ldap_url,
'--source-objectclasses', 'dcObject organization inetOrgPerson', '--target-bind-dn',
slapd_tcp2.root_bind_dn,
'--target-bind-password',
slapd_tcp2.root_bind_password,
'--target-base-dn',
'o=orga',
'--attributes',
' '.join(attributes),
'--source-objectclasses',
'dcObject organization inetOrgPerson',
'--verbose', '--verbose',
] ]
for object_class, pivot_attribute in pivot_attributes: for object_class, pivot_attribute in pivot_attributes:
@ -68,5 +86,4 @@ def test_ldapsync_ldap_to_ldap(slapd_tcp1, slapd_tcp2, ldif, attributes, pivot_a
main(args) main(args)
conn = slapd_tcp2.get_connection() conn = slapd_tcp2.get_connection()
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert (set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
== set(['o=orga', 'uid=admin,o=orga']))

View File

@ -7,11 +7,15 @@ from ldaptools.ldif_utils import ListLDIFParser
def test_ldifparser(): def test_ldifparser():
parser = ListLDIFParser(StringIO('''dn: o=orga parser = ListLDIFParser(
StringIO(
"""dn: o=orga
objectClass: organization objectClass: organization
jpegPhoto:: E+o9UYDeUDNblBzchRD/1+2HMdI= jpegPhoto:: E+o9UYDeUDNblBzchRD/1+2HMdI=
''')) """
)
)
parser.parse() parser.parse()
assert len(list(parser)) == 1 assert len(list(parser)) == 1
assert list(parser)[0][0] == 'o=orga' assert list(parser)[0][0] == 'o=orga'

View File

@ -1,6 +1,5 @@
import pytest
import ldap import ldap
import pytest
@pytest.mark.parametrize('slapd', [None, 'ldap://localhost:1389'], indirect=True) @pytest.mark.parametrize('slapd', [None, 'ldap://localhost:1389'], indirect=True)
@ -11,14 +10,16 @@ def test_checkpoint(slapd):
slapd.stop() slapd.stop()
slapd.checkpoint() slapd.checkpoint()
slapd.start() slapd.start()
slapd.add_ldif('''dn: uid=admin,o=orga slapd.add_ldif(
"""dn: uid=admin,o=orga
objectclass: person objectclass: person
objectclass: uidObject objectclass: uidObject
uid:in uid:in
cn: n cn: n
sn: n sn: n
''') """
)
conn = slapd.get_connection() conn = slapd.get_connection()
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
slapd.stop() slapd.stop()
@ -35,10 +36,13 @@ def test_any(any_slapd):
def test_ssl_client_cert(slapd_ssl): def test_ssl_client_cert(slapd_ssl):
conn = slapd_ssl.get_connection_admin() conn = slapd_ssl.get_connection_admin()
conn.modify_s('cn=config', [ conn.modify_s(
(ldap.MOD_ADD, 'olcTLSCACertificateFile', slapd_ssl.tls[1].encode('utf-8')), 'cn=config',
(ldap.MOD_ADD, 'olcTLSVerifyClient', b'demand'), [
]) (ldap.MOD_ADD, 'olcTLSCACertificateFile', slapd_ssl.tls[1].encode('utf-8')),
(ldap.MOD_ADD, 'olcTLSVerifyClient', b'demand'),
],
)
with pytest.raises((ldap.SERVER_DOWN, ldap.CONNECT_ERROR)): with pytest.raises((ldap.SERVER_DOWN, ldap.CONNECT_ERROR)):
conn = slapd_ssl.get_connection() conn = slapd_ssl.get_connection()

View File

@ -5,9 +5,9 @@ except ImportError:
import ldap import ldap
from ldaptools.synchronize import Synchronize, Delete, Rename, Update, Create
from ldaptools.ldif_utils import ListLDIFParser
from ldaptools.ldap_source import LDAPSource from ldaptools.ldap_source import LDAPSource
from ldaptools.ldif_utils import ListLDIFParser
from ldaptools.synchronize import Create, Delete, Rename, Synchronize, Update
def test_synchronize_ldif(slapd): def test_synchronize_ldif(slapd):
@ -22,13 +22,13 @@ def test_synchronize_ldif(slapd):
def syn_ldif(ldif): def syn_ldif(ldif):
parser = ListLDIFParser(StringIO(ldif)) parser = ListLDIFParser(StringIO(ldif))
parser.parse() parser.parse()
synchronize = Synchronize(parser, 'o=orga', conn, 'o=orga', synchronize = Synchronize(
pivot_attributes=pivot_attributes, parser, 'o=orga', conn, 'o=orga', pivot_attributes=pivot_attributes, attributes=attributes
attributes=attributes) )
synchronize.run() synchronize.run()
return synchronize return synchronize
ldif = '''dn: o=orga ldif = """dn: o=orga
o: orga o: orga
dc: coucou dc: coucou
objectClass: organization objectClass: organization
@ -42,23 +42,27 @@ sn: John
givenName: Doe givenName: Doe
mail: john.doe@entrouvert.com mail: john.doe@entrouvert.com
''' """
synchronize = syn_ldif(ldif) synchronize = syn_ldif(ldif)
assert all(not action.errors for action in synchronize.actions) assert all(not action.errors for action in synchronize.actions)
assert len(synchronize.actions) == 2 assert len(synchronize.actions) == 2
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
# Rename # Rename
slapd.add_ldif('''dn: ou=people,o=orga slapd.add_ldif(
"""dn: ou=people,o=orga
ou: people ou: people
objectClass: organizationalUnit objectClass: organizationalUnit
''') """
)
conn.rename_s('uid=admin,o=orga', 'cn=John Doe', newsuperior='ou=people,o=orga', delold=0) conn.rename_s('uid=admin,o=orga', 'cn=John Doe', newsuperior='ou=people,o=orga', delold=0)
assert set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) == set(['o=orga', assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {
'ou=people,o=orga', 'o=orga',
'cn=John Doe,ou=people,o=orga']) 'ou=people,o=orga',
'cn=John Doe,ou=people,o=orga',
}
synchronize.run() synchronize.run()
assert not any([action.errors for action in synchronize.actions]) assert not any([action.errors for action in synchronize.actions])
@ -66,17 +70,16 @@ objectClass: organizationalUnit
assert isinstance(synchronize.actions[0], Rename) assert isinstance(synchronize.actions[0], Rename)
assert isinstance(synchronize.actions[1], Delete) assert isinstance(synchronize.actions[1], Delete)
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) == set(['o=orga', assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
'uid=admin,o=orga'])
# Delete one entry # Delete one entry
ldif = '''dn: o=orga ldif = """dn: o=orga
o: orga o: orga
dc: coucou dc: coucou
objectClass: organization objectClass: organization
objectClass: dcobject objectClass: dcobject
''' """
synchronize = syn_ldif(ldif) synchronize = syn_ldif(ldif)
assert all(not action.errors for action in synchronize.actions) assert all(not action.errors for action in synchronize.actions)
assert len(synchronize.actions) == 1 assert len(synchronize.actions) == 1
@ -93,7 +96,7 @@ def test_synchronize_ldap(slapd):
conn = slapd.get_connection_admin() conn = slapd.get_connection_admin()
slapd.add_db('dc=orga2') slapd.add_db('dc=orga2')
ldif = '''dn: dc=orga2 ldif = """dn: dc=orga2
o: orga o: orga
dc: orga2 dc: orga2
objectClass: organization objectClass: organization
@ -107,15 +110,14 @@ sn: John
givenName: Doe givenName: Doe
mail: john.doe@entrouvert.com mail: john.doe@entrouvert.com
''' """
slapd.add_ldif(ldif) slapd.add_ldif(ldif)
source = LDAPSource(conn, base_dn='dc=orga2', attributes=attributes) source = LDAPSource(conn, base_dn='dc=orga2', attributes=attributes)
synchronize = Synchronize(
synchronize = Synchronize(source, 'dc=orga2', conn, 'o=orga', source, 'dc=orga2', conn, 'o=orga', pivot_attributes=pivot_attributes, attributes=attributes
pivot_attributes=pivot_attributes, )
attributes=attributes)
synchronize.run() synchronize.run()
assert all(not action.errors for action in synchronize.actions) assert all(not action.errors for action in synchronize.actions)
@ -123,19 +125,22 @@ mail: john.doe@entrouvert.com
assert isinstance(synchronize.actions[0], Update) assert isinstance(synchronize.actions[0], Update)
assert isinstance(synchronize.actions[1], Create) assert isinstance(synchronize.actions[1], Create)
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) == set(['o=orga', assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
'uid=admin,o=orga'])
# Rename # Rename
slapd.add_ldif('''dn: ou=people,o=orga slapd.add_ldif(
"""dn: ou=people,o=orga
ou: people ou: people
objectClass: organizationalUnit objectClass: organizationalUnit
''') """
)
conn.rename_s('uid=admin,o=orga', 'cn=John Doe', newsuperior='ou=people,o=orga', delold=0) conn.rename_s('uid=admin,o=orga', 'cn=John Doe', newsuperior='ou=people,o=orga', delold=0)
assert set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) == set(['o=orga', assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {
'ou=people,o=orga', 'o=orga',
'cn=John Doe,ou=people,o=orga']) 'ou=people,o=orga',
'cn=John Doe,ou=people,o=orga',
}
synchronize.run() synchronize.run()
assert not any([action.errors for action in synchronize.actions]) assert not any([action.errors for action in synchronize.actions])
@ -143,8 +148,7 @@ objectClass: organizationalUnit
assert isinstance(synchronize.actions[0], Rename) assert isinstance(synchronize.actions[0], Rename)
assert isinstance(synchronize.actions[1], Delete) assert isinstance(synchronize.actions[1], Delete)
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 2
assert set([dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)]) == set(['o=orga', assert {dn for dn, entry in conn.search_s('o=orga', ldap.SCOPE_SUBTREE)} == {'o=orga', 'uid=admin,o=orga'}
'uid=admin,o=orga'])
# Delete one entry # Delete one entry
conn.delete_s('uid=admin,dc=orga2') conn.delete_s('uid=admin,dc=orga2')
@ -155,26 +159,26 @@ objectClass: organizationalUnit
assert isinstance(synchronize.actions[0], Delete) assert isinstance(synchronize.actions[0], Delete)
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 1 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 1
def test_synchronize_deep_rename(slapd): def test_synchronize_deep_rename(slapd):
pivot_attributes = ( pivot_attributes = (
('organization', 'o'), ('organization', 'o'),
('inetOrgPerson', 'uid'), ('inetOrgPerson', 'uid'),
('organizationalUnit', 'ou'), ('organizationalUnit', 'ou'),
) )
attributes = ['o', 'objectClass', 'uid', 'sn', 'givenName', 'mail', 'dc', attributes = ['o', 'objectClass', 'uid', 'sn', 'givenName', 'mail', 'dc', 'cn', 'description', 'ou']
'cn', 'description', 'ou']
conn = slapd.get_connection_admin() conn = slapd.get_connection_admin()
def syn_ldif(ldif): def syn_ldif(ldif):
parser = ListLDIFParser(StringIO(ldif)) parser = ListLDIFParser(StringIO(ldif))
parser.parse() parser.parse()
synchronize = Synchronize(parser, 'o=orga', conn, 'o=orga', synchronize = Synchronize(
pivot_attributes=pivot_attributes, parser, 'o=orga', conn, 'o=orga', pivot_attributes=pivot_attributes, attributes=attributes
attributes=attributes) )
synchronize.run() synchronize.run()
return synchronize return synchronize
ldif = '''dn: o=orga ldif = """dn: o=orga
o: orga o: orga
dc: coucou dc: coucou
objectClass: organization objectClass: organization
@ -193,15 +197,15 @@ sn: John
givenName: Doe givenName: Doe
mail: john.doe@entrouvert.com mail: john.doe@entrouvert.com
''' """
synchronize = syn_ldif(ldif) synchronize = syn_ldif(ldif)
assert all(not action.errors for action in synchronize.actions) assert all(not action.errors for action in synchronize.actions)
assert len(synchronize.actions) == 3 assert len(synchronize.actions) == 3
assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 3 assert len(conn.search_s('o=orga', ldap.SCOPE_SUBTREE)) == 3
# Rename # Rename
ldif = '''dn: o=orga ldif = """dn: o=orga
o: orga o: orga
dc: coucou dc: coucou
objectClass: organization objectClass: organization
@ -220,7 +224,7 @@ sn: John
givenName: Doe givenName: Doe
mail: john.doe@entrouvert.com mail: john.doe@entrouvert.com
''' """
synchronize = syn_ldif(ldif) synchronize = syn_ldif(ldif)