263 lines
9.3 KiB
Python
263 lines
9.3 KiB
Python
from quixote import get_publisher, get_field, redirect, get_path, get_request, get_session
|
|
import quixote.http_request as http_request
|
|
from qommon import get_cfg
|
|
import urlparse
|
|
try:
|
|
from urlparse import parse_qsl
|
|
except ImportError:
|
|
from cgi import parse_qsl
|
|
import urllib
|
|
import itertools
|
|
|
|
import cPickle
|
|
import lasso
|
|
import re
|
|
import os
|
|
from qommon import get_logger
|
|
|
|
from qommon.misc import *
|
|
from rfc822 import formatdate
|
|
|
|
def protect_form_from_get_parameters(function):
|
|
def f(*args, **kwargs):
|
|
request=get_request()
|
|
in_get = request.get_method() == 'GET'
|
|
if in_get:
|
|
form_save=request.form
|
|
request.form={}
|
|
res = function(*args, **kwargs)
|
|
if in_get:
|
|
request.form=form_save
|
|
return res
|
|
return f
|
|
|
|
def get_lasso_server(role = lasso.PROVIDER_ROLE_IDP, protocol = 'liberty'):
|
|
if not get_cfg('idp'):
|
|
return None
|
|
if protocol == 'liberty':
|
|
server = lasso.Server(
|
|
get_abs_path(get_cfg('idp')['metadata']),
|
|
get_abs_path(get_cfg('idp')['privatekey']),
|
|
None, None)
|
|
elif protocol == 'saml2':
|
|
server = lasso.Server(
|
|
get_abs_path(get_cfg('idp')['saml2_metadata']),
|
|
get_abs_path(get_cfg('idp')['privatekey']),
|
|
None, None)
|
|
encryption_privatekey = get_abs_path(get_cfg('idp').get('encryption_privatekey'))
|
|
if encryption_privatekey and os.path.exists(encryption_privatekey):
|
|
try:
|
|
server.setEncryptionPrivateKey(encryption_privatekey)
|
|
except lasso.Error, error:
|
|
get_logger().warn('Failed to set encryption private key')
|
|
else:
|
|
raise 'XXX: unknown protocol'
|
|
|
|
# Set encryption private key
|
|
if protocol in ('liberty', 'saml2'):
|
|
encryption_privatekey = get_abs_path(get_cfg('sp').get('encryption_privatekey'))
|
|
if encryption_privatekey and os.path.exists(encryption_privatekey):
|
|
try:
|
|
server.setEncryptionPrivateKey(encryption_privatekey)
|
|
except lasso.Error, error:
|
|
get_logger().warn('Failed to set encryption private key')
|
|
|
|
for klp, lp in get_cfg('providers', {}).items():
|
|
if not os.path.exists(get_abs_path(lp['metadata'])):
|
|
continue
|
|
|
|
changed = False
|
|
if lp['role'] == lasso.PROVIDER_ROLE_NONE:
|
|
changed = True
|
|
# "fix" remote role
|
|
if role == lasso.PROVIDER_ROLE_IDP:
|
|
lp['role'] = lasso.PROVIDER_ROLE_SP
|
|
else:
|
|
lp['role'] = lasso.PROVIDER_ROLE_IDP
|
|
|
|
publickey_fn = None
|
|
cacertchain_fn = None
|
|
if lp.has_key('publickey') and os.path.exists(get_abs_path(lp['publickey'])):
|
|
publickey_fn = get_abs_path(lp['publickey'])
|
|
if lp.has_key('cacertchain') and os.path.exists(get_abs_path(lp['cacertchain'])):
|
|
cacertchain_fn = get_abs_path(lp['cacertchain'])
|
|
|
|
# only add providers of the opposite role
|
|
if lp['role'] != role:
|
|
try:
|
|
server.addProvider(lp['role'], get_abs_path(lp['metadata']), publickey_fn, cacertchain_fn)
|
|
except lasso.Error, error:
|
|
if error[0] == -203: # protocol mismatch
|
|
continue
|
|
get_logger().warn('Could not load provider %s: %s' % (klp, error))
|
|
|
|
if hasattr(lasso, 'ENCRYPTION_SYM_KEY_TYPE_DEFAULT'):
|
|
encryption_mode = lasso.ENCRYPTION_MODE_NONE
|
|
if lp.get('encrypt_nameid', False):
|
|
encryption_mode |= lasso.ENCRYPTION_MODE_NAMEID
|
|
if lp.get('encrypt_assertion', False):
|
|
encryption_mode |= lasso.ENCRYPTION_MODE_ASSERTION
|
|
provider_t = get_provider_and_label(klp)[0]
|
|
provider = server.getProvider(provider_t.providerId)
|
|
if provider is not None:
|
|
provider.setEncryptionMode(encryption_mode)
|
|
cfg_sym_key_type = get_cfg('providers').get(
|
|
get_provider_key(provider_t.providerId), {}).get('sym_key_type', '')
|
|
sym_key_type = {
|
|
'aes256': lasso.ENCRYPTION_SYM_KEY_TYPE_AES_256,
|
|
'aes128': lasso.ENCRYPTION_SYM_KEY_TYPE_AES_128,
|
|
'3des': lasso.ENCRYPTION_SYM_KEY_TYPE_3DES,
|
|
'': lasso.ENCRYPTION_SYM_KEY_TYPE_DEFAULT
|
|
} [cfg_sym_key_type]
|
|
provider.setEncryptionSymKeyType(sym_key_type)
|
|
else:
|
|
get_logger().warn('Failed to set encryption mode for provider: %s' % provider_t.providerId)
|
|
|
|
if changed:
|
|
# restore role
|
|
lp['role'] = lasso.PROVIDER_ROLE_NONE
|
|
|
|
if protocol == 'liberty' and role == lasso.PROVIDER_ROLE_IDP and hasattr(
|
|
lasso, 'DiscoServiceInstance'):
|
|
# also acts as Disco service
|
|
disco_service = lasso.DiscoServiceInstance(
|
|
lasso.DISCO_HREF,
|
|
get_cfg('idp')['providerid'],
|
|
lasso.DiscoDescription_newWithBriefSoapHttpDescription(
|
|
lasso.SECURITY_MECH_NULL,
|
|
'%s/soapEndpoint' % get_cfg('idp')['base_url'],
|
|
'Discovery SOAP Endpoint'))
|
|
server.addService(disco_service)
|
|
|
|
if role == lasso.PROVIDER_ROLE_IDP and hasattr(server, 'loadAffiliation'):
|
|
affiliations_dir = os.path.join(get_publisher().app_dir, 'affiliations')
|
|
if os.path.exists(affiliations_dir):
|
|
filenames = os.listdir(affiliations_dir)
|
|
else:
|
|
filenames = []
|
|
|
|
for f in filenames:
|
|
filename = os.path.join(affiliations_dir, f)
|
|
try:
|
|
server.loadAffiliation(filename)
|
|
except:
|
|
pass
|
|
|
|
if hasattr(server, 'role'):
|
|
server.role = role
|
|
return server
|
|
|
|
def get_provider_and_label(provider_key):
|
|
lp = get_cfg('providers', {}).get(provider_key)
|
|
if not lp:
|
|
raise KeyError()
|
|
|
|
publickey_fn = None
|
|
if lp.get('publickey'):
|
|
publickey_fn = get_abs_path(lp['publickey'])
|
|
# cacertchain (not really necessary to get provider label
|
|
|
|
try:
|
|
provider = lasso.Provider(lp['role'], get_abs_path(lp['metadata']), publickey_fn, None)
|
|
except:
|
|
raise KeyError()
|
|
|
|
if lp.get('label'):
|
|
return (provider, lp.get('label'))
|
|
|
|
if not hasattr(provider, str('getOrganization')):
|
|
return (provider, provider.providerId)
|
|
|
|
organization = provider.getOrganization()
|
|
if not organization:
|
|
return (provider, provider.providerId)
|
|
|
|
name = re.findall("<OrganizationDisplayName.*>(.*)</OrganizationDisplayName>", organization)
|
|
if not name:
|
|
name = re.findall("<OrganizationName.*>(.*)</OrganizationName>", organization)
|
|
if not name:
|
|
return (provider, provider.providerId)
|
|
return (provider, unicode(name[0], 'utf8').encode('iso-8859-1'))
|
|
|
|
# Test last_modifed or etag
|
|
# last_modified is a time value (seconds since epoch)
|
|
def is_cached(request, response, last_modified = None, etag = None):
|
|
if last_modified:
|
|
last_modified = formatdate(last_modified)
|
|
# Set last modified
|
|
if last_modified:
|
|
response.set_header('Last-Modified', last_modified)
|
|
if etag:
|
|
response.set_header('ETag', etag)
|
|
# Handle If-Modified-Since
|
|
if_modified_since = request.get_header('If-Modified-Since')
|
|
if if_modified_since and last_modified == if_modified_since:
|
|
response.set_status(304)
|
|
return None
|
|
# XXX: Handle ETag
|
|
return True
|
|
|
|
def mtime(file):
|
|
s = os.stat(file)
|
|
return s.st_mtime
|
|
|
|
def url_add_parameters(url, parameters):
|
|
splitted = list(urlparse.urlsplit(url))
|
|
query = parse_qsl(splitted[3])
|
|
splitted[3] = '&'.join('%s=%s' % (x,urllib.quote(y)) for x,y in itertools.chain(query, parameters))
|
|
return urlparse.urlunsplit(splitted)
|
|
|
|
def redirect_with_return_url(url, return_url):
|
|
if isinstance(return_url, (tuple,list)):
|
|
return redirect(url_add_parameters(url, return_url))
|
|
else:
|
|
return redirect(url_add_parameters(url, (('returnURL', return_url),)))
|
|
|
|
def redirect_to_return_url(prefix = 'return'):
|
|
return_url = get_field('%sURL' % prefix)
|
|
if return_url:
|
|
return redirect(return_url)
|
|
else:
|
|
return None
|
|
|
|
def redirect_with_same_qs(url):
|
|
return redirect(url_with_same_qs(url))
|
|
|
|
def redirect_to_after_url():
|
|
session = get_session()
|
|
if not session:
|
|
return None
|
|
after_url = session.after_url
|
|
session.after_url = None
|
|
if after_url:
|
|
return redirect(after_url)
|
|
return None
|
|
|
|
def redirect_home(suffix='/'):
|
|
return redirect(get_request().environ['SCRIPT_NAME'] + suffix)
|
|
|
|
def url_with_same_qs(url):
|
|
query_string=get_request().environ.get('QUERY_STRING')
|
|
return url_add_parameters(url, parse_qsl(query_string))
|
|
|
|
def change_query(qs):
|
|
'''Change current query string, it allows to use other endpoints directly without a redirection'''
|
|
req = get_request()
|
|
req.form = http_request.parse_query(qs, req.charset)
|
|
|
|
def redirect_to_referer(form = None):
|
|
if form:
|
|
if form.referer():
|
|
return redirect(form.referer())
|
|
else:
|
|
referer = get_request().environ.get('HTTP_REFERER')
|
|
if referer:
|
|
return redirect(referer)
|
|
|
|
def good_for_logs(s):
|
|
if s is None:
|
|
return ''
|
|
s = str(s)
|
|
return s.replace('\n', ' ').replace('\r', ' ')
|
|
|