The new API

This commit is contained in:
Roland Hedberg 2012-03-19 15:17:57 +01:00
parent e66da03a90
commit 1b7ad6c4e0
18 changed files with 2082 additions and 2419 deletions

View File

@ -8,8 +8,7 @@ import logging
import re
from oic.utils.http_util import *
from oic.oic.message import OpenIDSchema
#from oic.oic.provider import AuthnFailure
from oic.oic.message import message
LOGGER = logging.getLogger("oicServer")
hdlr = logging.FileHandler('oc3cp.log')
@ -49,7 +48,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
else:
result = identity
return OpenIDSchema(**result)
return message("OpenIDSchema", **result)
FUNCTIONS = {
"verify_client": verify_client,

32
oc3/htdocs/login.mako Normal file
View File

@ -0,0 +1,32 @@
<%inherit file="root.mako" />
<%def name="title()">Log in</%def>
<div class="login_form" class="block">
<form action="${action}" method="post" class="login form">
<input type="hidden" name="sid" value="${sid}"/>
<table>
<tr>
<td>Username</td>
<td><input type="text" name="login" value="${login}"/></td>
</tr>
<tr>
<td>Password</td>
<td><input type="password" name="password"
value="${password}"/></td>
</tr>
<tr>
</td>
<td><input type="submit" name="form.commit"
value="Log In"/></td>
</tr>
</table>
</form>
</div>
<%def name="add_js()">
<script type="text/javascript">
$(document).ready(function() {
bookie.login.init();
});
</script>
</%def>

1
oc3/modules/__init__.py Normal file
View File

@ -0,0 +1 @@
__author__ = 'rohe0002'

79
oc3/modules/login.mako.py Normal file
View File

@ -0,0 +1,79 @@
# -*- encoding:utf-8 -*-
from mako import runtime, filters, cache
UNDEFINED = runtime.UNDEFINED
__M_dict_builtin = dict
__M_locals_builtin = locals
_magic_number = 6
_modified_time = 1330607174.483153
_template_filename='../htdocs/login.mako'
_template_uri='login.mako'
_template_cache=cache.Cache(__name__, _modified_time)
_source_encoding='utf-8'
_exports = ['add_js', 'title']
def _mako_get_namespace(context, name):
try:
return context.namespaces[(__name__, name)]
except KeyError:
_mako_generate_namespaces(context)
return context.namespaces[(__name__, name)]
def _mako_generate_namespaces(context):
pass
def _mako_inherit(template, context):
_mako_generate_namespaces(context)
return runtime._inherit_from(context, u'root.mako', _template_uri)
def render_body(context,**pageargs):
context.caller_stack._push_frame()
try:
__M_locals = __M_dict_builtin(pageargs=pageargs)
action = context.get('action', UNDEFINED)
login = context.get('login', UNDEFINED)
password = context.get('password', UNDEFINED)
sid = context.get('sid', UNDEFINED)
__M_writer = context.writer()
# SOURCE LINE 1
__M_writer(u'\n')
# SOURCE LINE 2
__M_writer(u'\n\n<div class="login_form" class="block">\n <form action="')
# SOURCE LINE 5
__M_writer(unicode(action))
__M_writer(u'" method="post" class="login form">\n <input type="hidden" name="sid" value="')
# SOURCE LINE 6
__M_writer(unicode(sid))
__M_writer(u'"/>\n <table>\n <tr>\n <td>Username</td>\n <td><input type="text" name="login" value="')
# SOURCE LINE 10
__M_writer(unicode(login))
__M_writer(u'"/></td>\n </tr>\n <tr>\n <td>Password</td>\n <td><input type="password" name="password"\n value="')
# SOURCE LINE 15
__M_writer(unicode(password))
__M_writer(u'"/></td>\n </tr>\n <tr>\n </td>\n <td><input type="submit" name="form.commit"\n value="Log In"/></td>\n </tr>\n </table>\n </form>\n</div>\n\n')
# SOURCE LINE 32
__M_writer(u'\n')
return ''
finally:
context.caller_stack._pop_frame()
def render_add_js(context):
context.caller_stack._push_frame()
try:
__M_writer = context.writer()
# SOURCE LINE 26
__M_writer(u'\n <script type="text/javascript">\n $(document).ready(function() {\n bookie.login.init();\n });\n </script>\n')
return ''
finally:
context.caller_stack._pop_frame()
def render_title(context):
context.caller_stack._push_frame()
try:
__M_writer = context.writer()
# SOURCE LINE 2
__M_writer(u'Log in')
return ''
finally:
context.caller_stack._pop_frame()

133
oc3/modules/root.mako.py Normal file
View File

@ -0,0 +1,133 @@
# -*- encoding:utf-8 -*-
from mako import runtime, filters, cache
UNDEFINED = runtime.UNDEFINED
__M_dict_builtin = dict
__M_locals_builtin = locals
_magic_number = 6
_modified_time = 1332142870.815408
_template_filename=u'templates/root.mako'
_template_uri=u'root.mako'
_template_cache=cache.Cache(__name__, _modified_time)
_source_encoding='utf-8'
_exports = ['css_link', 'pre', 'post', 'css']
def render_body(context,**pageargs):
context.caller_stack._push_frame()
try:
__M_locals = __M_dict_builtin(pageargs=pageargs)
def pre():
return render_pre(context.locals_(__M_locals))
self = context.get('self', UNDEFINED)
set = context.get('set', UNDEFINED)
def post():
return render_post(context.locals_(__M_locals))
next = context.get('next', UNDEFINED)
__M_writer = context.writer()
# SOURCE LINE 1
self.seen_css = set()
__M_writer(u'\n')
# SOURCE LINE 7
__M_writer(u'\n')
# SOURCE LINE 10
__M_writer(u'\n')
# SOURCE LINE 15
__M_writer(u'\n')
# SOURCE LINE 22
__M_writer(u'\n')
# SOURCE LINE 25
__M_writer(u'<html>\n<head><title>OAuth test</title>\n')
# SOURCE LINE 27
__M_writer(unicode(self.css()))
__M_writer(u'\n<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />\n</head>\n<body>\n')
# SOURCE LINE 31
__M_writer(unicode(pre()))
__M_writer(u'\n')
# SOURCE LINE 34
__M_writer(unicode(next.body()))
__M_writer(u'\n')
# SOURCE LINE 35
__M_writer(unicode(post()))
__M_writer(u'\n</body>\n</html>\n')
return ''
finally:
context.caller_stack._pop_frame()
def render_css_link(context,path,media=''):
context.caller_stack._push_frame()
try:
context._push_buffer()
self = context.get('self', UNDEFINED)
__M_writer = context.writer()
# SOURCE LINE 2
__M_writer(u'\n')
# SOURCE LINE 3
if path not in self.seen_css:
# SOURCE LINE 4
__M_writer(u' <link rel="stylesheet" type="text/css" href="')
__M_writer(filters.html_escape(unicode(path)))
__M_writer(u'" media="')
__M_writer(unicode(media))
__M_writer(u'">\n')
pass
# SOURCE LINE 6
__M_writer(u' ')
self.seen_css.add(path)
__M_writer(u'\n')
finally:
__M_buf, __M_writer = context._pop_buffer_and_writer()
context.caller_stack._pop_frame()
__M_writer(filters.trim(__M_buf.getvalue()))
return ''
def render_pre(context):
context.caller_stack._push_frame()
try:
context._push_buffer()
__M_writer = context.writer()
# SOURCE LINE 11
__M_writer(u'\n<div class="header">\n <h1><a href="/">Login</a></h1>\n</div>\n')
finally:
__M_buf, __M_writer = context._pop_buffer_and_writer()
context.caller_stack._pop_frame()
__M_writer(filters.trim(__M_buf.getvalue()))
return ''
def render_post(context):
context.caller_stack._push_frame()
try:
context._push_buffer()
__M_writer = context.writer()
# SOURCE LINE 16
__M_writer(u'\n<div>\n <div class="footer">\n <p>&#169; Copyright 2011 Ume&#229; Universitet &nbsp;</p>\n </div>\n</div>\n')
finally:
__M_buf, __M_writer = context._pop_buffer_and_writer()
context.caller_stack._pop_frame()
__M_writer(filters.trim(__M_buf.getvalue()))
return ''
def render_css(context):
context.caller_stack._push_frame()
try:
context._push_buffer()
def css_link(path,media=''):
return render_css_link(context,path,media)
__M_writer = context.writer()
# SOURCE LINE 8
__M_writer(u'\n ')
# SOURCE LINE 9
__M_writer(unicode(css_link('/css/main.css', 'screen')))
__M_writer(u'\n')
finally:
__M_buf, __M_writer = context._pop_buffer_and_writer()
context.caller_stack._pop_frame()
__M_writer(filters.trim(__M_buf.getvalue()))
return ''

View File

@ -9,7 +9,8 @@ import logging
import re
from oic.utils.http_util import *
from oic.oic.message import OpenIDSchema, AuthnToken
#from oic.oic.message import OpenIDSchema, AuthnToken
from oic.oic.message import message, msg_deser
from oic.oic.provider import AuthnFailure
from oic.oic.claims_provider import ClaimsClient
from oic.oic import JWT_BEARER
@ -70,16 +71,17 @@ def do_authorization(user, session=None):
#noinspection PyUnusedLocal
def verify_client(environ, areq, cdb):
if areq.client_secret: # client_secret_post
identity = areq.client_id
if "client_secret" in areq: # client_secret_post
identity = areq["client_id"]
if identity in cdb:
if cdb[identity]["client_secret"] == areq.client_secret:
if cdb[identity]["client_secret"] == areq["client_secret"]:
return True
elif areq.client_assertion: # client_secret_jwt or public_key_jwt
assert areq.client_assertion_type == JWT_BEARER
secret = cdb[areq.client_id]["client_secret"]
elif "client_assertion" in areq: # client_secret_jwt or public_key_jwt
assert areq["client_assertion_type"] == JWT_BEARER
secret = cdb[areq["client_id"]]["client_secret"]
key_col = {"hmac": secret}
bjwt = AuthnToken.set_jwt(areq.client_assertion, key_col)
bjwt = msg_deser(areq["client_assertion"], "jwt", "AuthnToken",
key=key_col)
return False
#import sys
@ -118,15 +120,15 @@ def _collect_distributed(srv, cc, user_id, what, alias=""):
if not alias:
alias = srv
for key in resp.claims_names:
for key in resp["claims_names"]:
result["_claims_names"][key] = alias
if resp.jwt:
if "jwt" in resp:
result["_claims_sources"][alias] = {"JWT": resp.jwt}
else:
result["_claims_sources"][alias] = {"endpoint": resp.endpoint}
result["_claims_sources"][alias] = {"endpoint": resp["endpoint"]}
if "access_token" in resp:
result["_claims_sources"][alias]["access_token"] = resp.access_token
result["_claims_sources"][alias]["access_token"] = resp["access_token"]
return result
@ -140,7 +142,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
result = {}
missing = []
optional = []
for key, restr in user_info_claims.claims.items():
for key, restr in user_info_claims["claims"].items():
try:
result[key] = identity[key]
except KeyError:
@ -183,7 +185,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
#result = identity
result = {"user_id": user_id}
return OpenIDSchema(**result)
return message("OpenIDSchema", **result)
FUNCTIONS = {
"authenticate": do_authentication,
@ -304,7 +306,7 @@ for endp in ENDPOINTS:
# ----------------------------------------------------------------------------
ROOT = '../'
ROOT = './'
LOOKUP = TemplateLookup(directories=[ROOT + 'templates', ROOT + 'htdocs'],
module_directory=ROOT + 'modules',

37
oc3/templates/root.mako Normal file
View File

@ -0,0 +1,37 @@
<% self.seen_css = set() %>
<%def name="css_link(path, media='')" filter="trim">
% if path not in self.seen_css:
<link rel="stylesheet" type="text/css" href="${path|h}" media="${media}">
% endif
<% self.seen_css.add(path) %>
</%def>
<%def name="css()" filter="trim">
${css_link('/css/main.css', 'screen')}
</%def>
<%def name="pre()" filter="trim">
<div class="header">
<h1><a href="/">Login</a></h1>
</div>
</%def>
<%def name="post()" filter="trim">
<div>
<div class="footer">
<p>&#169; Copyright 2011 Ume&#229; Universitet &nbsp;</p>
</div>
</div>
</%def>
##<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN "
##"http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
<html>
<head><title>OAuth test</title>
${self.css()}
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
</head>
<body>
${pre()}
## ${comps.dict_to_table(pageargs)}
## <hr><hr>
${next.body()}
${post()}
</body>
</html>

View File

@ -1,3 +1 @@
# Complete OpenID Connect implementation
# __author__ = 'rohe0002'
__author__ = 'rohe0002'

View File

@ -3,7 +3,6 @@
__author__ = 'rohe0002'
import httplib2
import inspect
import random
import string
@ -22,15 +21,15 @@ DEFAULT_POST_CONTENT_TYPE = 'application/x-www-form-urlencoded'
REQUEST2ENDPOINT = {
"AuthorizationRequest": "authorization_endpoint",
"AccessTokenRequest": "token_endpoint",
# ROPCAccessTokenRequest: "authorization_endpoint",
# CCAccessTokenRequest: "authorization_endpoint",
# ROPCAccessTokenRequest: "authorization_endpoint",
# CCAccessTokenRequest: "authorization_endpoint",
"RefreshAccessTokenRequest": "token_endpoint",
"TokenRevocationRequest": "token_endpoint",
}
}
RESPONSE2ERROR = {
AuthorizationResponse: [AuthorizationErrorResponse, TokenErrorResponse],
AccessTokenResponse: [TokenErrorResponse]
"AuthorizationResponse": ["AuthorizationErrorResponse", "TokenErrorResponse"],
"AccessTokenResponse": ["TokenErrorResponse"]
}
ENDPOINTS = ["authorization_endpoint", "token_endpoint",
@ -67,24 +66,24 @@ def client_secret_post(cli, cis, request_args=None, http_args=None, **kwargs):
if request_args is None:
request_args = {}
if not cis.client_secret:
if "client_secret" not in cis:
try:
cis.client_secret = http_args["client_secret"]
cis["client_secret"] = http_args["client_secret"]
del http_args["client_secret"]
except (KeyError, TypeError):
cis.client_secret = cli.client_secret
cis["client_secret"] = cli.client_secret
cis.client_id = cli.client_id
cis["client_id"] = cli.client_id
return http_args
#noinspection PyUnusedLocal
def bearer_header(cli, cis, request_args=None, http_args=None, **kwargs):
if cis.access_token:
_acc_token = cis.access_token
cis.access_token = None
if "access_token" in cis:
_acc_token = cis["access_token"]
del cis["access_token"]
# Required under certain circumstances :-) not under other
cis.c_attributes["access_token"] = SINGLE_OPTIONAL_STRING
cis._schema["param"]["access_token"] = SINGLE_OPTIONAL_STRING
else:
try:
_acc_token = request_args["access_token"]
@ -118,11 +117,11 @@ def bearer_body(cli, cis, request_args=None, http_args=None, **kwargs):
if request_args is None:
request_args = {}
if cis.access_token:
if "access_token" in cis:
pass
else:
try:
cis.access_token = request_args["access_token"]
cis["access_token"] = request_args["access_token"]
except KeyError:
try:
_state = kwargs["state"]
@ -131,7 +130,7 @@ def bearer_body(cli, cis, request_args=None, http_args=None, **kwargs):
raise Exception("Missing state specification")
kwargs["state"] = cli.state
cis.access_token = cli.get_token(**kwargs).access_token
cis["access_token"] = cli.get_token(**kwargs).access_token
return http_args
@ -140,7 +139,7 @@ AUTHN_METHOD = {
"client_secret_post" : client_secret_post,
"bearer_header": bearer_header,
"bearer_body": bearer_body,
}
}
# -----------------------------------------------------------------------------
@ -150,7 +149,7 @@ class ExpiredToken(Exception):
# -----------------------------------------------------------------------------
class Token(object):
_class = AccessTokenResponse
_schema = SCHEMA["AccessTokenResponse"]
def __init__(self, resp=None):
self.scope = []
@ -161,19 +160,11 @@ class Token(object):
self.replaced = False
if resp:
for prop in self._class.c_attributes.keys():
try:
_val = getattr(resp, prop)
except KeyError:
continue
if _val:
setattr(self, prop, _val)
for key, val in resp.c_extension.items():
setattr(self, key, val)
for prop, val in resp.items():
setattr(self, prop, val)
try:
_expires_in = resp.expires_in
_expires_in = resp["expires_in"]
except KeyError:
return
@ -210,10 +201,10 @@ class Token(object):
return True
class Grant(object):
_authz_resp = AuthorizationResponse
_acc_resp = AccessTokenResponse
_authz_resp = "AuthorizationResponse"
_acc_resp = "AccessTokenResponse"
_token_class = Token
def __init__(self, exp_in=600, resp=None, seed=""):
self.grant_expiration_time = 0
self.exp_in = exp_in
@ -232,12 +223,15 @@ class Grant(object):
def add_code(self, resp):
try:
self.code = resp.code
self.code = resp["code"]
self.grant_expiration_time = utc_time_sans_frac() + self.exp_in
except KeyError:
pass
def add_token(self, resp):
"""
:param resp: A Authorization Response instance
"""
tok = self._token_class(resp)
if tok.access_token:
self.tokens.append(tok)
@ -255,7 +249,7 @@ class Grant(object):
return self.__dict__.keys()
def update(self, resp):
if isinstance(resp, self._acc_resp):
if resp.type() == self._acc_resp:
if "access_token" in resp or "id_token" in resp:
tok = self._token_class(resp)
if tok not in self.tokens:
@ -265,7 +259,7 @@ class Grant(object):
self.tokens.append(tok)
else:
self.add_code(resp)
elif isinstance(resp, self._authz_resp):
elif resp.type() == self._authz_resp:
self.add_code(resp)
def get_token(self, scope=""):
@ -477,36 +471,37 @@ class KeyStore(object):
if "x509_url" in inst:
try:
_verkey = self.load_x509_cert(inst.x509_url, "verify",
_verkey = self.load_x509_cert(inst["x509_url"], "verify",
_issuer)
except Exception:
raise Exception(KEYLOADERR % ('x509', inst.x509_url))
raise Exception(KEYLOADERR % ('x509', inst["x509_url"]))
else:
_verkey = None
if "x509_encryption_url" in inst:
try:
self.load_x509_cert(inst.x509_encryption_url, "enc",
self.load_x509_cert(inst["x509_encryption_url"], "enc",
_issuer)
except Exception:
raise Exception(KEYLOADERR % ('x509_encryption',
inst.x509_encryption_url))
inst["x509_encryption_url"]))
elif _verkey:
self.set_decrypt_key(_verkey, "rsa", _issuer)
if "jwk_url" in inst:
try:
_verkeys = self.load_jwk(inst.jwk_url, "verify", _issuer)
_verkeys = self.load_jwk(inst["jwk_url"], "verify", _issuer)
except Exception, err:
raise Exception(KEYLOADERR % ('jwk', inst.jwk_url, err))
raise Exception(KEYLOADERR % ('jwk', inst["jwk_url"], err))
else:
_verkeys = []
if "jwk_encryption_url" in inst:
try:
self.load_jwk(inst.jwk_url, "enc", _issuer)
self.load_jwk(inst["jwk_encryption_url"], "enc", _issuer)
except Exception:
raise Exception(KEYLOADERR % ('jwk', inst.jwk_encryption_url))
raise Exception(KEYLOADERR % ('jwk',
inst["jwk_encryption_url"]))
elif _verkeys:
for key in _verkeys:
self.set_decrypt_key(key, "rsa", _issuer)
@ -550,8 +545,8 @@ class Client(PBase):
jwt_keys=None):
PBase.__init__(self, cache, time_out, proxy_info, follow_redirects,
disable_ssl_certificate_validation, ca_certs,
httpclass, jwt_keys)
disable_ssl_certificate_validation, ca_certs,
httpclass, jwt_keys)
self.client_id = client_id
self.client_timeout = client_timeout
@ -608,30 +603,21 @@ class Client(PBase):
return None
# def scope_from_state(self, state):
#
# def grant_from_state_or_scope(self, state, scope):
def _parse_args(self, schema, **kwargs):
ar_args = kwargs.copy()
def _parse_args(self, klass, **kwargs):
ar_args = {}
for prop, val in kwargs.items():
if prop in klass.c_attributes:
ar_args[prop] = val
elif prop.startswith("extra_"):
if prop[6:] not in klass.c_attributes:
ar_args[prop[6:]] = val
# Used to not overwrite defaults
argspec = inspect.getargspec(klass.__init__)
for prop in klass.c_attributes.keys():
if prop not in ar_args:
index = argspec[0].index(prop) -1 # skip self
if not argspec[3][index]:
if prop == "redirect_uri":
ar_args[prop] = getattr(self, "redirect_uris",
[None])[0]
else:
ar_args[prop] = getattr(self, prop, None)
for prop in schema["param"].keys():
if prop in ar_args:
continue
else:
if prop == "redirect_uri":
_val = getattr(self, "redirect_uris", [None])[0]
if _val:
ar_args[prop] = _val
else:
_val = getattr(self, prop, None)
if _val:
ar_args[prop] = _val
return ar_args
@ -693,18 +679,19 @@ class Client(PBase):
else:
raise ExpiredToken()
def construct_request(self, reqclass, request_args=None, extra_args=None):
def construct_request(self, schema, request_args=None, extra_args=None):
if request_args is None:
request_args = {}
args = self._parse_args(reqclass, **request_args)
args = self._parse_args(schema, **request_args)
if extra_args:
args.update(extra_args)
return reqclass(**args)
return Message(schema["name"], schema, **args)
#noinspection PyUnusedLocal
def construct_AuthorizationRequest(self, reqclass=AuthorizationRequest,
def construct_AuthorizationRequest(self,
schema=SCHEMA["AuthorizationRequest"],
request_args=None, extra_args=None,
**kwargs):
@ -716,10 +703,11 @@ class Client(PBase):
else:
request_args = {}
return self.construct_request(reqclass, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
#noinspection PyUnusedLocal
def construct_AccessTokenRequest(self, cls=AccessTokenRequest,
def construct_AccessTokenRequest(self,
schema=SCHEMA["AccessTokenRequest"],
request_args=None, extra_args=None,
**kwargs):
@ -727,8 +715,8 @@ class Client(PBase):
if not grant.is_valid():
raise GrantExpired("Authorization Code to old %s > %s" % (
utc_time_sans_frac(),
grant.grant_expiration_time))
utc_time_sans_frac(),
grant.grant_expiration_time))
if request_args is None:
request_args = {}
@ -743,12 +731,12 @@ class Client(PBase):
elif not request_args["client_id"]:
request_args["client_id"] = self.client_id
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def construct_RefreshAccessTokenRequest(self,
cls=RefreshAccessTokenRequest,
request_args=None, extra_args=None,
**kwargs):
schema=SCHEMA["RefreshAccessTokenRequest"],
request_args=None, extra_args=None,
**kwargs):
if request_args is None:
request_args = {}
@ -762,9 +750,10 @@ class Client(PBase):
except AttributeError:
pass
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def construct_TokenRevocationRequest(self, cls=TokenRevocationRequest,
def construct_TokenRevocationRequest(self,
schema=SCHEMA["TokenRevocationRequest"],
request_args=None, extra_args=None,
**kwargs):
@ -774,15 +763,15 @@ class Client(PBase):
token = self.get_token(**kwargs)
request_args["token"] = token.access_token
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def get_or_post(self, uri, method, req, extend=False, **kwargs):
def get_or_post(self, uri, method, req, **kwargs):
if method == "GET":
path = uri + '?' + req.get_urlencoded(extended=extend)
path = uri + '?' + req.to_urlencoded()
body = None
elif method == "POST":
path = uri
body = req.get_urlencoded(extended=extend)
body = req.to_urlencoded()
header_ext = {"content-type": DEFAULT_POST_CONTENT_TYPE}
if "headers" in kwargs.keys():
kwargs["headers"].update(header_ext)
@ -793,13 +782,12 @@ class Client(PBase):
return path, body, kwargs
def uri_and_body(self, cls, cis, method="POST", request_args=None,
extend=False, **kwargs):
def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
**kwargs):
uri = self._endpoint(self.request2endpoint[cls.__name__],
**request_args)
uri = self._endpoint(self.request2endpoint[reqmsg], **request_args)
uri, body, kwargs = self.get_or_post(uri, method, cis, extend, **kwargs)
uri, body, kwargs = self.get_or_post(uri, method, cis, **kwargs)
try:
h_args = {"headers": kwargs["headers"]}
except KeyError:
@ -807,15 +795,15 @@ class Client(PBase):
return uri, body, h_args, cis
def request_info(self, cls, method="POST", request_args=None,
def request_info(self, schema, method="POST", request_args=None,
extra_args=None, **kwargs):
if request_args is None:
request_args = {}
cis = getattr(self, "construct_%s" % cls.__name__)(cls, request_args,
extra_args,
**kwargs)
cis = getattr(self, "construct_%s" % schema["name"])(schema,
request_args,
extra_args, **kwargs)
if "authn_method" in kwargs:
h_arg = self.init_authentication_method(cis,
@ -830,81 +818,51 @@ class Client(PBase):
else:
kwargs["headers"] = h_arg
if extra_args:
extend = True
else:
extend = False
return self.uri_and_body(schema["name"], cis, method, request_args,
**kwargs)
return self.uri_and_body(cls, cis, method, request_args,
extend=extend, **kwargs)
def parse_response(self, cls, info="", format="json", state="",
extended=False, **kwargs):
def parse_response(self, schema, info="", format="json", state="",
**kwargs):
"""
Parse a response
:param cls: Which class to use when parsing the response
:param schema: Which schema the response should adhere to
:param info: The response, can be either an JSON code or an urlencoded
form:
:param format: Which serialization that was used
:param extended: If non-standard parameters should be honored
:return: The parsed and to some extend verified response
"""
_r2e = self.response2error
err = None
if format == "json":
try:
resp = cls.set_json(info, extended)
assert resp.verify(**kwargs)
except Exception, err:
resp = None
eresp = None
try:
for errcls in _r2e[cls]:
try:
eresp = errcls.set_json(info, extended)
eresp.verify()
break
except Exception:
eresp = None
except KeyError:
pass
elif format == "urlencoded":
if format == "urlencoded":
if '?' in info or '#' in info:
parts = urlparse.urlparse(info)
scheme, netloc, path, params, query, fragment = parts[:6]
# either query of fragment
if query:
pass
info = query
else:
query = fragment
else:
query = info
info = fragment
try:
resp = cls.set_urlencoded(query, extended)
assert resp.verify(**kwargs)
except Exception, err:
resp = None
try:
resp = msg_deser(info, format, schema=schema)
assert resp.verify(**kwargs)
except Exception, err:
resp = None
eresp = None
try:
for errcls in _r2e[cls]:
try:
eresp = errcls.set_urlencoded(query, extended)
eresp.verify()
break
except Exception:
eresp = None
except KeyError:
pass
else:
raise Exception("Unknown package format: '%s'" % format)
eresp = None
try:
for errmsg in _r2e[schema["name"]]:
try:
eresp = msg_deser(info, format, typ=errmsg)
eresp.verify()
break
except Exception:
eresp = None
except KeyError:
pass
# Error responses has higher precedence
if eresp:
@ -913,9 +871,9 @@ class Client(PBase):
if not resp:
raise err
if isinstance(resp, (AuthorizationResponse, AccessTokenResponse)):
if resp._name in ["AuthorizationResponse", "AccessTokenResponse"]:
try:
_state = resp.state
_state = resp["state"]
except (AttributeError, KeyError):
_state = ""
@ -931,7 +889,7 @@ class Client(PBase):
#noinspection PyUnusedLocal
def init_authentication_method(self, cis, authn_method, request_args=None,
http_args=None, **kwargs):
http_args=None, **kwargs):
if http_args is None:
http_args = {}
@ -944,16 +902,15 @@ class Client(PBase):
else:
return http_args
def request_and_return(self, url, respcls=None, method="GET", body=None,
body_type="json", extended=True,
state="", http_args=None, **kwargs):
def request_and_return(self, url, schema=None, method="GET", body=None,
body_type="json", state="", http_args=None,
**kwargs):
"""
:param url: The URL to which the request should be sent
:param respcls: The class the should represent the response
:param schema: The schema the response should adhere to
:param method: Which HTTP method to use
:param body: A message body if any
:param body_type: The format of the body of the return message
:param extended: If non-standard parameters should be honored
:param http_args: Arguments for the HTTP client
:return: A cls or ErrorResponse instance or the HTTP response
instance if no response body was expected.
@ -985,41 +942,44 @@ class Client(PBase):
raise Exception("ERROR: Something went wrong [%s]" % response.status)
if body_type:
return self.parse_response(respcls, content, body_type,
state, extended, **kwargs)
return self.parse_response(schema, content, body_type, state,
**kwargs)
else:
return response
def do_authorization_request(self, cls=AuthorizationRequest,
def do_authorization_request(self, schema=SCHEMA["AuthorizationRequest"],
state="", body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None, resp_cls=None):
http_args=None,
resp_schema=SCHEMA["AuthorizationResponse"]):
url, body, ht_args, csi = self.request_info(cls, method, request_args,
extra_args)
url, body, ht_args, csi = self.request_info(schema, method,
request_args, extra_args)
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
resp = self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
resp = self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
if isinstance(resp, ErrorResponse):
resp.state = csi.state
if isinstance(resp, Message):
if resp.type() in RESPONSE2ERROR["AuthorizationRequest"]:
resp.state = csi.state
return resp
def do_access_token_request(self, cls=AccessTokenRequest, scope="",
state="", body_type="json", method="POST",
request_args=None, extra_args=None,
http_args=None, resp_cls=AccessTokenResponse,
def do_access_token_request(self, schema=SCHEMA["AccessTokenRequest"],
scope="", state="", body_type="json",
method="POST", request_args=None,
extra_args=None, http_args=None,
resp_schema=SCHEMA["AccessTokenResponse"],
authn_method="", **kwargs):
# method is default POST
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state,
@ -1031,19 +991,20 @@ class Client(PBase):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def do_access_token_refresh(self, cls=RefreshAccessTokenRequest,
def do_access_token_refresh(self, schema=SCHEMA["RefreshAccessTokenRequest"],
state="", body_type="json", method="POST",
request_args=None, extra_args=None,
http_args=None, resp_cls=AccessTokenResponse,
http_args=None,
resp_schema=SCHEMA["AccessTokenResponse"],
authn_method="", **kwargs):
token = self.get_token(also_expired=True, state=state, **kwargs)
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
token=token,
@ -1054,16 +1015,16 @@ class Client(PBase):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def do_revocate_token(self, cls=TokenRevocationRequest, scope="", state="",
body_type="json", method="POST",
def do_revocate_token(self, schema=SCHEMA["TokenRevocationRequest"],
scope="", state="", body_type="json", method="POST",
request_args=None, extra_args=None, http_args=None,
resp_cls=None, authn_method=""):
resp_schema=SCHEMA[""], authn_method=""):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state,
@ -1074,9 +1035,9 @@ class Client(PBase):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def fetch_protected_resource(self, uri, method="GET", headers=None,
state="", **kwargs):
@ -1111,56 +1072,48 @@ class Server(PBase):
httpclass=None):
PBase.__init__(self, cache, time_out, proxy_info, follow_redirects,
disable_ssl_certificate_validation, ca_certs,
httpclass, jwt_keys)
disable_ssl_certificate_validation, ca_certs,
httpclass, jwt_keys)
def parse_url_request(self, cls, url=None, query=None, extended=False):
def parse_url_request(self, msg, url=None, query=None):
if url:
parts = urlparse.urlparse(url)
scheme, netloc, path, params, query, fragment = parts[:6]
req = cls.set_urlencoded(query, extended)
req = msg_deser(query, "urlencoded", msg)
#req = message(msg).from_urlencoded(query)
req.verify()
return req
def parse_authorization_request(self, rcls=AuthorizationRequest,
url=None, query=None, extended=False):
return self.parse_url_request(rcls, url, query, extended)
def parse_authorization_request(self, reqmsg="AuthorizationRequest",
url=None, query=None):
def parse_jwt_request(self, rcls=AuthorizationRequest, txt="", keystore="",
verify=True, extend=False):
return self.parse_url_request(reqmsg, url, query)
def parse_jwt_request(self, reqmsg="AuthorizationRequest", txt="",
keystore="", verify=True):
if not keystore:
keystore = self.keystore
keys = keystore.get_keys("verify", owner=None)
areq = rcls.set_jwt(txt, keys, verify, extend)
#areq = message().from_(txt, keys, verify)
areq = msg_deser(txt, "jwt", reqmsg, key=keys, verify=verify)
areq.verify()
return areq
def parse_body_request(self, cls=AccessTokenRequest, body=None,
extend=False):
req = cls.set_urlencoded(body, extend)
def parse_body_request(self, reqmsg="AccessTokenRequest", body=None):
#req = message(reqmsg).from_urlencoded(body)
req = msg_deser(body, "urlencoded", reqmsg)
req.verify()
return req
def parse_token_request(self, rcls=AccessTokenRequest, body=None,
extend=False):
return self.parse_body_request(rcls, body, extend)
def parse_token_request(self, reqmsg="AccessTokenRequest", body=None):
return self.parse_body_request(reqmsg, body)
def parse_refresh_token_request(self, rcls=RefreshAccessTokenRequest,
body=None, extend=False):
return self.parse_body_request(rcls, body, extend)
# def is_authorized(self, path, authorization=None):
# if not authorization:
# return False
#
# if authorization.startswith("Bearer"):
# parts = authorization.split(" ")
#
# return True
def parse_refresh_token_request(self, reqmsg="RefreshAccessTokenRequest",
body=None):
return self.parse_body_request(reqmsg, body)

View File

@ -7,17 +7,15 @@ import time
from hashlib import md5
from oic.utils import http_util
from oic.oauth2 import AuthorizationRequest
from oic.oauth2 import AccessTokenRequest
from oic.oauth2 import AuthorizationResponse
from oic.oauth2 import AccessTokenResponse
from oic.oauth2 import Client
from oic.oauth2 import ErrorResponse
from oic.oauth2 import Grant
from oic.oauth2 import rndstr
from oic.oauth2.message import SCHEMA
from oic.oauth2.message import Message
ENDPOINTS = ["authorization_endpoint", "token_endpoint", "userinfo_endpoint",
"check_id_endpoint", "registration_endpoint", "token_revokation_endpoint"]
"check_id_endpoint", "registration_endpoint",
"token_revokation_endpoint"]
def stateID(url, seed):
"""The hash of the time + server path + a seed makes an unique
@ -44,7 +42,7 @@ def factory(kaka, sdb, client_id, **kwargs):
part = http_util.cookie_parts(client_id, kaka)
if part is None:
return None
cons = Consumer(sdb, **kwargs)
cons.restore(part[0])
http_util.parse_cookie(client_id, cons.seed, kaka)
@ -81,7 +79,7 @@ class Consumer(Client):
"""
if client_config is None:
client_config = {}
Client.__init__(self, **client_config)
self.authz_page = authz_page
@ -149,7 +147,7 @@ class Consumer(Client):
"grant": self.grant,
"seed": self.seed,
"redirect_uris": self.redirect_uris,
}
}
for endpoint in ENDPOINTS:
res[endpoint] = getattr(self, endpoint, None)
@ -185,9 +183,9 @@ class Consumer(Client):
self._backup(sid)
self.sdb["seed:%s" % self.seed] = sid
location = self.request_info(AuthorizationRequest, method="GET",
scope=self.scope,
request_args={"state": sid})[0]
location = self.request_info(SCHEMA["AuthorizationRequest"],
method="GET", scope=self.scope,
request_args={"state": sid})[0]
if self.debug:
@ -220,44 +218,47 @@ class Consumer(Client):
if "code" in self.response_type:
# Might be an error response
try:
aresp = self.parse_response(AuthorizationResponse,
aresp = self.parse_response(SCHEMA["AuthorizationResponse"],
info=_query, format="urlencoded")
except Exception, err:
logger.error("%s" % err)
raise
if isinstance(aresp, ErrorResponse):
raise AuthzError(aresp.error)
if isinstance(aresp, Message):
if aresp.type().endswith("ErrorResponse"):
raise AuthzError(aresp["error"])
try:
self.update(aresp.state)
self.update(aresp["state"])
except KeyError:
raise UnknownState(aresp.state)
self._backup(aresp.state)
raise UnknownState(aresp["state"])
self._backup(aresp["state"])
return aresp
else: # implicit flow
atr = self.parse_response(AccessTokenResponse, info=_query,
format="urlencoded", extended=True)
atr = self.parse_response(SCHEMA["AccessTokenResponse"],
info=_query, format="urlencoded",
extended=True)
if isinstance(atr, ErrorResponse):
raise TokenError(atr.error)
if isinstance(atr, Message):
if atr.type().endswith("ErrorResponse"):
raise TokenError(atr["error"])
try:
self.update(atr.state)
self.update(atr["state"])
except KeyError:
raise UnknownState(atr.state)
raise UnknownState(atr["state"])
self.seed = self.grant[self.state].seed
return atr
def complete(self, environ, start_response, logger):
resp = self.handle_authorization_response(environ, start_response,
logger)
if isinstance(resp, AuthorizationResponse):
if resp.type() == "AuthorizationResponse":
# Go get the access token
resp = self.do_access_token_request(state=self.state)
@ -273,9 +274,9 @@ class Consumer(Client):
elif self.client_secret:
http_args = {}
request_args = {
"client_secret":self.client_secret,
"client_id": self.client_id,
"auth_method":"request_body"}
"client_secret":self.client_secret,
"client_id": self.client_id,
"auth_method":"request_body"}
else:
raise Exception("Nothing to authenticate with")
@ -286,7 +287,7 @@ class Consumer(Client):
request_args, http_args = self.client_auth_info()
url, body, ht_args, csi = self.request_info(AccessTokenRequest,
url, body, ht_args, csi = self.request_info(SCHEMA["AccessTokenRequest"],
request_args=request_args,
state=self.state)

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python
from oic.oauth2.message import message, add_non_standard, msg_deser, by_schema, SCHEMA
__author__ = 'rohe0002'
@ -9,17 +10,8 @@ from urlparse import parse_qs
from oic.utils.http_util import *
from oic.oauth2 import rndstr
from oic.oauth2 import Server as SrvMethod
from oic.oauth2 import MissingRequiredAttribute
from oic.oauth2 import AuthorizationResponse
from oic.oauth2 import AuthorizationRequest
from oic.oauth2 import AccessTokenResponse
from oic.oauth2 import AccessTokenRequest
from oic.oauth2 import TokenErrorResponse
from oic.oauth2 import NoneResponse
from oic import oauth2
class AuthnFailure(Exception):
pass
@ -44,10 +36,11 @@ def get_post(environ):
def code_response(**kwargs):
_areq = kwargs["areq"]
_scode = kwargs["scode"]
aresp = AuthorizationResponse()
if _areq.state:
aresp.state = _areq.state
aresp.code = _scode
aresp = message("AuthorizationResponse")
if "state" in _areq:
aresp["state"] = _areq["state"]
aresp["code"] = _scode
add_non_standard(_areq, aresp)
return aresp
def token_response(**kwargs):
@ -56,18 +49,19 @@ def token_response(**kwargs):
_sdb = kwargs["sdb"]
_dic = _sdb.update_to_token(_scode, issue_refresh=False)
aresp = oauth2.factory(AccessTokenResponse, **_dic)
if _areq.scope:
aresp.scope = _areq.scope
aresp.c_extension = _areq.c_extension
aresp = message("AccessTokenResponse", **_dic)
if "state" in _areq:
aresp["state"] = _areq["state"]
return aresp
#noinspection PyUnusedLocal
def none_response(**kwargs):
_areq = kwargs["areq"]
aresp = NoneResponse()
if _areq.state:
aresp.state = _areq.state
aresp = message("NoneResponse")
if "state" in _areq:
aresp["state"] = _areq["state"]
return aresp
def location_url(response_type, redirect_uri, query):
@ -77,8 +71,6 @@ def location_url(response_type, redirect_uri, query):
return "%s#%s" % (redirect_uri, query)
class Provider(object):
authorization_request = AuthorizationRequest
def __init__(self, name, sdb, cdb, function, urlmap=None, debug=0):
self.name = name
self.sdb = sdb
@ -136,8 +128,8 @@ class Provider(object):
#_log_info( "type: %s" % type(session["authzreq"]))
# pick up the original request
areq = self.authorization_request.set_json(session["authzreq"],
extended=True)
areq = msg_deser(session["authzreq"], "json",
schema=SCHEMA["AuthorizationRequest"])
if self.debug:
_log_info("areq: %s" % areq)
@ -149,21 +141,21 @@ class Provider(object):
except Exception:
raise
_log_info("response type: %s" % areq.response_type)
_log_info("response type: %s" % areq["response_type"])
return areq, session
def authn_reply(self, areq, aresp, environ, start_response, logger):
_log_info = logger.info
if areq.redirect_uri:
if "redirect_uri" in areq:
# TODO verify that the uri is reasonable
redirect_uri = areq.redirect_uri
redirect_uri = areq["redirect_uri"]
else:
redirect_uri = self.urlmap[areq.client_id]
redirect_uri = self.urlmap[areq["client_id"]]
location = location_url(areq.response_type, redirect_uri,
aresp.get_urlencoded())
location = location_url(areq["response_type"], redirect_uri,
aresp.to_urlencoded())
if self.debug:
_log_info("Redirected to: '%s' (%s)" % (location, type(location)))
@ -173,13 +165,12 @@ class Provider(object):
def authn_response(self, areq, session):
scode = session["code"]
areq.response_type.sort()
_rtype = " ".join(areq.response_type)
areq["response_type"].sort()
_rtype = " ".join(areq["response_type"])
return self.response_type_map[_rtype](areq=areq, scode=scode,
sdb=self.sdb)
#noinspection PyUnusedLocal
def authenticated(self, environ, start_response, logger, **kwargs):
def authenticated(self, environ, start_response, logger):
_log_info = logger.info
if self.debug:
@ -203,13 +194,11 @@ class Provider(object):
resp = BadRequest("Unknown response type")
return resp(environ, start_response)
aresp.c_extension = areq.c_extension
add_non_standard(aresp, areq)
return self.authn_reply(areq, aresp, environ, start_response, logger)
#noinspection PyUnusedLocal
def authorization_endpoint(self, environ, start_response, logger,
**kwargs):
def authorization_endpoint(self, environ, start_response, logger):
# The AuthorizationRequest endpoint
_log_info = logger.info
@ -230,8 +219,7 @@ class Provider(object):
_log_info("Query: '%s'" % query)
try:
areq = self.srvmethod.parse_authorization_request(query=query,
extended=True)
areq = self.srvmethod.parse_authorization_request(query=query)
except MissingRequiredAttribute, err:
resp = BadRequest("%s" % err)
return resp(environ, start_response)
@ -239,11 +227,11 @@ class Provider(object):
resp = BadRequest("%s" % err)
return resp(environ, start_response)
if areq.redirect_uri:
_redirect = areq.redirect_uri
else:
# A list, so pick one (==the first)
_redirect = self.urlmap[areq.client_id][0]
# if "redirect_uri" in areq:
# _redirect = areq["redirect_uri"]
# else:
# # A list, so pick one (==the first)
# _redirect = self.urlmap[areq["client_id"]][0]
sid = _sdb.create_authz_session("", areq)
bsid = base64.b64encode(sid)
@ -254,8 +242,7 @@ class Provider(object):
return self.function["authenticate"](environ, start_response, bsid)
#noinspection PyUnusedLocal
def token_endpoint(self, environ, start_response, logger, handle):
def token_endpoint(self, environ, start_response, logger):
"""
This is where clients come to get their access tokens
"""
@ -269,39 +256,40 @@ class Provider(object):
if self.debug:
_log_info("body: %s" % body)
areq = AccessTokenRequest.set_urlencoded(body, extended=True)
areq = msg_deser(body, "urlencoded", typ="AccessTokenRequest")
# Client is from basic auth or ...
client = environ["REMOTE_USER"]
if not self.function["verify client"](environ, client, self.cdb):
err = TokenErrorResponse(error="unathorized_client")
resp = Response(err.get_json(), content="application/json",
err = message("TokenErrorResponse", error="unathorized_client")
resp = Response(err.to_json(), content="application/json",
status="401 Unauthorized")
return resp(environ, start_response)
if self.debug:
_log_info("AccessTokenRequest: %s" % areq)
assert areq.grant_type == "authorization_code"
assert areq["grant_type"] == "authorization_code"
# assert that the code is valid
_info = _sdb[areq.code]
_info = _sdb[areq["code"]]
# If redirect_uri was in the initial authorization request
# verify that the one given here is the correct one.
if "redirect_uri" in _info:
assert areq.redirect_uri == _info["redirect_uri"]
assert areq["redirect_uri"] == _info["redirect_uri"]
_tinfo = _sdb.update_to_token(areq.code)
_tinfo = _sdb.update_to_token(areq["code"])
if self.debug:
_log_info("_tinfo: %s" % _tinfo)
atr = oauth2.factory(AccessTokenResponse, **_tinfo)
atr = message("AccessTokenResponse",
**by_schema("AccessTokenResponse", **_tinfo))
if self.debug:
_log_info("AccessTokenResponse: %s" % atr)
resp = Response(atr.get_json(), content="application/json")
resp = Response(atr.to_json(), content="application/json")
return resp(environ, start_response)

View File

@ -1,11 +1,14 @@
__author__ = 'rohe0002'
import urlparse
from oic import oauth2
from oic.oauth2 import AUTHN_METHOD as OAUTH2_AUTHN_METHOD
from oic.oauth2 import DEF_SIGN_ALG
from oic.oauth2 import HTTP_ARGS
from oic.oauth2 import rndstr
from oic.oauth2.message import ErrorResponse
from oic.oauth2.message import message_from_schema
from oic.oic.message import *
from oic.utils import jwt
@ -20,11 +23,12 @@ ENDPOINTS = ["authorization_endpoint", "token_endpoint",
"registration_endpoint", "check_id_endpoint"]
RESPONSE2ERROR = {
AuthorizationResponse: [AuthorizationErrorResponse, TokenErrorResponse],
AccessTokenResponse: [TokenErrorResponse],
IdToken: [ErrorResponse],
RegistrationResponse: [ClientRegistrationErrorResponse],
OpenIDSchema: [UserInfoErrorResponse]
"AuthorizationResponse": ["AuthorizationErrorResponse",
"TokenErrorResponse"],
"AccessTokenResponse": ["TokenErrorResponse"],
"IdToken": ["ErrorResponse"],
"RegistrationResponse": ["ClientRegistrationErrorResponse"],
"OpenIDSchema": ["UserInfoErrorResponse"]
}
REQUEST2ENDPOINT = {
@ -48,15 +52,10 @@ OIDCONF_PATTERN = "%s/.well-known/openid-configuration"
AUTHN_METHOD = OAUTH2_AUTHN_METHOD.copy()
def assertion_jwt(cli, keys, audience, algorithm=DEF_SIGN_ALG):
at = AuthnToken(
iss = cli.client_id,
prn = cli.client_id,
aud = audience,
jti = rndstr(8),
exp = int(epoch_in_a_while(minutes=10)),
iat = utc_now()
)
return at.get_jwt(key=keys, algorithm=algorithm)
at = message("AuthnToken", iss = cli.client_id, prn = cli.client_id,
aud = audience, jti = rndstr(8),
exp = int(epoch_in_a_while(minutes=10)), iat = utc_now())
return at.to_jwt(key=keys, algorithm=algorithm)
#noinspection PyUnusedLocal
def client_secret_jwt(cli, cis, authn_method, request_args=None,
@ -66,30 +65,37 @@ def client_secret_jwt(cli, cis, authn_method, request_args=None,
signing_key = cli.keystore.get_sign_key()
# audience is the OP endpoint
audience = cli._endpoint(REQUEST2ENDPOINT[cis.__class__.__name__])
audience = cli._endpoint(REQUEST2ENDPOINT[cis.type()])
cis.client_assertion = assertion_jwt(cli, signing_key, audience,
"RS256")
cis.client_assertion_type = JWT_BEARER
cis["client_assertion"] = assertion_jwt(cli, signing_key, audience,
"RS256")
cis["client_assertion_type"] = JWT_BEARER
if cis.client_secret:
cis.client_secret = None
try:
del cis["client_secret"]
except KeyError:
pass
return {}
#noinspection PyUnusedLocal
def private_key_jwt(cli, cis, authn_method, request_args=None,
http_args=None, req=None):
http_args=None, req=None):
# signing key is the clients rsa key for instance
signing_key = cli.keystore.get_sign_key()
# audience is the OP endpoint
audience = cli._endpoint(REQUEST2ENDPOINT[cis.__class__.__name__])
audience = cli._endpoint(REQUEST2ENDPOINT[cis.type()])
cis.client_assertion = assertion_jwt(cli, signing_key, audience,
cis["client_assertion"] = assertion_jwt(cli, signing_key, audience,
algorithm="RS512")
cis.client_assertion_type = JWT_BEARER
cis["client_assertion_type"] = JWT_BEARER
try:
del cis["client_secret"]
except KeyError:
pass
return {}
@ -99,12 +105,12 @@ AUTHN_METHOD.update({"client_secret_jwt": client_secret_jwt,
# -----------------------------------------------------------------------------
class Token(oauth2.Token):
_class = AccessTokenResponse
_schema = SCHEMA["AccessTokenResponse"]
class Grant(oauth2.Grant):
_authz_resp = AuthorizationResponse
_acc_resp = AccessTokenResponse
_authz_resp = "AuthorizationResponse"
_acc_resp = "AccessTokenResponse"
_token_class = Token
def add_token(self, resp):
@ -130,8 +136,8 @@ class Client(oauth2.Client):
client_timeout = time_sans_frac() + expire_in
oauth2.Client.__init__(self, client_id, cache, timeout, proxy_info,
follow_redirects, disable_ssl_certificate_validation,
ca_certs, grant_expire_in, client_timeout, httpclass)
follow_redirects, disable_ssl_certificate_validation,
ca_certs, grant_expire_in, client_timeout, httpclass)
self.file_store = "./file/"
self.file_uri = "http://localhost/"
@ -186,21 +192,26 @@ class Client(oauth2.Client):
The request will be signed
"""
oir_args = {}
if userinfo_claims is not None:
# UserInfoClaims
claim = Claims(**userinfo_claims["claims"])
claim = message("Claims", **userinfo_claims["claims"])
uic_args = {}
for prop, val in userinfo_claims.items():
if prop == "claims":
continue
if prop in UserInfoClaim.c_attributes.keys():
if prop in SCHEMA["UserInfoClaim"]["param"].keys():
uic_args[prop] = val
uic = UserInfoClaim(claim, **uic_args)
uic = message("UserInfoClaim", claims=claim, **uic_args)
else:
uic = None
if uic:
oir_args["userinfo"] = uic
if idtoken_claims is not None:
#IdTokenClaims
try:
@ -208,28 +219,29 @@ class Client(oauth2.Client):
except KeyError:
_max_age=MAX_AUTHENTICATION_AGE
id_token = IDTokenClaim(max_age=_max_age)
id_token = message("IDTokenClaim", max_age=_max_age)
if "claims" in idtoken_claims:
idtclaims = Claims(**idtoken_claims["claims"])
id_token.claims = idtclaims
idtclaims = message("Claims", **idtoken_claims["claims"])
id_token["claims"] = idtclaims
else: # uic must be != None
id_token = IDTokenClaim(max_age=MAX_AUTHENTICATION_AGE)
id_token = message("IDTokenClaim", max_age=MAX_AUTHENTICATION_AGE)
oir_args = {"userinfo":uic, "id_token":id_token}
for prop in arq.keys():
_val = getattr(arq, prop)
if _val:
oir_args[prop] = _val
if id_token:
oir_args["id_token"] = id_token
for prop, val in arq.items():
oir_args[prop] = val
for attr in ["scope", "prompt", "response_type"]:
if attr in oir_args:
oir_args[attr] = " ".join(oir_args[attr])
oir = OpenIDRequest(**oir_args)
oir = message("OpenIDRequest", **oir_args)
return oir.get_jwt(extended=True, key=keys, algorithm=algorithm)
return oir.to_jwt(key=keys, algorithm=algorithm)
def construct_AuthorizationRequest(self, cls=AuthorizationRequest,
def construct_AuthorizationRequest(self,
schema=SCHEMA["AuthorizationRequest"],
request_args=None, extra_args=None,
**kwargs):
@ -239,13 +251,13 @@ class Client(oauth2.Client):
else:
request_args = {"nonce": rndstr(12)}
return oauth2.Client.construct_AuthorizationRequest(self, cls,
return oauth2.Client.construct_AuthorizationRequest(self, schema,
request_args,
extra_args,
**kwargs)
def construct_OpenIDRequest(self, cls=OpenIDRequest, request_args=None,
extra_args=None, **kwargs):
def construct_OpenIDRequest(self, schema=SCHEMA["OpenIDRequest"],
request_args=None, extra_args=None, **kwargs):
if request_args is not None:
for arg in ["idtoken_claims", "userinfo_claims"]:
@ -257,7 +269,7 @@ class Client(oauth2.Client):
else:
request_args = {"nonce": rndstr(12)}
areq = oauth2.Client.construct_AuthorizationRequest(self, cls,
areq = oauth2.Client.construct_AuthorizationRequest(self, schema,
request_args,
extra_args,
**kwargs)
@ -266,29 +278,29 @@ class Client(oauth2.Client):
kwargs["keys"] = self.keystore.get_sign_key()
if "userinfo_claims" in kwargs or "idtoken_claims" in kwargs:
areq.request = self.make_openid_request(areq, **kwargs)
areq["request"] = self.make_openid_request(areq, **kwargs)
return areq
#noinspection PyUnusedLocal
def construct_AccessTokenRequest(self, cls=AccessTokenRequest,
def construct_AccessTokenRequest(self, schema=SCHEMA["AccessTokenRequest"],
request_args=None, extra_args=None,
**kwargs):
return oauth2.Client.construct_AccessTokenRequest(self, cls,
return oauth2.Client.construct_AccessTokenRequest(self, schema,
request_args,
extra_args, **kwargs)
def construct_RefreshAccessTokenRequest(self,
cls=RefreshAccessTokenRequest,
request_args=None, extra_args=None,
**kwargs):
schema=SCHEMA["RefreshAccessTokenRequest"],
request_args=None, extra_args=None,
**kwargs):
return oauth2.Client.construct_RefreshAccessTokenRequest(self, cls,
request_args,
extra_args, **kwargs)
return oauth2.Client.construct_RefreshAccessTokenRequest(self, schema,
request_args,
extra_args, **kwargs)
def construct_UserInfoRequest(self, cls=UserInfoRequest,
def construct_UserInfoRequest(self, schema=SCHEMA["UserInfoRequest"],
request_args=None, extra_args=None,
**kwargs):
@ -304,23 +316,25 @@ class Client(oauth2.Client):
request_args["access_token"] = token.access_token
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
#noinspection PyUnusedLocal
def construct_RegistrationRequest(self, cls=RegistrationRequest,
def construct_RegistrationRequest(self,
schema=SCHEMA["RegistrationRequest"],
request_args=None, extra_args=None,
**kwargs):
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
#noinspection PyUnusedLocal
def construct_RefreshSessionRequest(self, cls=RefreshSessionRequest,
def construct_RefreshSessionRequest(self,
schema=SCHEMA["RefreshSessionRequest"],
request_args=None, extra_args=None,
**kwargs):
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def _id_token_based(self, cls, request_args=None, extra_args=None,
def _id_token_based(self, schema, request_args=None, extra_args=None,
**kwargs):
if request_args is None:
@ -340,82 +354,88 @@ class Client(oauth2.Client):
request_args[_prop] = id_token
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def construct_CheckSessionRequest(self, cls=CheckSessionRequest,
request_args=None, extra_args=None,
**kwargs):
def construct_CheckSessionRequest(self,
schema=SCHEMA["CheckSessionRequest"],
request_args=None, extra_args=None,
**kwargs):
return self._id_token_based(cls, request_args, extra_args, **kwargs)
return self._id_token_based(schema, request_args, extra_args, **kwargs)
def construct_CheckIDRequest(self, cls=CheckIDRequest, request_args=None,
def construct_CheckIDRequest(self, schema=SCHEMA["CheckIDRequest"],
request_args=None,
extra_args=None, **kwargs):
# access_token is where the id_token will be placed
return self._id_token_based(cls, request_args, extra_args,
return self._id_token_based(schema, request_args, extra_args,
prop="access_token", **kwargs)
def construct_EndSessionRequest(self, cls=EndSessionRequest,
def construct_EndSessionRequest(self, schema=SCHEMA["EndSessionRequest"],
request_args=None, extra_args=None,
**kwargs):
if request_args is None:
request_args = {}
if "state" in kwargs:
request_args["state"] = kwargs["state"]
elif "state" in request_args:
kwargs["state"] = request_args["state"]
# if "redirect_url" not in request_args:
# request_args["redirect_url"] = self.redirect_url
return self._id_token_based(cls, request_args, extra_args, **kwargs)
# if "redirect_url" not in request_args:
# request_args["redirect_url"] = self.redirect_url
return self._id_token_based(schema, request_args, extra_args, **kwargs)
# ------------------------------------------------------------------------
def do_authorization_request(self, cls=AuthorizationRequest,
def do_authorization_request(self, schema=SCHEMA["AuthorizationRequest"],
state="", body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None, resp_cls=None):
http_args=None,
resp_schema=SCHEMA["AuthorizationResponse"]):
return oauth2.Client.do_authorization_request(self, cls, state,
return oauth2.Client.do_authorization_request(self, schema, state,
body_type, method,
request_args,
extra_args, http_args,
resp_cls)
resp_schema)
def do_access_token_request(self, cls=AccessTokenRequest, scope="",
state="", body_type="json", method="POST",
request_args=None, extra_args=None,
http_args=None, resp_cls=AccessTokenResponse,
def do_access_token_request(self, schema=SCHEMA["AccessTokenRequest"],
scope="", state="", body_type="json",
method="POST", request_args=None,
extra_args=None, http_args=None,
resp_schema=SCHEMA["AccessTokenResponse"],
authn_method="", **kwargs):
return oauth2.Client.do_access_token_request(self, cls, scope, state,
return oauth2.Client.do_access_token_request(self, schema, scope, state,
body_type, method,
request_args, extra_args,
http_args, resp_cls,
http_args, resp_schema,
authn_method, **kwargs)
def do_access_token_refresh(self, cls=RefreshAccessTokenRequest,
def do_access_token_refresh(self, schema=SCHEMA["RefreshAccessTokenRequest"],
state="", body_type="json", method="POST",
request_args=None, extra_args=None,
http_args=None, resp_cls=AccessTokenResponse,
http_args=None,
resp_schema=SCHEMA["AccessTokenResponse"],
**kwargs):
return oauth2.Client.do_access_token_refresh(self, cls, state,
return oauth2.Client.do_access_token_refresh(self, schema, state,
body_type, method,
request_args,
extra_args, http_args,
resp_cls, **kwargs)
resp_schema, **kwargs)
def do_registration_request(self, cls=RegistrationRequest, scope="",
state="", body_type="json", method="POST",
request_args=None, extra_args=None,
http_args=None, resp_cls=RegistrationResponse):
def do_registration_request(self, schema=SCHEMA["RegistrationRequest"],
scope="", state="", body_type="json",
method="POST", request_args=None,
extra_args=None, http_args=None,
resp_schema=SCHEMA["RegistrationResponse"]):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state)
@ -425,17 +445,18 @@ class Client(oauth2.Client):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def do_check_session_request(self, cls=CheckSessionRequest, scope="",
def do_check_session_request(self, schema=SCHEMA["CheckSessionRequest"],
scope="",
state="", body_type="json", method="GET",
request_args=None, extra_args=None,
http_args=None,
resp_cls=IdToken):
resp_schema=SCHEMA["IdToken"]):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state)
@ -445,17 +466,17 @@ class Client(oauth2.Client):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def do_check_id_request(self, cls=CheckIDRequest, scope="",
state="", body_type="json", method="GET",
request_args=None, extra_args=None,
http_args=None,
resp_cls=IdToken):
def do_check_id_request(self, schema=SCHEMA["CheckIDRequest"], scope="",
state="", body_type="json", method="GET",
request_args=None, extra_args=None,
http_args=None,
resp_schema=SCHEMA["IdToken"]):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state)
@ -465,16 +486,16 @@ class Client(oauth2.Client):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def do_end_session_request(self, cls=EndSessionRequest, scope="",
state="", body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None, resp_cls=None):
def do_end_session_request(self, schema=SCHEMA["EndSessionRequest"], scope="",
state="", body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None, resp_schema=SCHEMA[""]):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args,
scope=scope, state=state)
@ -484,15 +505,15 @@ class Client(oauth2.Client):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
body_type, extended=False,
state=state, http_args=http_args)
return self.request_and_return(url, resp_schema, method, body,
body_type, state=state,
http_args=http_args)
def user_info_request(self, method="GET", state="", scope="", **kwargs):
uir = UserInfoRequest()
uir = message("UserInfoRequest")
if "token" in kwargs:
if kwargs["token"]:
uir.access_token = kwargs["token"]
uir["access_token"] = kwargs["token"]
token = Token()
token.type = "Bearer"
token.access_token = kwargs["token"]
@ -504,7 +525,7 @@ class Client(oauth2.Client):
token = self.grant[state].get_token(scope)
if token.is_valid():
uir.access_token = token.access_token
uir["access_token"] = token.access_token
else:
# raise oauth2.OldAccessToken
if self.log:
@ -512,12 +533,12 @@ class Client(oauth2.Client):
try:
self.do_access_token_refresh(token=token)
token = self.grant[state].get_token(scope)
uir.access_token = token.access_token
uir["access_token"] = token.access_token
except Exception:
raise
try:
uir.schema = kwargs["schema"]
uir["schema"] = kwargs["schema"]
except KeyError:
pass
@ -538,7 +559,7 @@ class Client(oauth2.Client):
kwargs["headers"] = [("Authorization", token.access_token)]
if not "token_in_message_body" in _behav:
# remove the token from the request
uir.access_token = None
uir["access_token"] = None
path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs)
@ -551,7 +572,7 @@ class Client(oauth2.Client):
kwargs["schema"] = schema
path, body, method, h_args = self.user_info_request(method, state,
scope, **kwargs)
scope, **kwargs)
try:
response, content = self.http.request(path, method, body, **h_args)
@ -565,7 +586,7 @@ class Client(oauth2.Client):
else:
raise Exception("ERROR: Something went wrong [%s]" % response.status)
return OpenIDSchema.set_json(txt=content, extended=True)
return message("OpenIDSchema").from_json(txt=content)
def provider_config(self, issuer, keys=True, endpoints=True):
@ -578,21 +599,21 @@ class Client(oauth2.Client):
(response, content) = self.http.request(url)
if response.status == 200:
pcr = ProviderConfigurationResponse.from_json(content,
extended=True)
pcr = message("ProviderConfigurationResponse").from_json(content)
else:
raise Exception("%s" % response.status)
if pcr.issuer:
if pcr.issuer.endswith("/"):
_pcr_issuer = pcr.issuer[:-1]
if "issuer" in pcr:
if pcr["issuer"].endswith("/"):
_pcr_issuer = pcr["issuer"][:-1]
else:
_pcr_issuer = pcr.issuer
_pcr_issuer = pcr["issuer"]
try:
assert _issuer == _pcr_issuer
except AssertionError:
raise Exception("provider info issuer mismatch")
raise Exception("provider info issuer mismatch '%s' != '%s'" % (
_issuer, _pcr_issuer))
if endpoints:
for key, val in pcr.items():
@ -614,7 +635,7 @@ class Client(oauth2.Client):
keycol = self.keystore.pairkeys(csrc)["verify"]
info = json.loads(jwt.verify(str(spec["JWT"]), keycol))
attr = [n for n, s in userinfo._claim_names.items() if s ==
csrc]
csrc]
assert attr == info.keys()
for key, vals in info.items():
@ -629,14 +650,14 @@ class Client(oauth2.Client):
if "access_token" in spec:
_uinfo = self.do_user_info_request(
token=spec["access_token"],
userinfo_endpoint=spec["endpoint"])
token=spec["access_token"],
userinfo_endpoint=spec["endpoint"])
else:
_uinfo = self.do_user_info_request(token=callback(csrc),
userinfo_endpoint=spec["endpoint"])
userinfo_endpoint=spec["endpoint"])
attr = [n for n, s in userinfo._claim_names.items() if s ==
csrc]
csrc]
assert attr == _uinfo.keys()
for key, vals in _uinfo.items():
@ -662,30 +683,27 @@ class Server(oauth2.Server):
return urlparse.parse_qs(query)
def parse_token_request(self, cls=AccessTokenRequest, body=None,
extended=False):
return oauth2.Server.parse_token_request(self, cls, body, extended)
def parse_token_request(self, msg="AccessTokenRequest", body=None):
return oauth2.Server.parse_token_request(self, msg, body)
def parse_authorization_request(self, rcls=AuthorizationRequest,
url=None, query=None, extended=False):
return oauth2.Server.parse_authorization_request(self, rcls, url,
query, extended)
def parse_authorization_request(self, msg="AuthorizationRequest",
url=None, query=None):
return oauth2.Server.parse_authorization_request(self, msg, url, query)
def parse_jwt_request(self, rcls=AuthorizationRequest, txt="",
keys=None, verify=True, extended=False):
def parse_jwt_request(self, msg="AuthorizationRequest", txt="",
keys=None, verify=True):
return oauth2.Server.parse_jwt_request(self, rcls, txt,
keys, verify, extended)
return oauth2.Server.parse_jwt_request(self, msg, txt,
keys, verify)
def parse_refresh_token_request(self, cls=RefreshAccessTokenRequest,
body=None, extended=False):
return oauth2.Server.parse_refresh_token_request(self, cls, body,
extended)
def parse_refresh_token_request(self, msg="RefreshAccessTokenRequest",
body=None):
return oauth2.Server.parse_refresh_token_request(self, msg, body)
def _deser_id_token(self, str=""):
if not str:
return None
# have to start decoding the jwt without verifying in order to find
# out which key to verify the JWT signature with
_ = json.loads(jwt.unpack(str)[1])
@ -695,7 +713,7 @@ class Server(oauth2.Server):
keys = self.keystore.get_keys("verify", owner=None)
return IdToken.set_jwt(str, key=keys)
return message("IdToken").from_jwt(str, key=keys)
def parse_check_session_request(self, url=None, query=None):
"""
@ -713,46 +731,44 @@ class Server(oauth2.Server):
assert "access_token" in param # ignore the rest
return self._deser_id_token(param["access_token"][0])
def _parse_request(self, cls, data, format, extended):
def _parse_request(self, schema, data, format):
if format == "json":
request = cls.set_json(data, extended)
request = message_from_schema(schema).from_json(data)
elif format == "urlencoded":
if '?' in data:
parts = urlparse.urlparse(data)
scheme, netloc, path, params, query, fragment = parts[:6]
else:
query = data
request = cls.set_urlencoded(query, extended)
request = message_from_schema(schema).from_urlencoded(query)
else:
raise Exception("Unknown package format: '%s'" % format)
request.verify()
return request
def parse_open_id_request(self, data, format="urlencoded", extended=False):
return self._parse_request(OpenIDRequest, data, format, extended)
def parse_user_info_request(self, data, format="urlencoded", extended=False):
return self._parse_request(UserInfoRequest, data, format, extended)
def parse_open_id_request(self, data, format="urlencoded"):
return self._parse_request(SCHEMA["OpenIDRequest"], data, format)
def parse_refresh_session_request(self, url=None, query=None,
extended=False):
def parse_user_info_request(self, data, format="urlencoded"):
return self._parse_request(SCHEMA["UserInfoRequest"], data, format)
def parse_refresh_session_request(self, url=None, query=None):
if url:
parts = urlparse.urlparse(url)
scheme, netloc, path, params, query, fragment = parts[:6]
return RefreshSessionRequest.set_urlencoded(query, extended)
return message("RefreshSessionRequest").from_urlencoded(query)
def parse_registration_request(self, data, format="urlencoded",
extended=True):
return self._parse_request(RegistrationRequest, data, format, extended)
def parse_registration_request(self, data, format="urlencoded"):
return self._parse_request(SCHEMA["RegistrationRequest"], data, format)
def parse_end_session_request(self, query, extended=True):
esr = EndSessionRequest.set_urlencoded(query, extended)
def parse_end_session_request(self, query):
esr = message("EndSessionRequest").from_urlencoded(query)
# if there is a id_token in there it is as a string
esr.id_token = self._deser_id_token(esr.id_token)
esr["id_token"] = self._deser_id_token(esr["id_token"])
return esr
def parse_issuer_request(self, info, format="urlencoded", extended=True):
return self._parse_request(IssuerRequest, info, format, extended)
def parse_issuer_request(self, info, format="urlencoded"):
return self._parse_request(SCHEMA["IssuerRequest"], info, format)

View File

@ -1,83 +1,65 @@
__author__ = 'rohe0002'
from oic import oauth2
from oic.oic import Server as OicServer
from oic.oic import Client
from oic.oic.provider import Provider, get_or_post, Endpoint
from oic.oauth2.message import SINGLE_REQUIRED_STRING
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
from oic.oauth2.message import REQUIRED_LIST_OF_STRINGS
from oic.oauth2.message import ErrorResponse
from oic.utils.http_util import Response, Unauthorized
from oic.oic import REQUEST2ENDPOINT
from oic.oic import RESPONSE2ERROR
#from oic.oic.message import IdToken, OpenIDSchema
from oic.oic.message import Claims, TokenErrorResponse
from oic.oic.message import OpenIDSchema
from oic.oic.message import UserInfoClaim
from oic.oic.message import message
#from oic.oic.message import SCHEMA
from oic.oic.provider import Provider, get_or_post, Endpoint
from oic.oauth2.message import SINGLE_REQUIRED_STRING, Message
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
from oic.oauth2.message import REQUIRED_LIST_OF_STRINGS
from oic.utils.http_util import Response, Unauthorized
# Used in claims.py
from oic.oic.message import RegistrationRequest
from oic.oic.message import RegistrationResponse
#from oic.oic.message import RegistrationRequest
#from oic.oic.message import RegistrationResponse
class UserClaimsRequest(oauth2.Base):
c_attributes = oauth2.Base.c_attributes.copy()
c_attributes["user_id"] = SINGLE_REQUIRED_STRING
c_attributes["client_id"] = SINGLE_REQUIRED_STRING
c_attributes["client_secret"] = SINGLE_REQUIRED_STRING
c_attributes["claims_names"] = REQUIRED_LIST_OF_STRINGS
def verify(self, **kwargs):
if self.jwt:
# Try to decode the JWT, checks the signature
try:
item = message("OpenIDSchema").set_jwt(str(self.jwt),
kwargs["key"])
except Exception, _err:
raise Exception(_err.__class__.__name__)
def __init__(self,
user_id=None,
client_id=None,
client_secret=None,
claims_names=None,
**kwargs):
oauth2.Base.__init__(self, **kwargs)
self.user_id = user_id
self.client_id = client_id
self.client_secret = client_secret
self.claims_names = claims_names
if not item.verify(**kwargs):
return False
class UserClaimsResponse(oauth2.Base):
c_attributes = oauth2.Base.c_attributes.copy()
c_attributes["claims_names"] = REQUIRED_LIST_OF_STRINGS
c_attributes["jwt"] = SINGLE_OPTIONAL_STRING
c_attributes["endpoint"] = SINGLE_OPTIONAL_STRING
c_attributes["access_token"] = SINGLE_OPTIONAL_STRING
return super(self.__class__, self).verify(**kwargs)
def __init__(self,
jwt=None,
claims_names=None,
endpoint=None,
access_token=None,
**kwargs):
oauth2.Base.__init__(self, **kwargs)
self.jwt = jwt
self.claims_names = claims_names
self.endpoint = endpoint
self.access_token = access_token
def verify(self, **kwargs):
if self.jwt:
# Try to decode the JWT, checks the signature
try:
item = OpenIDSchema.set_jwt(str(self.jwt), kwargs["key"])
except Exception, _err:
raise Exception(_err.__class__.__name__)
if not item.verify(**kwargs):
return False
return oauth2.Base.verify(self, **kwargs)
SCHEMA = {
"": {"param": {}},
"UserClaimsRequest": {
"param": {
"user_id": SINGLE_REQUIRED_STRING,
"client_id": SINGLE_REQUIRED_STRING,
"client_secret": SINGLE_REQUIRED_STRING,
"claims_names": REQUIRED_LIST_OF_STRINGS
},
},
"UserClaimsResponse": {
"param": {
"claims_names": REQUIRED_LIST_OF_STRINGS,
"jwt": SINGLE_OPTIONAL_STRING,
"endpoint": SINGLE_OPTIONAL_STRING,
"access_token": SINGLE_OPTIONAL_STRING
},
"verify": verify,
},
}
class OICCServer(OicServer):
def parse_user_claims_request(self, info, format="urlencoded",
extended=True):
return self._parse_request(UserClaimsRequest, info, format, extended)
def parse_user_claims_request(self, info, format="urlencoded"):
return self._parse_request(SCHEMA["UserClaimsRequest"], info, format)
class ClaimsServer(Provider):
@ -85,8 +67,8 @@ class ClaimsServer(Provider):
debug=0, cache=None, timeout=None, proxy_info=None,
follow_redirects=True, ca_certs="", jwt_keys=None):
Provider.__init__(self, name, sdb, cdb, function, userdb, urlmap,
debug, cache, timeout, proxy_info,
follow_redirects, ca_certs, jwt_keys)
debug, cache, timeout, proxy_info,
follow_redirects, ca_certs, jwt_keys)
if jwt_keys is None:
jwt_keys = []
@ -102,19 +84,16 @@ class ClaimsServer(Provider):
def _aggregation(self, info, logger):
jwt_key = self.keystore.get_sign_key()
cresp = UserClaimsResponse(jwt=info.get_jwt(key=jwt_key,
algorithm="RS256"),
claims_names=info.keys())
cresp = Message("UserClaimsResponse", SCHEMA["UserClaimsResponse"],
jwt=info.get_jwt(key=jwt_key, algorithm="RS256"),
claims_names=info.keys())
logger.info("RESPONSE: %s" % (cresp.dictionary(),))
logger.info("RESPONSE: %s" % (cresp.to_dict(),))
return cresp
#noinspection PyUnusedLocal
def _distributed(self, ucreq, logger):
cresp = UserClaimsResponse()
cresp.endpoint = ""
cresp.access_token = ""
return cresp
return Message("UserClaimsResponse", SCHEMA["UserClaimsResponse"])
#noinspection PyUnusedLocal
def do_aggregation(self, info, uid):
@ -134,13 +113,13 @@ class ClaimsServer(Provider):
if not self.function["verify_client"](environ, ucreq, self.cdb):
_log_info("could not verify client")
err = TokenErrorResponse(error="unathorized_client")
err = message("TokenErrorResponse", error="unathorized_client")
resp = Unauthorized(err.get_json(), content="application/json")
return resp(environ, start_response)
if ucreq.claims_names:
args = dict([(n, {"optional": True}) for n in ucreq.claims_names])
uic = UserInfoClaim(claims=Claims(**args))
uic = message("UserInfoClaim", claims=message("Claims", **args))
else:
uic = None
@ -156,7 +135,7 @@ class ClaimsServer(Provider):
else:
cresp = self._distributed(info, logger)
resp = Response(cresp.get_json(), content="application/json")
resp = Response(cresp.to_json(), content="application/json")
return resp(environ, start_response)
class ClaimsClient(Client):
@ -176,22 +155,23 @@ class ClaimsClient(Client):
self.request2endpoint = REQUEST2ENDPOINT.copy()
self.request2endpoint["UserClaimsRequest"] = "userclaims_endpoint"
self.response2error = RESPONSE2ERROR.copy()
self.response2error[UserClaimsResponse] = [ErrorResponse]
self.response2error["UserClaimsResponse"] = ["ErrorResponse"]
#noinspection PyUnusedLocal
def construct_UserClaimsRequest(self, cls=UserClaimsRequest,
request_args=None, extra_args=None,
**kwargs):
def construct_UserClaimsRequest(self, schema=SCHEMA["UserClaimsRequest"],
request_args=None, extra_args=None,
**kwargs):
return self.construct_request(cls, request_args, extra_args)
return self.construct_request(schema, request_args, extra_args)
def do_claims_request(self, cls=UserClaimsRequest,
resp_cls=UserClaimsResponse, body_type="json",
def do_claims_request(self, schema=SCHEMA["UserClaimsRequest"],
resp_schema=SCHEMA["UserClaimsResponse"],
body_type="json",
method="POST", request_args=None, extra_args=None,
http_args=None):
url, body, ht_args, csi = self.request_info(cls, method=method,
url, body, ht_args, csi = self.request_info(schema, method=method,
request_args=request_args,
extra_args=extra_args)
@ -200,11 +180,11 @@ class ClaimsClient(Client):
else:
http_args.update(http_args)
return self.request_and_return(url, resp_cls, method, body,
return self.request_and_return(url, resp_schema, method, body,
body_type, extended=False,
http_args=http_args,
key=self.keystore.pairkeys(
self.keystore.match_owner(url)))
self.keystore.match_owner(url)))
class UserClaimsEndpoint(Endpoint) :
type = "userclaims"

View File

@ -10,26 +10,14 @@ import httplib2
from hashlib import md5
from oic.utils import http_util
#from oic.utils.time_util import time_sans_frac
from oic.oic import Client
from oic.oic import ENDPOINTS
from oic.oic.message import AuthorizationRequest
#from oic.oic.message import IDTokenClaim
from oic.oic.message import UserInfoClaim
from oic.oic.message import Claims
#from oic.oic.message import OpenIDRequest
from oic.oic.message import AuthorizationResponse
from oic.oic.message import AccessTokenResponse
#from oic.oic.message import ProviderConfigurationResponse
from oic.oic.message import RegistrationRequest
from oic.oic.message import RegistrationResponse
from oic.oic.message import IssuerRequest
from oic.oic.message import IssuerResponse
from oic.oauth2.message import ErrorResponse
from oic.oic.message import SCHEMA, msg_deser
from oic.oic.message import message
from oic.oauth2 import Grant
#from oic.oauth2 import DEF_SIGN_ALG
from oic.oauth2 import rndstr
from oic.oauth2.consumer import TokenError
@ -85,9 +73,9 @@ def build_userinfo_claims(claims, format="signed", locale="us-en"):
}
}
"""
claim = Claims(**claims)
claim = message("Claims", **claims)
return UserInfoClaim(claim, format=format, locale=locale)
return message("UserInfoClaim", claims=claim, format=format, locale=locale)
#def construct_openid_request(arq, keys, algorithm=DEF_SIGN_ALG, iss=None,
@ -127,9 +115,12 @@ def clean_response(aresp):
:param aresp: The original AccessTokenResponse
:return: An AccessTokenResponse instance
"""
atr = AccessTokenResponse()
for prop in AccessTokenResponse.c_attributes.keys():
setattr(atr, prop, getattr(aresp, prop))
atr = message("AccessTokenResponse")
for prop in atr.parameters():
try:
atr[prop] = aresp[prop]
except KeyError:
pass
return atr
@ -233,7 +224,17 @@ class Consumer(Client):
_log_info("- begin -")
_path = http_util.geturl(environ, False, False)
self.redirect_uris = [_path + self.config["authz_page"]]
_page = self.config["authz_page"]
if not _path.endswith("/"):
if _page.startswith("/"):
self.redirect_uris = [_path + _page]
else:
self.redirect_uris = ["%s/%s" % (_path, _page)]
else:
if _page.startswith("/"):
self.redirect_uris = [_path + _page[1:]]
else:
self.redirect_uris = ["%s/%s" % (_path, _page)]
# Put myself in the dictionary of sessions, keyed on session-id
if not self.seed:
@ -274,8 +275,8 @@ class Consumer(Client):
extra_args=None)
if self.config["request_method"] == "file":
id_request = areq.request
areq.request = None
id_request = areq["request"]
del areq["request"]
_filedir = self.config["temp_dir"]
_webpath = self.config["temp_path"]
_name = rndstr(10)
@ -287,16 +288,16 @@ class Consumer(Client):
fid.write(id_request)
fid.close()
_webname = "%s%s%s" % (_path,_webpath,_name)
areq.request_uri = _webname
areq["request_uri"] = _webname
self.request_uri = _webname
self._backup(sid)
else:
areq = self.construct_AuthorizationRequest(AuthorizationRequest,
request_args=args)
areq = self.construct_AuthorizationRequest(
SCHEMA["AuthorizationRequest"],
request_args=args)
location = "%s?%s" % (self.authorization_endpoint,
areq.get_urlencoded())
location = areq.request(self.authorization_endpoint)
if self.debug:
_log_info("Redirecting to: %s" % location)
@ -335,39 +336,41 @@ class Consumer(Client):
if "code" in self.config["response_type"]:
# Might be an error response
_log_info("Expect Authorization Response")
aresp = self.parse_response(AuthorizationResponse, info=_query,
aresp = self.parse_response(SCHEMA["AuthorizationResponse"],
info=_query,
format="urlencoded")
if isinstance(aresp, ErrorResponse):
if aresp.type() == "ErrorResponse":
_log_info("ErrorResponse: %s" % aresp)
raise AuthzError(aresp.error)
_log_info("Aresp: %s" % aresp)
_state = aresp["state"]
try:
self.update(aresp.state)
self.update(_state)
except KeyError:
raise UnknownState(aresp.state)
raise UnknownState(_state)
self.redirect_uris = [self.sdb[aresp.state]["redirect_uris"]]
self.redirect_uris = [self.sdb[_state]["redirect_uris"]]
# May have token and id_token information too
if aresp.access_token:
if "access_token" in aresp:
atr = clean_response(aresp)
self.access_token = atr
# update the grant object
self.get_grant(state=aresp.state).add_token(atr)
self.get_grant(state=_state).add_token(atr)
else:
atr = None
self._backup(aresp.state)
self._backup(_state)
idt = None
return aresp, atr, idt
else: # implicit flow
_log_info("Expect Access Token Response")
atr = self.parse_response(AccessTokenResponse, info=_query,
format="urlencoded", extended=True)
if isinstance(atr, ErrorResponse):
atr = self.parse_response(SCHEMA["AccessTokenResponse"],
info=_query, format="urlencoded")
if atr.type() == "ErrorResponse":
raise TokenError(atr.error)
idt = None
@ -397,7 +400,7 @@ class Consumer(Client):
logger.info("Access Token Response: %s" % resp)
if isinstance(resp, ErrorResponse):
if resp.type() == "ErrorResponse":
raise TokenError(resp.error)
#self._backup(self.sdb["seed:%s" % _cli.seed])
@ -413,7 +416,7 @@ class Consumer(Client):
self.log = logger
uinfo = self.do_user_info_request(state=self.state)
if isinstance(uinfo, ErrorResponse):
if uinfo.type() == "ErrorResponse":
raise TokenError(uinfo.error)
self.user_info = uinfo
@ -441,10 +444,11 @@ class Consumer(Client):
raise
if response.status == 200:
result = IssuerResponse.from_json(content)
if result.SWD_service_redirect:
_loc = result.SWD_service_redirect.location
_uri = IssuerRequest(ISSUER_URL, principal).request(_loc)
result = msg_deser(content, "json", "IssuerResponse")
if "SWD_service_redirect" in result:
_loc = result["SWD_service_redirect"]["location"]
_uri = message("IssuerRequest", service=ISSUER_URL,
principal=principal).request(_loc)
return self.discovery_query(_uri, principal)
else:
return result
@ -463,32 +467,35 @@ class Consumer(Client):
def discover(self, principal, idtype="mail"):
_loc = SWD_PATTERN % self.get_domain(principal, idtype)
uri = IssuerRequest(ISSUER_URL, principal).request(_loc)
uri = message("IssuerRequest", service=ISSUER_URL,
principal=principal).request(_loc)
result = self.discovery_query(uri, principal)
return result.locations[0]
return result["locations"][0]
def register(self, server, type="client_associate", **kwargs):
req = RegistrationRequest(type=type)
req = message("RegistrationRequest", type=type)
if type == "client_update":
req.client_id = self.client_id
req.client_secret = self.client_secret
req["client_id"] = self.client_id
req["client_secret"] = self.client_secret
for prop in RegistrationRequest.c_attributes.keys():
for prop in req.parameters():
if prop in ["type", "client_id", "client_secret"]:
continue
try:
val = getattr(self, prop)
if val:
setattr(req, prop, val)
req[prop] = val
except Exception:
val = None
if not val:
if prop in kwargs:
setattr(req, prop, kwargs[prop])
try:
req[prop] = kwargs[prop]
except KeyError:
pass
headers = {"content-type": "application/x-www-form-urlencoded"}
(response, content) = self.http.request(server, "POST",
@ -496,12 +503,12 @@ class Consumer(Client):
headers=headers)
if response.status == 200:
resp = RegistrationResponse.from_json(content)
self.client_secret = resp.client_secret
self.client_id = resp.client_id
self.registration_expires = resp.expires_at
resp = msg_deser(content, "json", "RegistrationResponse")
self.client_secret = resp["client_secret"]
self.client_id = resp["client_id"]
self.registration_expires = resp["expires_at"]
else:
err = ErrorResponse.from_json(content)
err = msg_deser(content, "json", "ErrorResponse")
raise Exception("Registration failed: %s" % err.get_json())
return resp

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,8 @@
__author__ = 'rohe0002'
import random
#import httplib2
import base64
#from random import SystemRandom
from urlparse import parse_qs
from oic.oauth2.provider import Provider as AProvider
@ -16,27 +13,17 @@ from oic.utils.http_util import *
from oic.utils import time_util
from oic.oauth2 import MissingRequiredAttribute
from oic.oauth2.provider import AuthnFailure
from oic.oauth2 import rndstr
from oic.oauth2.message import ErrorResponse
from oic.oauth2.provider import AuthnFailure
from oic.oauth2.message import by_schema
from oic.oauth2.message import SCHEMA as OA2_SCHEMA
from oic.oic import Server
from oic.oic.message import AuthorizationResponse, AuthnToken
from oic.oic.message import AuthorizationErrorResponse
from oic.oic.message import msg_deser
from oic.oic.message import SCOPE2CLAIMS
from oic.oic.message import AuthorizationRequest
from oic.oic.message import AccessTokenResponse
from oic.oic.message import AccessTokenRequest
from oic.oic.message import TokenErrorResponse
from oic.oic.message import OpenIDRequest
from oic.oic.message import IdToken
from oic.oic.message import RegistrationRequest
from oic.oic.message import RegistrationResponse
from oic.oic.message import ProviderConfigurationResponse
from oic.oic.message import UserInfoClaim
from oic import oauth2
from oic.oic.message import message
from oic.oic.message import SCHEMA
from oic.oic import JWT_BEARER
class OICError(Exception):
@ -93,53 +80,34 @@ def secret(seed, id):
csum.update(id)
return csum.hexdigest()
##noinspection PyUnusedLocal
#def code_response(**kwargs):
# _areq = kwargs["areq"]
# _scode = kwargs["scode"]
# aresp = AuthorizationResponse()
# if _areq.state:
# aresp.state = _areq.state
# if _areq.nonce:
# aresp.nonce = _areq.nonce
# aresp.code = _scode
# return aresp
#
#def token_response(**kwargs):
# _areq = kwargs["areq"]
# _scode = kwargs["scode"]
# _sdb = kwargs["sdb"]
# _dic = _sdb.update_to_token(_scode, issue_refresh=False)
#
# aresp = oauth2.factory(AccessTokenResponse, **_dic)
# if _areq.scope:
# aresp.scope = _areq.scope
# return aresp
def add_token_info(aresp, sdict):
for prop in AccessTokenResponse.c_attributes.keys():
try:
if sdict[prop]:
setattr(aresp, prop, sdict[prop])
except KeyError:
pass
#def update_info(aresp, sdict):
# for prop in aresp._schema["param"].keys():
# try:
# aresp[prop] = sdict[prop]
# except KeyError:
# pass
def code_token_response(**kwargs):
_areq = kwargs["areq"]
_scode = kwargs["scode"]
_sdb = kwargs["sdb"]
aresp = AuthorizationResponse()
if _areq.state:
aresp.state = _areq.state
if _areq.nonce:
aresp.nonce = _areq.nonce
if _areq.scope:
aresp.scope = _areq.scope
aresp.code = _scode
aresp = message("AuthorizationResponse")
for key in ["state", "nonce", "scope"]:
try:
aresp[key] = _areq[key]
except KeyError:
pass
aresp["code"] = _scode
_dic = _sdb.update_to_token(_scode, issue_refresh=False)
add_token_info(aresp, _dic)
for prop in SCHEMA["AccessTokenResponse"]["param"].keys():
try:
aresp[prop] = _dic[prop]
except KeyError:
pass
return aresp
@ -170,8 +138,6 @@ def verify_acr_level(req, level):
raise AccessDenied
class Provider(AProvider):
authorization_request = AuthorizationRequest
def __init__(self, name, sdb, cdb, function, userdb, urlmap=None,
debug=0, cache=None, timeout=None, proxy_info=None,
follow_redirects=True, ca_certs="", jwt_keys=None):
@ -204,13 +170,15 @@ class Provider(AProvider):
# Handle the idtoken_claims
extra = {}
try:
oidreq = OpenIDRequest.from_json(session["oidreq"])
itc = oidreq.id_token
info_log("ID Token claims: %s" % itc.dictionary())
if itc.max_age:
inawhile = {"seconds": itc.max_age}
if itc.claims:
for key, val in itc.claims.items():
oidreq = msg_deser(session["oidreq"], "json", "OpenIDRequest")
itc = oidreq["id_token"]
info_log("ID Token claims: %s" % itc.to_dict())
try:
inawhile = {"seconds": itc["max_age"]}
except KeyError:
inawhile = {}
if "claims" in itc:
for key, val in itc["claims"].items():
if key == "auth_time":
extra["auth_time"] = time_util.utc_time_sans_frac()
elif key == "acr":
@ -219,14 +187,12 @@ class Provider(AProvider):
except KeyError:
pass
idt = IdToken(iss=self.name,
user_id=session["user_id"],
aud = session["client_id"],
exp = time_util.epoch_in_a_while(**inawhile),
acr=loa,
)
idt = message("IdToken", iss=self.name, user_id=session["user_id"],
aud = session["client_id"],
exp = time_util.epoch_in_a_while(**inawhile), acr=loa)
for key, val in extra.items():
setattr(idt, key, val)
idt[key] = val
if "nonce" in session:
idt.nonce = session["nonce"]
@ -241,17 +207,19 @@ class Provider(AProvider):
if info_log:
info_log("Sign idtoken with '%s'" % ckey)
return idt.get_jwt(key=ckey)
return idt.to_jwt(key=ckey)
def _error(self, environ, start_response, error, descr=None):
response = ErrorResponse(error=error, error_description=descr)
resp = Response(response.get_json(), content="application/json")
response = message(OA2_SCHEMA["ErrorResponse"], error=error,
error_description=descr)
resp = Response(response.to_json(), content="application/json")
return resp(environ, start_response)
def _authz_error(self, environ, start_response, error, descr=None):
response = AuthorizationErrorResponse(error=error,
error_description=descr)
resp = Response(response.get_json(), content="application/json")
response = message("AuthorizationErrorResponse", error=error,
error_description=descr)
resp = Response(response.to_json(), content="application/json")
return resp(environ, start_response)
def authorization_endpoint(self, environ, start_response, logger,
@ -276,8 +244,7 @@ class Provider(AProvider):
# Same serialization used for GET and POST
try:
areq = self.server.parse_authorization_request(query=query,
extended=True)
areq = self.server.parse_authorization_request(query=query)
except MissingRequiredAttribute, err:
resp = BadRequest("%s" % err)
return resp(environ, start_response)
@ -285,54 +252,64 @@ class Provider(AProvider):
resp = BadRequest("%s" % err)
return resp(environ, start_response)
if self.debug:
_log_info("Prompt: '%s'" % areq.prompt)
if "none" in areq.prompt:
if len(areq.prompt) > 1:
return self._error(environ, start_response, "invalid_request")
else:
return self._authz_error(environ, start_response,
"login_required")
if "prompt" in areq:
if self.debug:
_log_info("Prompt: '%s'" % areq["prompt"])
if "none" in areq["prompt"]:
if len(areq["prompt"]) > 1:
return self._error(environ, start_response,
"invalid_request")
else:
return self._authz_error(environ, start_response,
"login_required")
if areq.client_id not in self.cdb:
raise UnknownClient(areq.client_id)
if areq["client_id"] not in self.cdb:
raise UnknownClient(areq["client_id"])
# verify that the redirect URI is resonable
if areq.redirect_uri:
if "redirect_uri" in areq:
try:
assert areq.redirect_uri in self.cdb[
areq.client_id]["redirect_uris"]
assert areq["redirect_uri"] in self.cdb[
areq["client_id"]]["redirect_uris"]
except AssertionError:
return self._authz_error(environ, start_response,
"invalid_request_redirect_uri")
if self.debug:
_log_info("AREQ keys: %s" % areq.keys())
# Is there an request decode it
openid_req = None
if "request" in areq or "request_uri" in areq:
if self.debug:
_log_info("OpenID request")
try:
_keystore = self.server.keystore
jwt_key = _keystore.get_keys("verify", owner=None)
except KeyError: # TODO
raise KeyError("Missing verifying key")
if areq.request:
if "request" in areq:
try:
openid_req = OpenIDRequest.set_jwt(areq.request, jwt_key)
openid_req = message("OpenIDRequest").from_jwt(
areq["request"],
jwt_key)
except Exception:
return self._authz_error(environ, start_response,
"invalid_openid_request_object")
elif areq.request_uri:
elif "request_uri" in areq:
# Do a HTTP get
_req = self.http.request(areq.request_uri)
_req = self.http.request(areq["request_uri"])
if not _req:
return self._authz_error(environ, start_response,
"invalid_request_uri")
try:
openid_req = OpenIDRequest.set_jwt(_req, jwt_key)
openid_req = message("OpenIDRequest").from_jwt(_req,
jwt_key)
except Exception:
return self._authz_error(environ, start_response,
"invalid_openid_request_object")
@ -346,10 +323,10 @@ class Provider(AProvider):
_log_info("SID:%s" % bsid)
if openid_req:
_max_age = -1
if openid_req.id_token:
if openid_req.id_token.max_age:
_max_age = openid_req.id_token.max_age
try:
_max_age = openid_req["id_token"]["max_age"]
except KeyError:
_max_age = -1
if _max_age >= 0:
if "handle" in kwargs:
@ -365,6 +342,19 @@ class Provider(AProvider):
areq=areq, user=user)
except ValueError:
pass
else:
if "handle" in kwargs:
try:
(b64sid, timestamp) = kwargs["handle"]
_log_info("- SSO -")
_scode = base64.b64decode(b64sid)
user = self.sdb[_scode]["user_id"]
_sdb.update(sid, "user_id", user)
return self.authenticated(environ, start_response,
logger, active_auth=bsid,
areq=areq, user=user)
except ValueError:
pass
# DEFAULT: start the authentication process
return self.function["authenticate"](environ, start_response, bsid)
@ -377,32 +367,32 @@ class Provider(AProvider):
except AuthnFailure:
pass
if areq.client_id not in self.cdb:
if areq["client_id"] not in self.cdb:
return False
if areq.client_secret: # client_secret_post
identity = areq.client_id
if self.cdb[identity]["client_secret"] == areq.client_secret:
if "client_secret" in areq: # client_secret_post
identity = areq["client_id"]
if self.cdb[identity]["client_secret"] == areq["client_secret"]:
return True
elif areq.client_assertion: # client_secret_jwt or public_key_jwt
if areq.client_assertion_type != JWT_BEARER:
elif "client_assertion" in areq: # client_secret_jwt or public_key_jwt
if areq["client_assertion_type"] != JWT_BEARER:
return False
key_col = {areq.client_id:
self.keystore.get_verify_key(owner=areq.client_id)}
key_col = {areq["client_id"]:
self.keystore.get_verify_key(owner=areq["client_id"])}
key_col.update({".":self.keystore.get_verify_key()})
if log_info:
log_info("key_col: %s" % (key_col,))
bjwt = AuthnToken.set_jwt(areq.client_assertion, key_col)
bjwt = message("AuthnToken").from_jwt(areq["client_assertion"],
key_col)
try:
assert bjwt.iss == areq.client_id # Issuer = the client
assert bjwt["iss"] == areq["client_id"] # Issuer = the client
# Is this true bjwt.iss == areq.client_id
assert str(bjwt.iss) in self.cdb # It's a client I know
assert str(bjwt.aud) == geturl(environ,
query=False) # audience = me
assert str(bjwt["iss"]) in self.cdb # It's a client I know
assert str(bjwt["aud"]) == geturl(environ, query=False)
return True
except AssertionError:
pass
@ -424,29 +414,29 @@ class Provider(AProvider):
if self.debug:
_log_info("body: %s" % body)
areq = AccessTokenRequest.set_urlencoded(body, extended=True)
areq = msg_deser(body, "urlencoded", "AccessTokenRequest")
if self.debug:
_log_info("environ: %s" % environ)
if not self.verify_client(environ, areq, _log_info):
_log_info("could not verify client")
err = TokenErrorResponse(error="unathorized_client")
resp = Unauthorized(err.get_json(), content="application/json")
err = message("TokenErrorResponse", error="unathorized_client")
resp = Unauthorized(err.to_json(), content="application/json")
return resp(environ, start_response)
if self.debug:
_log_info("AccessTokenRequest: %s" % areq)
assert areq.grant_type == "authorization_code"
assert areq["grant_type"] == "authorization_code"
# assert that the code is valid
_info = _sdb[areq.code]
_info = _sdb[areq["code"]]
# If redirect_uri was in the initial authorization request
# verify that the one given here is the correct one.
if "redirect_uri" in _info:
assert areq.redirect_uri == _info["redirect_uri"]
assert areq["redirect_uri"] == _info["redirect_uri"]
if self.debug:
_log_info("All checks OK")
@ -461,7 +451,7 @@ class Provider(AProvider):
_idtoken = None
try:
_tinfo = _sdb.update_to_token(areq.code, id_token=_idtoken)
_tinfo = _sdb.update_to_token(areq["code"], id_token=_idtoken)
except Exception,err:
_log_info("Error: %s" % err)
raise
@ -469,7 +459,8 @@ class Provider(AProvider):
if self.debug:
_log_info("_tinfo: %s" % _tinfo)
atr = oauth2.factory(AccessTokenResponse, **_tinfo)
atr = message("AccessTokenResponse",
**by_schema(SCHEMA["AccessTokenResponse"], **_tinfo))
if self.debug:
_log_info("AccessTokenResponse: %s" % atr)
@ -510,7 +501,7 @@ class Provider(AProvider):
else:
uireq = self.server.parse_user_info_request(data=query)
_log_info("user_info_request: %s" % uireq)
_token = uireq.access_token
_token = uireq["access_token"]
# should be an access token
typ, key = self.sdb.token.type_and_key(_token)
@ -534,23 +525,25 @@ class Provider(AProvider):
except KeyError:
pass
try:
_req = session["oidreq"]
_log_info("OIDREQ: %s" % _req)
oidreq = OpenIDRequest.from_json(_req)
userinfo_claims = oidreq.userinfo
if userinfo_claims:
_claim = userinfo_claims.claims
if "oidreq" in session:
oidreq = msg_deser(session["oidreq"], "json", "OpenIDRequest")
_log_info("OIDREQ: %s" % oidreq.to_dict())
if "userinfo" in oidreq:
userinfo_claims = oidreq["userinfo"]
_claim = oidreq["userinfo"]["claims"]
for key, val in uic.items():
if key not in _claim:
setattr(_claim, key, val)
except KeyError:
if uic:
userinfo_claims = UserInfoClaim(claims=uic)
_claim[key] = val
elif uic:
userinfo_claims = message("UserInfoClaim", claims=uic)
else:
userinfo_claims = None
userinfo_claims = None
elif uic:
userinfo_claims = message("UserInfoClaim", claims=uic)
else:
userinfo_claims = None
_log_info("userinfo_claim: %s" % userinfo_claims)
_log_info("userinfo_claim: %s" % userinfo_claims.to_dict())
_log_info("userdb: %s" % self.userdb.keys())
#logger.info("oidreq: %s[%s]" % (oidreq, type(oidreq)))
info = self.function["userinfo"](self, self.userdb,
@ -559,7 +552,7 @@ class Provider(AProvider):
userinfo_claims)
_log_info("info: %s" % (info,))
resp = Response(info.get_json(), content="application/json")
resp = Response(info.to_json(), content="application/json")
return resp(environ, start_response)
#noinspection PyUnusedLocal
@ -576,7 +569,7 @@ class Provider(AProvider):
idt = self.server.parse_check_session_request(query=info)
resp = Response(idt.get_json(), content="application/json")
resp = Response(idt.to_json(), content="application/json")
return resp(environ, start_response)
#noinspection PyUnusedLocal
@ -588,12 +581,17 @@ class Provider(AProvider):
resp = BadRequest("Unsupported method")
return resp(environ, start_response)
logger.info("environ: %s" % environ)
logger.info("info: '%s'" % info)
if not info:
logger.info("HTTP_AUTHORIZATION: %s" %
environ["HTTP_AUTHORIZATION"])
info = "access_token=%s" % self._bearer_auth(environ)
logger.info("check_id_endpoint: query=%s" % info)
idt = self.server.parse_check_id_request(query=info)
resp = Response(idt.get_json(), content="application/json")
resp = Response(idt.to_json(), content="application/json")
return resp(environ, start_response)
#noinspection PyUnusedLocal
@ -605,11 +603,11 @@ class Provider(AProvider):
resp = BadRequest("Unsupported method")
return resp(environ, start_response)
request = RegistrationRequest.from_urlencoded(query)
logger.info("RegistrationRequest:%s" % request.dictionary())
request = msg_deser(query, "urlencoded", "RegistrationRequest")
logger.info("RegistrationRequest:%s" % request.to_dict())
_keystore = self.server.keystore
if request.type == "client_associate":
if request["type"] == "client_associate":
# create new id och secret
client_id = rndstr(12)
while client_id in self.cdb:
@ -621,15 +619,15 @@ class Provider(AProvider):
}
_cinfo = self.cdb[client_id]
for key,val in request.dictionary().items():
for key,val in request.items():
_cinfo[key] = val
self.keystore.load_keys(request, client_id)
logger.info("KEYSTORE: %s" % self.keystore._store)
elif request.type == "client_update":
elif request["type"] == "client_update":
# that these are an id,secret pair I know about
client_id = request.client_id
client_id = request["client_id"]
try:
_cinfo = self.cdb[client_id]
except KeyError:
@ -637,7 +635,7 @@ class Provider(AProvider):
resp = BadRequest()
return resp(environ, start_response)
if _cinfo["client_secret"] != request.client_secret:
if _cinfo["client_secret"] != request["client_secret"]:
logger.info("Wrong secret")
resp = BadRequest()
return resp(environ, start_response)
@ -646,12 +644,12 @@ class Provider(AProvider):
client_secret = secret(self.seed, client_id)
_cinfo["client_secret"] = client_secret
old_key = request.client_secret
old_key = request["client_secret"]
_keystore.remove_key(old_key, client_id, type="hmac", usage="sign")
_keystore.remove_key(old_key, client_id, type="hmac",
usage="verify")
for key,val in request.dictionary().items():
for key,val in request.items():
if key in ["client_id", "client_secret"]:
continue
@ -670,10 +668,11 @@ class Provider(AProvider):
# set expiration time
_cinfo["registration_expires"] = time_util.time_sans_frac()+3600
response = RegistrationResponse(client_id, client_secret,
expires_in=3600)
response = message("RegistrationResponse", client_id=client_id,
client_secret=client_secret,
expires_at=_cinfo["registration_expires"])
logger.info("Registration response: %s" % response.dictionary())
logger.info("Registration response: %s" % response.to_dict())
resp = Response(response.to_json(), content="application/json",
headers=[("Cache-Control", "no-store")])
@ -681,7 +680,7 @@ class Provider(AProvider):
#noinspection PyUnusedLocal
def providerinfo_endpoint(self, environ, start_response, logger, *args):
_response = ProviderConfigurationResponse(
_response = message("ProviderConfigurationResponse",
issuer=self.baseurl,
token_endpoint_auth_types_supported=["client_secret_post",
"client_secret_basic",
@ -696,9 +695,9 @@ class Provider(AProvider):
#keys = self.keystore.keys_by_owner(owner=".")
for cert in self.cert:
setattr(_response, "x509_url", "%s%s" % (self.baseurl, cert))
_response["x509_url"] = "%s%s" % (self.baseurl, cert)
for jwk in self.jwk:
setattr(_response, "jwk_url", "%s%s" % (self.baseurl, jwk))
_response["jwk_url"] = "%s%s" % (self.baseurl, jwk)
if not self.baseurl.endswith("/"):
self.baseurl += "/"
@ -707,8 +706,8 @@ class Provider(AProvider):
logger.info("# %s, %s" % (endp, endp.name))
_response[endp.name] = "%s%s" % (self.baseurl, endp.type)
logger.info("provider_info_response: %s" % _response.dictionary(True))
resp = Response(_response.to_json(True), content="application/json",
logger.info("provider_info_response: %s" % _response.to_dict())
resp = Response(_response.to_json(), content="application/json",
headers=[("Cache-Control", "no-store")])
return resp(environ, start_response)
@ -761,8 +760,8 @@ class Provider(AProvider):
#_log_info( "type: %s" % type(asession["authzreq"]))
# pick up the original request
areq = AuthorizationRequest.set_json(asession["authzreq"],
extended=True)
areq = msg_deser(asession["authzreq"], "json",
"AuthorizationRequest")
if self.debug:
_log_info("areq: %s" % areq)
@ -775,64 +774,67 @@ class Provider(AProvider):
except Exception:
raise
_log_info("response type: %s" % areq.response_type)
_log_info("response type: %s" % areq["response_type"])
# create the response
aresp = AuthorizationResponse()
if areq.state:
aresp.state = areq.state
aresp = message("AuthorizationResponse")
try:
aresp["state"] = areq["state"]
except KeyError:
pass
if len(areq.response_type) == 1 and "none" in areq.response_type:
if "response_type" in areq and \
len(areq["response_type"]) == 1 and \
"none" in areq["response_type"]:
pass
else:
_sinfo = self.sdb[scode]
if areq.scope:
aresp.scope = areq.scope
try:
aresp["scope"] = areq["scope"]
except KeyError:
pass
if self.debug:
_log_info("_dic: %s" % _sinfo)
rtype = set(areq.response_type[:])
if "code" in areq.response_type:
aresp.code = _sinfo["code"]
aresp.c_extension = areq.c_extension
rtype = set(areq["response_type"][:])
if "code" in areq["response_type"]:
aresp["code"] = _sinfo["code"]
rtype.remove("code")
else:
self.sdb[scode]["code"] = None
if "id_token" in areq.response_type:
if "id_token" in areq["response_type"]:
id_token = self._id_token(_sinfo, info_log=_log_info)
aresp.id_token = id_token
aresp["id_token"] = id_token
_sinfo["id_token"] = id_token
rtype.remove("id_token")
if "token" in areq.response_type:
if "token" in areq["response_type"]:
_dic = self.sdb.update_to_token(issue_refresh=False,
key=scode)
if self.debug:
_log_info("_dic: %s" % _dic)
for key, val in _dic.items():
if key in aresp.c_attributes:
setattr(aresp, key, val)
if key in aresp.parameters():
aresp[key] = val
aresp.c_extension = areq.c_extension
rtype.remove("token")
if len(rtype):
resp = BadRequest("Unknown response type")
return resp(environ, start_response)
if areq.redirect_uri:
assert areq.redirect_uri in self.cdb[
areq.client_id]["redirect_uris"]
redirect_uri = areq.redirect_uri
if "redirect_uri" in areq:
assert areq["redirect_uri"] in self.cdb[
areq["client_id"]]["redirect_uris"]
redirect_uri = areq["redirect_uri"]
else:
redirect_uri = self.cdb[areq.client_id]["redirect_uris"][0]
location = "%s?%s" % (redirect_uri, aresp.get_urlencoded())
redirect_uri = self.cdb[areq["client_id"]]["redirect_uris"][0]
location = aresp.request(redirect_uri)
if self.debug:
_log_info("Redirected to: '%s' (%s)" % (location, type(location)))
@ -874,4 +876,4 @@ class CheckIDEndpoint(Endpoint):
type = "check_id"
class RegistrationEndpoint(Endpoint) :
type = "registration"
type = "registration"

View File

@ -64,13 +64,20 @@ class Token(object):
csum.update("%f" % random.random())
if user:
csum.update(user)
if areq:
csum.update(areq.state)
if areq.scope:
for val in areq.scope:
csum.update(areq["state"])
try:
for val in areq["scope"]:
csum.update(val)
if areq.redirect_uri:
csum.update(areq.redirect_uri)
except KeyError:
pass
try:
csum.update(areq["redirect_uri"])
except KeyError:
pass
return csum.digest() # 28 bytes long, 224 bits
def _split_token(self, token):
@ -161,33 +168,30 @@ class SessionDB(object):
"user_id": user_id,
"code": access_grant,
"code_used": False,
"authzreq": areq.get_json(),
"client_id": areq.client_id,
"authzreq": areq.to_json(),
"client_id": areq["client_id"],
"expires_in": self.grant_expires_in,
"expires_at": utc_time_sans_frac()+self.grant_expires_in,
"issued": time.time()
}
try:
_val = areq.nonce
_val = areq["nonce"]
if _val:
_dic["nonce"] = _val
except (AttributeError, KeyError):
pass
if areq.redirect_uri:
_dic["redirect_uri"] = areq.redirect_uri
if areq.state:
_dic["state"] = areq.state
# Just an assumption
if areq.scope:
_dic["scope"] = areq.scope
for key in ["redirect_uri", "state", "scope"]:
try:
_dic[key] = areq[key]
except KeyError:
pass
if id_token:
_dic["id_token"] = id_token
if oidreq:
_dic["oidreq"] = oidreq.get_json()
_dic["oidreq"] = oidreq.to_json()
self._db[sid] = _dic
return sid