1574 lines
58 KiB
Python
1574 lines
58 KiB
Python
try:
|
|
from hashlib import sha1 as hash_sha1
|
|
except ImportError:
|
|
from sha import sha as hash_sha1
|
|
import base64
|
|
import cPickle
|
|
import os
|
|
import imp
|
|
import lasso
|
|
import crypt
|
|
import random
|
|
import string
|
|
import time
|
|
|
|
from quixote import get_publisher
|
|
|
|
from qommon import get_cfg, get_logger
|
|
from qommon.storage import StorableObject
|
|
from qommon.publisher import sitecharset2utf8, utf82sitecharset
|
|
import qommon.form as qform
|
|
|
|
import misc
|
|
from form import *
|
|
import admin.configuration as configuration
|
|
|
|
ROLE_NONE = 0
|
|
ROLE_ADMIN = 1
|
|
|
|
_singleton = object()
|
|
|
|
class AlreadyExists(Exception):
|
|
pass
|
|
|
|
class WidgetList(quixote.form.widget.WidgetList):
|
|
def render(self):
|
|
get_response().add_javascript(['jquery.js', 'widget_list.js'])
|
|
r = TemplateIO(html=True)
|
|
classnames = '%s widget' % self.__class__.__name__
|
|
r += htmltext('<div class="%s">') % classnames
|
|
r += self.render_title(self.get_title())
|
|
r += htmltext('<div class="content">')
|
|
r += self.render_hint(self.get_hint())
|
|
r += self.render_error(self.get_error())
|
|
add_element_widget = self.get_widget('add_element')
|
|
for widget in self.get_widgets():
|
|
if widget is add_element_widget:
|
|
continue
|
|
r += widget.render()
|
|
r += add_element_widget.render()
|
|
if self.render_br:
|
|
r += htmltext('<br class="%s" />') % classnames
|
|
r += htmltext('\n')
|
|
return r.getvalue()
|
|
|
|
|
|
dumb_questions = {
|
|
'rps': N_('Rock, paper, scissors?'),
|
|
'mother-in-law': N_('Your mother-in-law first name'),
|
|
'best-friend': N_('Your best friend first name'),
|
|
'teacher': N_('Your first teacher name'),
|
|
'dog': N_('Your dog name'),
|
|
'cat': N_('Your cat name'),
|
|
'car': N_('Your car brand'),
|
|
}
|
|
available_dumb_questions = dumb_questions.keys()
|
|
# so a authentic derivative can use less (or others) questions
|
|
|
|
__pwd_letters = list(string.letters) + list(string.digits)
|
|
def create_password(length):
|
|
return ''.join([random.choice(__pwd_letters) for x in range(length)])
|
|
|
|
def ssha_hash_password(password, salt = None):
|
|
'''Salted SHA1 hash of password'''
|
|
if not salt:
|
|
salt = create_password(8)
|
|
else:
|
|
# Extract salt from previous hashing
|
|
salt = base64.b64decode(salt[6:])[20:]
|
|
h = hash_sha1()
|
|
h.update(password)
|
|
h.update(salt)
|
|
return '{SSHA}' + base64.b64encode(h.digest()+salt)
|
|
|
|
def sha_hash_password(password, salt = None):
|
|
h = hash_sha1()
|
|
h.update(password)
|
|
return '{SHA}' + base64.b64encode(h.digest())
|
|
|
|
def crypt_hash_password(password, salt = None):
|
|
if not salt:
|
|
a = random.choice(string.letters)
|
|
b = random.choice(string.letters)
|
|
salt = '%c%c' % (a, b)
|
|
else:
|
|
salt = salt[7:9]
|
|
return '{crypt}%s' % crypt.crypt(password, salt)
|
|
|
|
__hash_schemes = {
|
|
'ssha': ssha_hash_password,
|
|
'sha': sha_hash_password,
|
|
'crypt': crypt_hash_password,
|
|
}
|
|
|
|
def hash_password(password, format, salt = None):
|
|
scheme = __hash_schemes.get(format)
|
|
if scheme is None:
|
|
return password
|
|
return scheme(password, salt = salt)
|
|
|
|
def get_hash_format(pwd):
|
|
if pwd and pwd[0] == '{':
|
|
i = pwd.find('}', 1)
|
|
if i != -1:
|
|
return pwd[1:i].lower()
|
|
return None
|
|
|
|
def hashed_password_equals(pwd1, pwd2):
|
|
format = get_hash_format(pwd1)
|
|
if format:
|
|
prefix_len = len(format) + 2
|
|
return pwd1[prefix_len:] == hash_password(pwd2, format, salt = pwd1)[prefix_len:]
|
|
else:
|
|
return pwd1 == pwd2
|
|
|
|
def try_int(s):
|
|
try:
|
|
return int(s)
|
|
except ValueError:
|
|
return -1
|
|
|
|
|
|
class Field:
|
|
def __init__(self, key, title, admin_required = False,
|
|
on_register = True, widget_class = 'StringWidget',
|
|
read_only = False, unique = False, multivalued = False,
|
|
invisible = False, **keywords):
|
|
'''
|
|
Field defines parts of the user identity
|
|
|
|
admin_required whether the field is required even in the
|
|
administrative part
|
|
on_register whether the registry page should show a field for
|
|
this information
|
|
widget_class the widget to use for showing value of this field
|
|
read_only whether the field can be modified by the user
|
|
unique whether the field must be unique for recorde in the
|
|
identity storage (adds a check when saving the user
|
|
record)
|
|
multivalued whether the field can have multiple values for the
|
|
same user
|
|
'''
|
|
self.key = key
|
|
self.title = title
|
|
self.keywords = keywords
|
|
self.admin_required = admin_required
|
|
if isinstance(widget_class, str):
|
|
self.widget_class = getattr(qform, widget_class)
|
|
else:
|
|
self.widget_class = widget_class
|
|
self.on_register = on_register
|
|
self.keywords = keywords
|
|
self.read_only = read_only
|
|
self.unique = unique
|
|
self.multivalued = multivalued
|
|
self.invisible = invisible
|
|
|
|
def get_value(self, identity = None):
|
|
'''Extra value from the identity object for this field, or return the
|
|
default value from the keyword dictionnary'''
|
|
value = self.keywords.pop('value', None)
|
|
if identity is not None:
|
|
value = getattr(identity, self.key, value)
|
|
if self.multivalued:
|
|
if value is None:
|
|
value = []
|
|
if not isinstance(value, list):
|
|
value = [ value ]
|
|
return value
|
|
|
|
def get_one_value_human_representation(self, value):
|
|
'''Return a human representation for a scalar value'''
|
|
if value is True:
|
|
return _('true')
|
|
if value is False:
|
|
return _('false')
|
|
if value is None:
|
|
return ''
|
|
if isinstance(value, str) or isinstance(value, unicode):
|
|
return value
|
|
return str(value)
|
|
|
|
def get_human_representation(self, identity = None):
|
|
value = self.get_value(identity)
|
|
if self.multivalued:
|
|
return ', '.join(map(self.get_one_value_human_representation, value))
|
|
else:
|
|
return self.get_one_value_human_representation(value)
|
|
|
|
def get_tabular_representation(self, identity = None):
|
|
value = self.get_value(identity)
|
|
if self.multivalued:
|
|
elements = [ htmltext('<div class="value">%s</div>') % \
|
|
self.get_one_value_human_representation(v) for v in value]
|
|
return htmltext(''.join(map(str, elements)))
|
|
else:
|
|
return self.get_one_value_human_representation(value)
|
|
|
|
def add_to_form(self, form, identity = None, admin = False):
|
|
widget_class = self.widget_class
|
|
if self.key == 'email':
|
|
widget_class = EmailWidget
|
|
if self.keywords.get('regex'):
|
|
widget_class = ValidatedStringWidget
|
|
keywords = dict(self.keywords)
|
|
if admin:
|
|
keywords['required'] = self.admin_required
|
|
value = keywords.pop('value', None)
|
|
if identity:
|
|
value = getattr(identity, self.key, value)
|
|
if 'options' in keywords:
|
|
# translate caption
|
|
new_options = []
|
|
for p in keywords['options']:
|
|
if isinstance(p, (list, tuple)):
|
|
p = (p[0], _(p[1]))
|
|
else:
|
|
p = _(p)
|
|
new_options.append(p)
|
|
keywords['options'] = new_options
|
|
if not self.read_only or admin:
|
|
if self.multivalued:
|
|
keywords['render_br'] = False
|
|
form.add(WidgetList, self.key, title = _(self.title),
|
|
value = value,
|
|
hint = keywords.pop('hint', None),
|
|
required = keywords.pop('required', False),
|
|
element_type = widget_class, element_kwargs = keywords,
|
|
add_element_label = _('Add value'))
|
|
else:
|
|
if value is not None:
|
|
keywords['value'] = value
|
|
form.add(widget_class, self.key, title = _(self.title),
|
|
**keywords)
|
|
elif not self.invisible:
|
|
form.add(HtmlWidget, htmltext('''<div class="HtmlWidget widget">
|
|
<div class="title">%s:</div>
|
|
<div class="content">%s </div></div>''') % (_(self.title),
|
|
self.get_human_representation(identity)))
|
|
return form.get_widget(self.key)
|
|
|
|
class IdentityStoreException(Exception):
|
|
pass
|
|
|
|
class TooMuchAccounts(IdentityStoreException):
|
|
pass
|
|
|
|
class Identity(StorableObject):
|
|
_names = 'identities'
|
|
|
|
id = None
|
|
name = None
|
|
email = None
|
|
accounts = None
|
|
roles = None
|
|
lasso_dump = None # identity dump
|
|
lasso_proxy_dump = None # identity dump, when acting as service provider
|
|
lasso_proxy_dump_migrated = False
|
|
resource_id = None
|
|
entry_id = None
|
|
disabled = False
|
|
proxied_identity_origin = None
|
|
shared_attributes = None
|
|
|
|
ecp_id = None
|
|
|
|
locked_down = False
|
|
|
|
def __init__(self):
|
|
self.accounts = []
|
|
|
|
def __str__(self):
|
|
if self.name:
|
|
return self.name
|
|
if self.email:
|
|
return self.email
|
|
return str(self.id)
|
|
|
|
def migrate(self):
|
|
if hasattr(self, 'uid'):
|
|
self.id = self.uid
|
|
del self.uid
|
|
self.store()
|
|
if str(self.id).startswith('https-') and not self.proxied_identity_origin:
|
|
proxied_idp = '-'.join(self.id.split('-')[:-1])
|
|
p = None
|
|
for klp, lp in get_cfg('providers', {}).items():
|
|
if lp['role'] != lasso.PROVIDER_ROLE_IDP:
|
|
continue
|
|
p = lasso.Provider(lp['role'],
|
|
misc.get_abs_path(lp['metadata']),
|
|
misc.get_abs_path(lp['publickey']), None)
|
|
base_key = misc.get_provider_key(p.providerId)
|
|
if base_key == proxied_idp:
|
|
self.proxied_identity_origin = p.providerId
|
|
self.store()
|
|
break
|
|
if not self.lasso_proxy_dump_migrated:
|
|
self.lasso_proxy_dump_migrated = True
|
|
if not self.lasso_proxy_dump_migrated and get_cfg('idp', {}).get('idff_proxy'):
|
|
# bootstrap with a copy
|
|
self.lasso_proxy_dump = self.lasso_dump
|
|
self.store()
|
|
|
|
if not self.ecp_id:
|
|
self.ecp_id = ''.join([random.choice(string.lowercase) for x in range(4)])
|
|
self.store()
|
|
|
|
def is_admin(self):
|
|
return ROLE_ADMIN in (self.roles or [])
|
|
|
|
def get_dst_view(self, href):
|
|
if href == lasso.PP_HREF:
|
|
return self.get_pp_view()
|
|
raise 'Unknown dst'
|
|
|
|
def get_pp_view(self):
|
|
l = ['<PP xmlns="urn:liberty:id-sis-pp:2003-08">']
|
|
if self.name:
|
|
l.append(' <InformalName>%s</InformalName>' % sitecharset2utf8(str(self)))
|
|
if self.email:
|
|
l.append(' <MsgContact><MsgAccount>%s</MsgAccount><MsgProvider>%s</MsgProvider></MsgContact>' % tuple(self.email.split('@')))
|
|
l.append('</PP>')
|
|
return '\n'.join(l)
|
|
|
|
def get_preferred_language(self):
|
|
return None
|
|
|
|
def get_display_name(self):
|
|
if self.name:
|
|
return self.name
|
|
if self.email:
|
|
return self.email
|
|
return '#%s' % self.id
|
|
display_name = property(get_display_name)
|
|
|
|
class CertificateAccount:
|
|
certificate_sha1 = None
|
|
dn = None
|
|
|
|
def __repr__(self):
|
|
return '<CertificateAccount: certificate_sha1: %s' % self.certificate_sha1 + '>'
|
|
|
|
def __init__(self, sha1 = None, certificate = None, dn = None):
|
|
self.dn = dn
|
|
if certificate:
|
|
start = '-----BEGIN CERTIFICATE-----\n'
|
|
end = '\n-----END CERTIFICATE-----'
|
|
i = certificate.find(start)
|
|
j = certificate.find(end)
|
|
if i != -1:
|
|
i = i + len(start)
|
|
certificate = certificate[i:j]
|
|
der = base64.b64decode(certificate)
|
|
h = hash_sha1()
|
|
h.update(der)
|
|
sha1 = h.hexdigest()
|
|
self.certificate_sha1 = sha1
|
|
|
|
def equals(self, account):
|
|
return hasattr(account, 'certificate_sha1') and self.certificate_sha1 == account.certificate_sha1
|
|
|
|
class PasswordAccount:
|
|
username = None
|
|
password = None
|
|
dumb_question = None
|
|
smart_answer = None
|
|
|
|
def equals(self, account):
|
|
if not isinstance(account, PasswordAccount):
|
|
return False
|
|
if self.username != account.username:
|
|
return False
|
|
return hashed_password_equals(self.password, account.password)
|
|
|
|
def type(self):
|
|
return _("Password Account")
|
|
|
|
def get_pwd_length():
|
|
'''Compute a password length from the min and max password settings'''
|
|
pwd_cfg = configuration.get_configuration('passwords')
|
|
min_pw_length = pwd_cfg.get('min_length')
|
|
max_pw_length = pwd_cfg.get('max_length')
|
|
pw_length = max([min_pw_length, 6])
|
|
if max_pw_length:
|
|
pw_length = min([pw_length, max_pw_length])
|
|
return pw_length
|
|
|
|
def get_pwd_hashing_scheme():
|
|
pwd_cfg = configuration.get_configuration('passwords')
|
|
return pwd_cfg.get('hashed_scheme')
|
|
|
|
class BaseIdentitiesStore:
|
|
fields = [ Field('name', N_('Name'), required = True, size = 30),
|
|
Field('email', N_('Email'), required = True, size = 30) ]
|
|
|
|
username_regex = r'[a-zA-Z0-9.-]+$'
|
|
identity_class = Identity
|
|
|
|
def create_password(self,for_account='unknown'):
|
|
'''Create a new password for an user contained in this store'''
|
|
password = create_password(get_pwd_length())
|
|
get_logger().debug('Generated new password %r for account %r' % (password, for_account))
|
|
return password
|
|
|
|
def hash_password(self, password):
|
|
return hash_password(password, get_pwd_hashing_scheme())
|
|
|
|
def get_identity_for_account(self, account):
|
|
for i in self.values():
|
|
for a in i.accounts or []:
|
|
if a.equals(account):
|
|
return i
|
|
return None
|
|
|
|
def load_identities(self):
|
|
pass
|
|
|
|
def reload_identities(self):
|
|
pass
|
|
|
|
def has_identity_with_username(self, username):
|
|
return self.get_identity_for_username(username) is not None
|
|
|
|
def get_identity_for_username(self, username, throw=False):
|
|
for i in self.values():
|
|
for a in i.accounts or []:
|
|
if hasattr(a, 'username') and a.username == username:
|
|
return i
|
|
return None
|
|
|
|
def get_identity_for_name_identifier(self, name_identifier):
|
|
for i in self.values():
|
|
if not (i.lasso_dump or i.lasso_proxy_dump):
|
|
continue
|
|
if not ((name_identifier in (i.lasso_dump or '')) or (
|
|
name_identifier in (i.lasso_proxy_dump or ''))):
|
|
continue # shortcut so it doesn't do the whole thing if not necessary
|
|
if i.lasso_dump:
|
|
identity = lasso.Identity.newFromDump(i.lasso_dump)
|
|
for p in identity.providerIds:
|
|
federation = identity.getFederation(p)
|
|
if federation.localNameIdentifier and \
|
|
federation.localNameIdentifier.content == name_identifier:
|
|
return i
|
|
if federation.remoteNameIdentifier and \
|
|
federation.remoteNameIdentifier.content == name_identifier:
|
|
return i
|
|
|
|
if i.lasso_proxy_dump:
|
|
identity = lasso.Identity.newFromDump(i.lasso_proxy_dump)
|
|
for p in identity.providerIds:
|
|
federation = identity.getFederation(p)
|
|
if federation.localNameIdentifier and \
|
|
federation.localNameIdentifier.content == name_identifier:
|
|
return i
|
|
if federation.remoteNameIdentifier and \
|
|
federation.remoteNameIdentifier.content == name_identifier:
|
|
return i
|
|
|
|
return None
|
|
|
|
def count(self, letter = None):
|
|
return -1
|
|
|
|
def is_big(self):
|
|
return False
|
|
|
|
def values_by_letter(self, letter, limit = 0, offset = 0):
|
|
letter = letter.lower()
|
|
return [x for x in self.values() if x.name and x.name.lower().startswith(letter)]
|
|
|
|
def get_identities_by_attributes(self, map):
|
|
'''Select identity objects by attributes'''
|
|
result = []
|
|
for x in self.values():
|
|
for key, value in map.iteritems():
|
|
if getattr(x, key, _singleton) != value:
|
|
break
|
|
else:
|
|
result.append(x)
|
|
return result
|
|
|
|
def last_modified(self):
|
|
return time.time()
|
|
|
|
def last_modified_uid(self, uid):
|
|
return time.time()
|
|
|
|
def get_identity_class(cls):
|
|
return cls.identity_class
|
|
get_identity_class = classmethod(get_identity_class)
|
|
|
|
def administrators(self):
|
|
return [ identity for identity in self.values() \
|
|
if identity.is_admin() ]
|
|
|
|
|
|
class IdentitiesStorePickle(BaseIdentitiesStore):
|
|
label = N_('Old default storage (pickled file)')
|
|
|
|
def __init__(self):
|
|
self.identities = None
|
|
|
|
def is_bootstrapping(self):
|
|
return len(self.identities.keys()) == 0
|
|
|
|
def load_identities(self):
|
|
if not self.identities:
|
|
self.reload_identities()
|
|
|
|
def reload_identities(self):
|
|
filename = os.path.join(get_publisher().app_dir, 'identities.pck')
|
|
try:
|
|
self.identities = cPickle.load(file(filename))
|
|
except:
|
|
self.identities = {}
|
|
for i in self.identities.values():
|
|
if not i.id:
|
|
i.id = i.uid
|
|
|
|
def count(self, letter = None):
|
|
return len(self.identities)
|
|
|
|
def keys(self):
|
|
return self.identities.keys()
|
|
|
|
def values(self, limit = 0, offset = 0):
|
|
return self.identities.values()
|
|
|
|
def get_identity(self, uid):
|
|
return self.identities[uid]
|
|
|
|
def has_identity(self, uid):
|
|
return self.identities.has_key(uid)
|
|
|
|
def add(self, identity):
|
|
if not identity.id:
|
|
identity.id = self.get_new_uid()
|
|
self.identities[identity.id] = identity
|
|
|
|
def remove(self, identity):
|
|
del self.identities[identity.id]
|
|
self.save(None)
|
|
|
|
def save(self, identity):
|
|
# there is no way to save a single identity; dumps them all
|
|
filename = os.path.join(get_publisher().app_dir, 'identities.pck')
|
|
s = cPickle.dumps(self.identities)
|
|
file(filename, 'w').write(s)
|
|
|
|
def get_new_uid(self):
|
|
uids = [try_int(x.id) for x in self.values()]
|
|
if len(uids) == 0:
|
|
return 1
|
|
else:
|
|
return max(uids)+1
|
|
|
|
def connect(self, session):
|
|
pass
|
|
|
|
class MiniIdentityLdap(StorableObject):
|
|
_names = 'ldap_identities'
|
|
roles = None
|
|
lasso_dump = None
|
|
lasso_proxy_dump = None
|
|
uid = None
|
|
resource_id = None
|
|
disabled = False
|
|
|
|
|
|
|
|
try:
|
|
import ldap
|
|
import ldap.async
|
|
import ldap.modlist
|
|
import ldap.filter
|
|
except ImportError:
|
|
ldap = None
|
|
|
|
class IdentitiesStoreLdap(BaseIdentitiesStore):
|
|
# use LDAP for standard attributes and authentic storage (cpickle) for
|
|
# authentic specific attributes (roles, lasso_dump)
|
|
|
|
label = N_('LDAP Directory')
|
|
admin_keys = ('ldap_url', 'ldap_new_user_base', 'ldap_base',
|
|
'ldap_bind_dn', 'ldap_bind_password', 'ldap_object_class',
|
|
'ldap_object_uid', 'ldap_object_name', 'ldap_object_email',
|
|
'ldap_big_directory', 'ldap_published_attributes',
|
|
'ldap_field_mapping','ldap_read_only')
|
|
default_field_mapping = { 'email': 'mail', 'name': 'cn' }
|
|
|
|
ldap_object_name = None
|
|
ldap_object_classes = None
|
|
ldap_bind_dn = None
|
|
ldap_bind_password = None
|
|
ldap_published_attributes = None
|
|
ldap_field_mapping = default_field_mapping
|
|
|
|
def __init__(self):
|
|
self.ldap_conn = None
|
|
|
|
def connect_admin(self):
|
|
if not self.ldap_conn:
|
|
self.ldap_conn = ldap.initialize(self.ldap_url)
|
|
self.ldap_conn.protocol_version = ldap.VERSION3
|
|
try:
|
|
if self.ldap_bind_dn is not None:
|
|
if self.ldap_bind_password is None:
|
|
self.ldap_conn.simple_bind(self.ldap_bind_dn)
|
|
else:
|
|
self.ldap_conn.simple_bind(self.ldap_bind_dn, self.ldap_bind_password)
|
|
except ldap.LDAPError:
|
|
pass
|
|
|
|
def field_name2attribute(self, field_name):
|
|
return self.ldap_field_mapping.get(field_name, field_name)
|
|
|
|
def field2attribute(self, field):
|
|
return self.field_name2attribute(field.key)
|
|
|
|
def ldap2identity(self, v):
|
|
i = self.get_identity_class()()
|
|
i.id = v[0]
|
|
if self.ldap_object_name:
|
|
keys = self.ldap_object_name.split('+')
|
|
values = [ v[1].get(key.strip())[0] for key in keys if key in v[1] ]
|
|
if values:
|
|
i.name = ' '.join([utf82sitecharset(x) for x in values if x])
|
|
else:
|
|
i.name = _('Unknown')
|
|
else:
|
|
i.name = _('Unknown')
|
|
i.email = v[1].get(self.ldap_object_email, [None])[0]
|
|
|
|
# Deprecated: use custom fields
|
|
if self.ldap_published_attributes:
|
|
i.attributes = {}
|
|
i.attributes['dn'] = i.id
|
|
for attr in self.ldap_published_attributes.split(' '):
|
|
if v[1].has_key(attr):
|
|
i.attributes[attr] = [utf82sitecharset(x) for x in v[1][attr]]
|
|
|
|
# Load custom fields
|
|
for field in self.fields:
|
|
attribute = self.field2attribute(field)
|
|
if attribute in v[1]:
|
|
if field.multivalued:
|
|
values = [ utf82sitecharset(value) for value in v[1][attribute] ]
|
|
setattr(i, field.key, values)
|
|
else:
|
|
setattr(i, field.key, utf82sitecharset(v[1][attribute][0]))
|
|
try:
|
|
pa = PasswordAccount()
|
|
pa.username = v[1][self.ldap_object_uid][0]
|
|
pa.password = v[1].get('userPassword', [None])[0]
|
|
i.accounts = [pa]
|
|
except KeyError:
|
|
i.accounts = []
|
|
|
|
try:
|
|
mini = MiniIdentityLdap.get(i.id)
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
i.roles = mini.roles[:]
|
|
i.lasso_dump = mini.lasso_dump
|
|
i.lasso_proxy_dump = mini.lasso_proxy_dump
|
|
i.resource_id = mini.resource_id
|
|
i.disabled = mini.disabled
|
|
return i
|
|
|
|
def identity2ldap(self, identity):
|
|
# Corresponding object class instance
|
|
identity_object_class = None
|
|
# entry is the ldap Modlist dictionnary of `identity'
|
|
entry = {}
|
|
entry[self.ldap_object_uid] = [identity.accounts[0].username]
|
|
entry[self.ldap_object_email] = [identity.email]
|
|
for account in identity.accounts:
|
|
if isinstance(account, PasswordAccount):
|
|
entry['userPassword'] = [account.password]
|
|
break
|
|
|
|
if self.ldap_object_classes is None:
|
|
self.ldap_object_classes = []
|
|
# Construct object classes tree (WARNING: only one way tree climb).
|
|
sub = ldap.schema.urlfetch(self.ldap_url)
|
|
object_class = sub[1].get_obj(ldap.schema.ObjectClass,
|
|
self.ldap_object_class)
|
|
identity_object_class = object_class
|
|
self.ldap_object_classes.append(self.ldap_object_class)
|
|
while object_class.sup:
|
|
sup_class_name = object_class.sup[0]
|
|
self.ldap_object_classes.append(sup_class_name)
|
|
object_class = sub[1].get_obj(ldap.schema.ObjectClass, sup_class_name)
|
|
self.ldap_object_classes.reverse()
|
|
entry['objectclass'] = self.ldap_object_classes
|
|
|
|
# Convert custom fields
|
|
for field in self.fields:
|
|
attribute = self.field2attribute(field)
|
|
value = getattr(identity, field.key, None)
|
|
if value is not None:
|
|
if field.multivalued:
|
|
entry[attribute] = [sitecharset2utf8(str(v)) for v in value]
|
|
else:
|
|
entry[attribute] = [sitecharset2utf8(str(value))]
|
|
|
|
return identity.id, entry
|
|
|
|
def init(self):
|
|
if not self.ldap_object_uid:
|
|
self.ldap_object_uid = 'uid'
|
|
if not self.ldap_object_name:
|
|
self.ldap_object_name = 'cn'
|
|
if not self.ldap_object_email:
|
|
self.ldap_object_email = 'mail'
|
|
oc = self.ldap_object_class
|
|
name_field = self.ldap_object_name.split('+')[0]
|
|
uid = self.ldap_object_uid
|
|
self.uid_request_str = '(&(%(uid)s=%%(username)s)(objectclass=%(oc)s))' % {
|
|
'uid': self.ldap_object_uid, 'oc': oc }
|
|
self.class_request_str = 'objectclass=%s' % oc
|
|
self.alpha_request_str = '(&(%(name)s=%%s*)(objectclass=%(oc)s))' % {
|
|
'name': name_field, 'oc': oc }
|
|
|
|
def is_bootstrapping(self):
|
|
return MiniIdentityLdap.count() == 0
|
|
|
|
def connect(self, session):
|
|
if session and hasattr(session, 'ldap_infos'):
|
|
try:
|
|
self.ldap_conn.simple_bind_s(session.ldap_infos[0], session.ldap_infos[1])
|
|
except ldap.INVALID_CREDENTIALS:
|
|
return
|
|
|
|
def load_identities(self):
|
|
if not ldap:
|
|
raise IdentityStoreException('Missing ldap module')
|
|
if not self.ldap_conn:
|
|
self.ldap_conn = ldap.initialize(self.ldap_url)
|
|
|
|
def keys(self):
|
|
raise NotImplementedError
|
|
|
|
def values(self, limit = 0, offset = 0):
|
|
self.connect_admin()
|
|
s = ldap.async.List(self.ldap_conn)
|
|
s.startSearch(self.ldap_base, ldap.SCOPE_SUBTREE,
|
|
self.class_request_str, sizelimit = limit)
|
|
try:
|
|
partial = s.processResults()
|
|
except ldap.SIZELIMIT_EXCEEDED:
|
|
pass
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
else:
|
|
if partial:
|
|
# it is ignored; this info should be passed down
|
|
pass
|
|
return [self.ldap2identity(v[1]) for v in s.allResults]
|
|
|
|
def count(self, letter = None):
|
|
self.connect_admin()
|
|
s = ldap.async.List(self.ldap_conn)
|
|
if letter is None:
|
|
request = self.class_request_str
|
|
else:
|
|
letter = sitecharset2utf8(letter)
|
|
letter = ldap.filter.escape_filter_chars(letter)
|
|
request = self.alpha_request_str % letter
|
|
try:
|
|
s.startSearch(self.ldap_base, ldap.SCOPE_SUBTREE,
|
|
request, attrsOnly = 1, sizelimit = 99999)
|
|
partial = s.processResults()
|
|
except ldap.SIZELIMIT_EXCEEDED:
|
|
pass
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
else:
|
|
if partial:
|
|
# it is ignored; this info should be passed down
|
|
pass
|
|
return len(s.allResults)
|
|
|
|
def values_by_letter(self, letter, limit = None, offset = 0):
|
|
self.connect_admin()
|
|
s = ldap.async.List(self.ldap_conn)
|
|
letter = sitecharset2utf8(letter)
|
|
letter = ldap.filter.escape_filter_chars(letter)
|
|
try:
|
|
s.startSearch(self.ldap_base, ldap.SCOPE_SUBTREE,
|
|
self.alpha_request_str % letter, sizelimit = limit + offset)
|
|
partial = s.processResults()
|
|
except ldap.SIZELIMIT_EXCEEDED:
|
|
pass
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
else:
|
|
if partial:
|
|
# it is ignored; this info should be passed down
|
|
pass
|
|
return [self.ldap2identity(v[1]) for v in s.allResults[offset:offset+limit]]
|
|
|
|
def get_identity(self, uid):
|
|
self.connect_admin()
|
|
try:
|
|
if type(uid) == int: #uid is an int => new account.
|
|
raise KeyError
|
|
if uid is None:
|
|
raise KeyError
|
|
result = self.ldap_conn.search_s(uid, ldap.SCOPE_BASE)
|
|
if len(result) > 1:
|
|
raise TooMuchAccounts('Too much answer for dn:%s' % uid)
|
|
if len(result) == 0:
|
|
raise KeyError
|
|
return self.ldap2identity(result[0])
|
|
except IndexError:
|
|
raise KeyError
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
except ldap.LDAPError, e:
|
|
raise KeyError
|
|
|
|
def has_identity(self, uid):
|
|
try:
|
|
self.get_identity(uid)
|
|
except KeyError:
|
|
return False
|
|
return True
|
|
|
|
def add(self, identity):
|
|
uid = ldap.dn.escape_dn_chars(identity.accounts[0].username)
|
|
# Make the id right
|
|
identity.id = '%s=%s,%s' % (self.ldap_object_uid, uid,
|
|
self.ldap_new_user_base or self.ldap_base)
|
|
try:
|
|
existing_identity = self.get_identity(identity.id)
|
|
return self.save(identity)
|
|
except KeyError:
|
|
pass
|
|
# New identity, id == dn
|
|
dn, entry = self.identity2ldap(identity)
|
|
# Naming attribute must be present
|
|
if self.ldap_object_uid in entry and \
|
|
entry[self.ldap_object_uid] != [ identity.accounts[0].username ]:
|
|
raise IdentityStoreException('Mismatch of naming attribute: %s != %s'
|
|
% (entry[self.ldap_object_uid],
|
|
[ identity.accounts[0].username ]))
|
|
entry[self.ldap_object_uid] = [ identity.accounts[0].username ]
|
|
self.connect_admin()
|
|
addList = ldap.modlist.addModlist(entry)
|
|
|
|
# Verify uniqueness of attributes
|
|
unique_fields = [ field for field in self.fields if field.unique ]
|
|
if unique_fields:
|
|
keys = [ (field.key, self.field2attribute(field)) for field in unique_fields ]
|
|
predicates = []
|
|
for attribute, ldap_key in keys:
|
|
value = ldap.filter.escape_filter_chars(getattr(identity, attribute))
|
|
predicates.append('(%s=%s)' % (ldap_key, value))
|
|
predicate = '(|%s)' % ''.join(predicates)
|
|
try:
|
|
result = self.ldap_conn.search_s(self.ldap_base,
|
|
ldap.SCOPE_SUBTREE,
|
|
predicate)
|
|
if result:
|
|
raise AlreadyExists([key[0] for key in keys])
|
|
except ldap.LDAPError, e:
|
|
raise IdentityStoreException(str(e))
|
|
try:
|
|
self.ldap_conn.add_s(dn, addList)
|
|
except ldap.ALREADY_EXISTS:
|
|
raise AlreadyExists()
|
|
|
|
def remove(self, identity):
|
|
self.connect_admin()
|
|
self.ldap_conn.delete_s(identity.id)
|
|
try:
|
|
MiniIdentityLdap.remove_object(identity.id)
|
|
except Exception, e:
|
|
get_logger().exception('IdentitiesStoreLdap.remove: failure to remove MiniIdentityLdap for %s' % identity.id)
|
|
|
|
def get_identity_for_account(self, account):
|
|
if not isinstance(account, PasswordAccount):
|
|
return
|
|
self.connect_admin()
|
|
try:
|
|
escaped_username = ldap.filter.escape_filter_chars(account.username)
|
|
result = self.ldap_conn.search_s(self.ldap_base,
|
|
ldap.SCOPE_SUBTREE,
|
|
self.uid_request_str % { 'username': escaped_username } )
|
|
if not result:
|
|
return
|
|
else:
|
|
uid = result[0][0]
|
|
self.ldap_conn.simple_bind_s(uid, account.password)
|
|
except ldap.INVALID_CREDENTIALS:
|
|
return
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
except IndexError:
|
|
return
|
|
identity = self.get_identity(uid)
|
|
# Workaroung in order to keep the password clear in memory.
|
|
identity.accounts[0].password = account.password
|
|
return identity
|
|
|
|
def has_identity_with_username(self, username):
|
|
self.connect_admin()
|
|
try:
|
|
escaped_username = ldap.filter.escape_filter_chars(username)
|
|
result = self.ldap_conn.search_s(self.ldap_base,
|
|
ldap.SCOPE_SUBTREE,
|
|
self.uid_request_str % { 'username': escaped_username })
|
|
if result:
|
|
return True
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
except ldap.LDAPError, e:
|
|
get_logger().error('%s.has_identity_for_username: %s' % (type(self), e))
|
|
return False
|
|
|
|
def get_identity_for_username(self, username, throw=False):
|
|
self.connect_admin()
|
|
try:
|
|
escaped_username = ldap.filter.escape_filter_chars(username)
|
|
result = self.ldap_conn.search_s(self.ldap_base,
|
|
ldap.SCOPE_SUBTREE,
|
|
self.uid_request_str % { 'username': escaped_username })
|
|
if len(result) == 0:
|
|
return None
|
|
if len(result) > 1:
|
|
raise TooMuchAccounts('Too much accounts for username:%s' % username)
|
|
return self.ldap2identity(result[0])
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
except IdentityStoreException, e:
|
|
get_logger().error('%s.get_identity_for_username: %s' % (type(self), e))
|
|
if throw:
|
|
raise
|
|
except ldap.LDAPError, e:
|
|
get_logger().error('%s.get_identity_for_username: %s' % (type(self), e))
|
|
return None
|
|
|
|
def init_session(self, session, account):
|
|
self.connect_admin()
|
|
# adds extraneous data to session
|
|
uid = self.ldap_conn.search_s(self.ldap_base,
|
|
ldap.SCOPE_SUBTREE,
|
|
self.uid_request_str % { 'username': account.username})[0][0]
|
|
session.ldap_infos = (uid, account.password)
|
|
|
|
def is_big(self):
|
|
return self.ldap_big_directory
|
|
|
|
def save(self, identity):
|
|
mini = MiniIdentityLdap()
|
|
mini.id = identity.id
|
|
mini.uid = identity.id
|
|
mini.disabled = identity.disabled
|
|
if identity.roles:
|
|
mini.roles = identity.roles[:]
|
|
else:
|
|
mini.roles = []
|
|
mini.lasso_dump = identity.lasso_dump
|
|
mini.lasso_proxy_dump = identity.lasso_proxy_dump
|
|
mini.resource_id = identity.resource_id
|
|
mini.store()
|
|
|
|
if self.ldap_read_only:
|
|
return False
|
|
|
|
# Save LDAP stored data, administrative bind, because we do not have the
|
|
# session password. How to access session variables ? Is get_session()
|
|
# acceptable ?
|
|
self.connect_admin()
|
|
entry = self.ldap_conn.search_s(identity.id, ldap.SCOPE_BASE)[0]
|
|
newEntry = entry[1].copy()
|
|
modified = False
|
|
if entry[1].has_key(self.ldap_object_name) and \
|
|
entry[1][self.ldap_object_name][0] != sitecharset2utf8(identity.name):
|
|
newEntry[self.ldap_object_name] = [sitecharset2utf8(identity.name)]
|
|
modified = True
|
|
if entry[1].has_key(self.ldap_object_email) and \
|
|
entry[1][self.ldap_object_email][0] != identity.email:
|
|
newEntry[self.ldap_object_email] = [identity.email]
|
|
modified = True
|
|
# Save custom fields
|
|
for field in self.fields:
|
|
attribute = self.field2attribute(field)
|
|
value = getattr(identity, field.key, None)
|
|
if field.multivalued:
|
|
if value is None or value == '':
|
|
value = []
|
|
elif not isinstance(value, list):
|
|
value = [ str(value) ]
|
|
old_values = set(entry[1].get(attribute, []))
|
|
new_values = set(value)
|
|
if old_values != new_values:
|
|
newEntry[attribute] = value
|
|
else:
|
|
if value is not None:
|
|
value = sitecharset2utf8(str(value))
|
|
if entry[1].get(attribute, [None])[0] != value:
|
|
# None must mean, no more value
|
|
if value is None and attribute in entry[1]:
|
|
del newEntry[attribute]
|
|
else:
|
|
newEntry[attribute] = [value]
|
|
modified = True
|
|
# If identity.accounts[0].password, we have the same crypted version,
|
|
# we do not modify it. FIXME ?
|
|
for account in identity.accounts:
|
|
if isinstance(account, PasswordAccount):
|
|
break
|
|
else:
|
|
account = None
|
|
if account and entry[1].has_key('userPassword') and \
|
|
account.password not in entry[1]['userPassword']:
|
|
newEntry['userPassword'] = account.password
|
|
modified = True
|
|
|
|
if modified:
|
|
modList = ldap.modlist.modifyModlist(entry[1], newEntry)
|
|
try:
|
|
self.ldap_conn.modify_s(entry[0], modList)
|
|
get_logger().debug('LDAPStore: updating account %s' % account.username)
|
|
except ldap.STRONG_AUTH_REQUIRED:
|
|
# failed to modify, probaly because not logged in, probably
|
|
# because called from SOAP for a defederate or something,
|
|
# ignore; the federation data is anyway stored elsewhere
|
|
return False
|
|
except ldap.INVALID_SYNTAX:
|
|
return False
|
|
return modified
|
|
|
|
def get_identity_for_name_identifier(self, name_identifier):
|
|
for i in MiniIdentityLdap.values():
|
|
if not (i.lasso_dump or i.lasso_proxy_dump):
|
|
continue
|
|
if not ((name_identifier in (i.lasso_dump or '')) or (
|
|
name_identifier in (i.lasso_proxy_dump or ''))):
|
|
continue # shortcut so it doesn't do the whole thing if not necessary
|
|
identity = lasso.Identity.newFromDump(i.lasso_dump)
|
|
for p in identity.providerIds:
|
|
federation = identity.getFederation(p)
|
|
if federation.localNameIdentifier and \
|
|
federation.localNameIdentifier.content == name_identifier:
|
|
return self.get_identity(i.id)
|
|
if federation.remoteNameIdentifier and \
|
|
federation.remoteNameIdentifier.content == name_identifier:
|
|
return self.get_identity(i.id)
|
|
|
|
if i.lasso_proxy_dump:
|
|
identity = lasso.Identity.newFromDump(i.lasso_proxy_dump)
|
|
for p in identity.providerIds:
|
|
federation = identity.getFederation(p)
|
|
if federation.localNameIdentifier and \
|
|
federation.localNameIdentifier.content == name_identifier:
|
|
return i
|
|
if federation.remoteNameIdentifier and \
|
|
federation.remoteNameIdentifier.content == name_identifier:
|
|
return i
|
|
|
|
return None
|
|
|
|
class LdapKeyWidget(qform.ValidatedStringWidget):
|
|
regex = r'^[a-zA-Z][a-zA-Z0-9-;]*$'
|
|
|
|
class LdapObjectNameWidget(qform.ValidatedStringWidget):
|
|
regex = r'^([a-zA-Z][a-zA-Z0-9-;]*)(\+[a-zA-Z][a-zA-Z0-9-;]*)*$'
|
|
|
|
class LdapDnWidget(qform.ValidatedStringWidget):
|
|
regex = r'^([a-zA-Z][a-zA-Z0-9-;]*)=([^ #,+"\<>;=/](\\.|[^#,+"\<>;=/])*)(,([a-zA-Z][a-zA-Z0-9-;]*)=([^ #,+"\<>;=/](\\.|[^#,+"\<>;=/])*))*$'
|
|
|
|
def fill_admin_form(self, form, data_source):
|
|
form.add(StringWidget, 'ldap_url', title = _('LDAP URL'), required = True,
|
|
hint = htmltext(_('Example: <tt>ldap://directory.example.com</tt>')),
|
|
value = data_source.get('ldap_url', ''))
|
|
form.add(CheckboxWidget, 'ldap_read_only', title=_('LDAP is Read only'), required=False,
|
|
value = data_source.get('ldap_read_only',False))
|
|
form.add(StringWidget, 'ldap_base', title = _('LDAP Base'), required = True,
|
|
hint = htmltext(_('Example: <tt>dc=example, dc=com</tt>')),
|
|
value = data_source.get('ldap_base', ''))
|
|
form.add(StringWidget, 'ldap_new_user_base', title = _('LDAP New User Base'), required = False,
|
|
hint = htmltext(_('Example: <tt>dc=example, dc=com</tt>, if not set LDAP Base is used')),
|
|
value = data_source.get('ldap_new_user_base', ''))
|
|
form.add(StringWidget, 'ldap_bind_dn', title = _('LDAP Administrative Bind DN'),
|
|
required = False,
|
|
hint = htmltext(_('Example: <tt>cn=admin, dc=example, dc=com</tt>')),
|
|
value = data_source.get('ldap_bind_dn', ''))
|
|
form.add(StringWidget, 'ldap_bind_password',
|
|
title = _('LDAP Administrative Bind password'), required = False,
|
|
hint = htmltext(_('Example: <tt>secret</tt>')),
|
|
value = data_source.get('ldap_bind_password', ''))
|
|
form.add(StringWidget, 'ldap_object_class', title = _('LDAP Object Class'),
|
|
required = True,
|
|
hint = htmltext(_('Example: <tt>posixAccount</tt>')),
|
|
value = data_source.get('ldap_object_class', ''))
|
|
form.add(self.LdapKeyWidget, 'ldap_object_uid', title = _('LDAP Object Username Attribute'),
|
|
required = True,
|
|
hint = htmltext(_('Example: <tt>uid</tt>')),
|
|
value = data_source.get('ldap_object_uid', ''))
|
|
form.add(self.LdapObjectNameWidget, 'ldap_object_name', title = _('LDAP Object Name Attribute'),
|
|
hint = htmltext(_('Example: <tt>cn</tt>')),
|
|
value = data_source.get('ldap_object_name', ''))
|
|
form.add(self.LdapKeyWidget, 'ldap_object_email', title = _('LDAP Object Email Attribute'),
|
|
hint = htmltext(_('Example: <tt>mail</tt>')),
|
|
value = data_source.get('ldap_object_email', ''))
|
|
form.add(CheckboxWidget, 'ldap_big_directory', title = _('Massive LDAP Directory'),
|
|
value = data_source.get('ldap_big_directory', False))
|
|
form.add(TextWidget, 'ldap_published_attributes',
|
|
title = _('LDAP attributes published in SAML assertions'),
|
|
hint = htmltext(_('Example: <tt>uid mail</tt> (space delimited list)')),
|
|
value = data_source.get('ldap_published_attributes', ''),
|
|
rows = 4,
|
|
cols = 80)
|
|
form.add(qform.WidgetDict, 'ldap_field_mapping',
|
|
element_value_type = self.LdapKeyWidget,
|
|
title = _('Mapping from custom field name to LDAP attributes'),
|
|
hint = htmltext(_('Possible fields are: ') + \
|
|
', '.join([field.key for field in self.fields])),
|
|
value = data_source.get('ldap_field_mapping', self.default_field_mapping))
|
|
|
|
def administrators(self):
|
|
# XXX: add support for ldap groups
|
|
def clause(x):
|
|
return ROLE_ADMIN in (x.roles or [])
|
|
mini_identities = MiniIdentityLdap.select(clause,
|
|
ignore_errors = True)
|
|
identities = []
|
|
for x in mini_identities:
|
|
try:
|
|
identity = self.get_identity(x.id)
|
|
except KeyError:
|
|
get_logger().warning('IdentitiesStoreLdap.administrators(): could not get administrator %r' % x.id)
|
|
continue
|
|
identities.append(identity)
|
|
return identities
|
|
|
|
def simple_search(self, predicate, base = None, sizelimit = 0):
|
|
if base is None:
|
|
base = self.ldap_base
|
|
self.connect_admin()
|
|
s = ldap.async.List(self.ldap_conn)
|
|
s.startSearch(base, ldap.SCOPE_SUBTREE,
|
|
predicate, sizelimit = sizelimit)
|
|
try:
|
|
partial = s.processResults()
|
|
except ldap.SIZELIMIT_EXCEEDED:
|
|
pass
|
|
except ldap.SERVER_DOWN:
|
|
raise IdentityStoreException('LDAP server is down')
|
|
else:
|
|
if partial:
|
|
# it is ignored; this info should be passed down
|
|
pass
|
|
return s.allResults
|
|
|
|
def get_identities_by_attributes(self, map):
|
|
predicate = []
|
|
for key, value in map.iteritems():
|
|
key = self.field_name2attribute(key)
|
|
sub_predicate = '(%s=%s)' % (ldap.filter.escape_filter_chars(key),
|
|
ldap.filter.escape_filter_chars(value))
|
|
predicate.append(sub_predicate)
|
|
predicate = '(&%s)' % ''.join(predicate)
|
|
ldap_results = self.simple_search(predicate)
|
|
return ( self.ldap2identity(ldap_result[1]) for ldap_result in ldap_results )
|
|
|
|
class IdentitiesStoreStorage(BaseIdentitiesStore):
|
|
label = N_('Default storage (files)')
|
|
identity_class = Identity
|
|
|
|
def __init__(self, identity_class = Identity):
|
|
self.identity_class = identity_class
|
|
|
|
def is_bootstrapping(self):
|
|
return self.identity_class.count() == 0
|
|
|
|
def keys(self):
|
|
return self.identity_class.keys()
|
|
|
|
def values(self, limit = 0, offset = 0):
|
|
return self.identity_class.values()
|
|
|
|
def get_identity(self, uid):
|
|
return self.identity_class.get(uid)
|
|
|
|
def has_identity(self, uid):
|
|
return self.identity_class.has_key(uid)
|
|
|
|
def add(self, identity):
|
|
identity.store()
|
|
|
|
def remove(self, identity):
|
|
identity.remove_self()
|
|
|
|
def is_big(self):
|
|
return self.count() > 1000
|
|
|
|
def count(self, letter = None):
|
|
if letter is None:
|
|
return self.identity_class.count()
|
|
else:
|
|
return len(self.values_by_letter(letter = letter))
|
|
|
|
def connect(self, session):
|
|
pass
|
|
|
|
def init_session(self, session, account):
|
|
pass
|
|
|
|
def save(self, identity):
|
|
identity.store()
|
|
|
|
def last_modified(self):
|
|
return self.identity_class.last_modified()
|
|
|
|
def last_modified_uid(self, uid):
|
|
return self.identity_class.last_modified_id(uid)
|
|
|
|
def get_identity_class(cls):
|
|
return cls.identity_class
|
|
get_identity_class = classmethod(get_identity_class)
|
|
|
|
|
|
sql_string_replace = [ # straight from sqlobject/converter.py
|
|
('\\', '\\\\'),
|
|
("'", "''"),
|
|
('\000', '\\0'),
|
|
('\b', '\\b'),
|
|
('\n', '\\n'),
|
|
('\r', '\\r'),
|
|
('\t', '\\t'),
|
|
]
|
|
|
|
def sqlrepr(s):
|
|
if s is None:
|
|
return 'NULL'
|
|
elif type(s) in (str, unicode):
|
|
for orig, repl in sql_string_replace:
|
|
s = s.replace(orig, repl)
|
|
return "'%s'" % s
|
|
elif type(s) in (int, long):
|
|
return str(s)
|
|
|
|
raise 'XXX (unknown type: %r)' % type(s)
|
|
|
|
|
|
class LazyPostgresqlIdentity(Identity):
|
|
real_roles = None
|
|
def get_roles(self):
|
|
if self.real_roles is None:
|
|
c = get_publisher().store.conn.cursor()
|
|
c.execute('SELECT role_id FROM identity_roles WHERE identity_id = %s' % self.id)
|
|
self.real_roles = []
|
|
for r in c.fetchall():
|
|
self.real_roles.append(int(r[0]))
|
|
c.close()
|
|
return self.real_roles
|
|
|
|
def set_roles(self, roles):
|
|
self.real_roles = roles
|
|
roles = property(get_roles, set_roles)
|
|
|
|
real_accounts = None
|
|
def get_accounts(self):
|
|
if self.real_accounts is None:
|
|
c = get_publisher().store.conn.cursor()
|
|
|
|
c.execute('SELECT * FROM password_accounts WHERE identity_id = %s' % self.id)
|
|
self.real_accounts = []
|
|
for r in c.fetchall():
|
|
pa = PasswordAccount()
|
|
for i, k in enumerate(('username', 'password', 'dumb_question', 'smart_answer')):
|
|
setattr(pa, k, r[i+1])
|
|
self.real_accounts.append(pa)
|
|
c.close()
|
|
|
|
return self.real_accounts
|
|
|
|
def set_accounts(self, accounts):
|
|
self.real_accounts = accounts
|
|
accounts = property(get_accounts, set_accounts)
|
|
|
|
|
|
class IdentitiesStorePostgresql:
|
|
label = N_('PostgreSQL database')
|
|
admin_keys = ('pg_dbname', 'pg_dbuser', 'pg_dbpassword', 'pg_dbhost')
|
|
|
|
def init(self):
|
|
self.conn = None
|
|
|
|
def is_bootstrapping(self):
|
|
c = self.conn.cursor()
|
|
try:
|
|
c.execute('SELECT id FROM identities LIMIT 1')
|
|
except ProgrammingError:
|
|
# will for instance happen if the user configured has no sufficient
|
|
# priviledges to access the database
|
|
raise IdentityStoreException('Error accessing PostgreSQL (insufficient privileges?)')
|
|
return c.fetchone() is None
|
|
|
|
def connect(self, session):
|
|
pass
|
|
|
|
def load_identities(self):
|
|
try:
|
|
import psycopg2
|
|
except ImportError:
|
|
raise IdentityStoreException('Missing python-psycopg2 module to connect to the database')
|
|
|
|
# initializes DB connection
|
|
try:
|
|
self.conn = psycopg2.connect('dbname=%s user=%s password=%s host=%s' % (
|
|
self.pg_dbname,
|
|
self.pg_dbuser,
|
|
self.pg_dbpassword,
|
|
self.pg_dbhost))
|
|
except psycopg2.OperationalError:
|
|
raise IdentityStoreException('Failed to connect to the database')
|
|
|
|
def add(self, identity):
|
|
c = self.conn.cursor()
|
|
|
|
# Get next available identity id
|
|
c.execute("SELECT NEXTVAL('identities_id_sequence')")
|
|
id = c.fetchone()[0]
|
|
identity.id = id
|
|
|
|
# Add a new identity
|
|
statement = 'INSERT INTO identities VALUES (%s)' % ', '.join(
|
|
[sqlrepr(id)] + [sqlrepr(getattr(identity, x.key)) for x in self.fields] +
|
|
[sqlrepr(getattr(identity, x)) for x in (
|
|
'lasso_dump', 'lasso_proxy_dump', 'resource_id', 'entry_id')])
|
|
c.execute(statement)
|
|
|
|
# Add a new password account for this identity
|
|
for account in identity.accounts:
|
|
# isinstance() didn't work well for that
|
|
if account.__class__.__name__ == 'PasswordAccount':
|
|
statement = 'INSERT INTO password_accounts VALUES (%s)' % ', '.join(
|
|
["NEXTVAL('password_accounts_id_sequence')"] + \
|
|
[sqlrepr(getattr(account, x)) for x in (
|
|
'username', 'password', 'dumb_question', 'smart_answer')] + \
|
|
[sqlrepr(id)])
|
|
c.execute(statement)
|
|
|
|
# Associate a role to this identity
|
|
for r in identity.roles or []:
|
|
statement = 'INSERT INTO identity_roles VALUES(%s, %s)' % (id, r)
|
|
c.execute(statement)
|
|
|
|
# Commit all database changes
|
|
self.conn.commit()
|
|
self.save_name_identifiers(identity)
|
|
|
|
def keys(self):
|
|
raise NotImplementedError
|
|
|
|
def db2identity(self, v):
|
|
identity = LazyPostgresqlIdentity()
|
|
identity.id = v[0]
|
|
for i, f in enumerate(self.fields):
|
|
setattr(identity, f.key, v[i+1])
|
|
for k in ('lasso_dump', 'lasso_proxy_dump'):
|
|
setattr(identity, k, v[i+2])
|
|
return identity
|
|
|
|
def values(self, limit = 0, offset = 0):
|
|
c = self.conn.cursor()
|
|
statement = 'SELECT * FROM identities ORDER BY name'
|
|
if limit:
|
|
statement += ' LIMIT %d OFFSET %d' % (limit, offset)
|
|
c.execute(statement)
|
|
return [self.db2identity(v) for v in c.fetchall()]
|
|
|
|
def values_by_letter(self, letter, limit = 0, offset = 0):
|
|
letter = letter.lower()
|
|
c = self.conn.cursor()
|
|
statement = "SELECT * FROM identities WHERE name ILIKE '%s%%' ORDER BY name" % letter
|
|
if limit:
|
|
statement += ' LIMIT %d OFFSET %d' % (limit, offset)
|
|
c.execute(statement)
|
|
return [self.db2identity(v) for v in c.fetchall()]
|
|
|
|
def get_identity(self, uid):
|
|
if self.conn is None:
|
|
# XXX : could use a more appropriate exception class
|
|
raise KeyError
|
|
c = self.conn.cursor()
|
|
c.execute('SELECT * FROM identities WHERE id = %s' % uid)
|
|
one = c.fetchone()
|
|
if one is None:
|
|
raise KeyError
|
|
return self.db2identity(one)
|
|
|
|
def has_identity(self, uid):
|
|
c = self.conn.cursor()
|
|
c.execute('SELECT * FROM identities WHERE id = %s' % uid)
|
|
return c.fetchone() is not None
|
|
|
|
def get_identity_for_account(self, account):
|
|
c = self.conn.cursor()
|
|
if account.__class__.__name__ == 'PasswordAccount':
|
|
c.execute('SELECT identity_id FROM password_accounts WHERE '
|
|
'username = %s AND password = %s' % (
|
|
sqlrepr(account.username), sqlrepr(account.password)))
|
|
one = c.fetchone()
|
|
if one is None:
|
|
return None
|
|
identity_id = one[0]
|
|
return self.get_identity(identity_id)
|
|
raise "XXX: unknown account type"
|
|
|
|
def has_identity_with_username(self, username):
|
|
c = self.conn.cursor()
|
|
c.execute('SELECT id FROM password_accounts WHERE username = %s' % sqlrepr(username))
|
|
return c.fetchone() is not None
|
|
|
|
def init_session(self, session, account):
|
|
pass
|
|
|
|
def count(self, letter = None):
|
|
# XXX: WARNING: this is expensive when there are lots (millions) of
|
|
# identities.
|
|
c = self.conn.cursor()
|
|
statement = 'SELECT COUNT(id) FROM identities'
|
|
if letter:
|
|
statement = statement + " WHERE name ILIKE '%s%%'" % letter
|
|
c.execute(statement)
|
|
return int(c.fetchone()[0])
|
|
|
|
def is_big(self):
|
|
# see count() comment about expensiveness
|
|
return True
|
|
return self.count() > 1000
|
|
|
|
def save(self, identity):
|
|
c = self.conn.cursor()
|
|
|
|
# Update identity attributes
|
|
statement = 'UPDATE identities SET %s WHERE id = %s' % (', '.join(
|
|
['%s = %s' % (x.key, sqlrepr(getattr(identity, x.key))) for x in self.fields] + \
|
|
['%s = %s' % (x, sqlrepr(getattr(identity, x))) for x in [
|
|
'lasso_dump', 'lasso_proxy_dump']]),
|
|
identity.id)
|
|
c.execute(statement)
|
|
|
|
# Delete old password accounts and add new ones
|
|
c.execute('DELETE FROM password_accounts WHERE identity_id = %s' % identity.id)
|
|
for account in identity.accounts:
|
|
if account.__class__.__name__ == 'PasswordAccount':
|
|
statement = 'INSERT INTO password_accounts VALUES (%s)' % ', '.join(
|
|
["NEXTVAL('password_accounts_id_sequence')"] + \
|
|
[sqlrepr(getattr(account, x)) for x in (
|
|
'username', 'password', 'dumb_question', 'smart_answer')] + \
|
|
[sqlrepr(identity.id)])
|
|
c.execute(statement)
|
|
|
|
# Delete old roles and save new ones
|
|
c.execute('DELETE FROM identity_roles WHERE identity_id = %s' % identity.id)
|
|
for r in identity.roles or []:
|
|
statement = 'INSERT INTO identity_roles VALUES(%s, %s)' % (identity.id, r)
|
|
c.execute(statement)
|
|
self.save_name_identifiers(identity)
|
|
|
|
def save_name_identifiers(self, identity):
|
|
name_identifiers = []
|
|
if identity.lasso_dump:
|
|
import xml.dom.minidom
|
|
doc = xml.dom.minidom.parseString(identity.lasso_dump)
|
|
name_identifiers = [x.childNodes[0].data for x in doc.getElementsByTagNameNS(
|
|
'urn:oasis:names:tc:SAML:1.0:assertion', 'NameIdentifier')]
|
|
if identity.lasso_proxy_dump:
|
|
import xml.dom.minidom
|
|
doc = xml.dom.minidom.parseString(identity.lasso_proxy_dump)
|
|
t = [x.childNodes[0].data for x in doc.getElementsByTagNameNS(
|
|
'urn:oasis:names:tc:SAML:1.0:assertion', 'NameIdentifier')]
|
|
name_identifiers.extend(t)
|
|
|
|
c = self.conn.cursor()
|
|
c.execute('DELETE FROM name_identifiers WHERE identity_id = %s' % identity.id)
|
|
for name_identifier in name_identifiers:
|
|
statement = 'INSERT INTO name_identifiers VALUES (%s)' % ', '.join(
|
|
["NEXTVAL('name_identifiers_id_sequence')",
|
|
sqlrepr(name_identifier),
|
|
sqlrepr(identity.id)])
|
|
c.execute(str(statement))
|
|
self.conn.commit()
|
|
|
|
def get_identity_for_name_identifier(self, name_identifier):
|
|
c = self.conn.cursor()
|
|
c.execute('SELECT identity_id FROM name_identifiers WHERE '
|
|
'name_identifier = %s' % sqlrepr(name_identifier))
|
|
one = c.fetchone()
|
|
if one is None:
|
|
raise KeyError
|
|
identity_id = one[0]
|
|
return self.get_identity(identity_id)
|
|
|
|
def fill_admin_form(self, form, data_source):
|
|
form.add(StringWidget, 'pg_dbname', title = _('Database Name'), required = True,
|
|
value = data_source.get('pg_dbname', ''))
|
|
form.add(StringWidget, 'pg_dbuser', title = _('Database User'), required = True,
|
|
value = data_source.get('pg_dbuser', ''))
|
|
form.add(StringWidget, 'pg_dbpassword', title = _('Database Password'), required = True,
|
|
value = data_source.get('pg_dbpassword', ''))
|
|
form.add(StringWidget, 'pg_dbhost', title = _('Database Hostname'), required = True,
|
|
value = data_source.get('pg_dbhost', 'localhost'))
|
|
def last_modified(self):
|
|
return None
|
|
|
|
def last_modified_uid(self, uid):
|
|
return None
|
|
|
|
def get_identities_by_attributes(self, map):
|
|
predicates = []
|
|
for key, value in map.iteritems():
|
|
sub_predicate = '% = %' % (key, sqlrepr(value))
|
|
predicates.append(sub_predicate)
|
|
predicate = ' AND '.join(predicates)
|
|
c = self.conn.cursor()
|
|
c.execute('SELECT * FROM name_identifiers WHERE %s' % sqlrepr(name_identifier))
|
|
return [self.db2identity(v) for v in c.fetchall()]
|
|
|
|
stores = {
|
|
'file': IdentitiesStorePickle,
|
|
'common': IdentitiesStoreStorage,
|
|
'ldap': IdentitiesStoreLdap,
|
|
'postgresql': IdentitiesStorePostgresql,
|
|
}
|
|
|
|
def load_store():
|
|
identity_stores_dir = os.path.join(os.path.dirname(__file__), 'identity_stores')
|
|
if os.path.exists(identity_stores_dir):
|
|
for filename in os.listdir(identity_stores_dir):
|
|
if filename.endswith('.py'):
|
|
modulename = filename[:-3]
|
|
fp, pathname, description = imp.find_module(modulename, [identity_stores_dir])
|
|
try:
|
|
imp.load_module(modulename, fp, pathname, description)
|
|
except:
|
|
if fp:
|
|
fp.close()
|
|
_storage = get_cfg('identity_storage', {'source': 'common'})
|
|
store = stores.get(_storage['source'], stores['common'])()
|
|
if hasattr(store, 'admin_keys'):
|
|
for k in store.admin_keys:
|
|
if k in _storage or not hasattr(store, k):
|
|
setattr(store, k, _storage.get(k, None))
|
|
if hasattr(store, 'init'):
|
|
store.init()
|
|
return store
|
|
|
|
def get_store():
|
|
return get_publisher().store
|
|
|
|
def get_store_class():
|
|
_storage = get_cfg('identity_storage', {'source': 'common'})
|
|
return stores.get(_storage['source'], stores['common'])
|