The new API
This commit is contained in:
parent
e66da03a90
commit
1b7ad6c4e0
|
@ -8,8 +8,7 @@ import logging
|
|||
import re
|
||||
|
||||
from oic.utils.http_util import *
|
||||
from oic.oic.message import OpenIDSchema
|
||||
#from oic.oic.provider import AuthnFailure
|
||||
from oic.oic.message import message
|
||||
|
||||
LOGGER = logging.getLogger("oicServer")
|
||||
hdlr = logging.FileHandler('oc3cp.log')
|
||||
|
@ -49,7 +48,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
|
|||
else:
|
||||
result = identity
|
||||
|
||||
return OpenIDSchema(**result)
|
||||
return message("OpenIDSchema", **result)
|
||||
|
||||
FUNCTIONS = {
|
||||
"verify_client": verify_client,
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
<%inherit file="root.mako" />
|
||||
<%def name="title()">Log in</%def>
|
||||
|
||||
<div class="login_form" class="block">
|
||||
<form action="${action}" method="post" class="login form">
|
||||
<input type="hidden" name="sid" value="${sid}"/>
|
||||
<table>
|
||||
<tr>
|
||||
<td>Username</td>
|
||||
<td><input type="text" name="login" value="${login}"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Password</td>
|
||||
<td><input type="password" name="password"
|
||||
value="${password}"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
</td>
|
||||
<td><input type="submit" name="form.commit"
|
||||
value="Log In"/></td>
|
||||
</tr>
|
||||
</table>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<%def name="add_js()">
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() {
|
||||
bookie.login.init();
|
||||
});
|
||||
</script>
|
||||
</%def>
|
|
@ -0,0 +1 @@
|
|||
__author__ = 'rohe0002'
|
|
@ -0,0 +1,79 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
from mako import runtime, filters, cache
|
||||
UNDEFINED = runtime.UNDEFINED
|
||||
__M_dict_builtin = dict
|
||||
__M_locals_builtin = locals
|
||||
_magic_number = 6
|
||||
_modified_time = 1330607174.483153
|
||||
_template_filename='../htdocs/login.mako'
|
||||
_template_uri='login.mako'
|
||||
_template_cache=cache.Cache(__name__, _modified_time)
|
||||
_source_encoding='utf-8'
|
||||
_exports = ['add_js', 'title']
|
||||
|
||||
|
||||
def _mako_get_namespace(context, name):
|
||||
try:
|
||||
return context.namespaces[(__name__, name)]
|
||||
except KeyError:
|
||||
_mako_generate_namespaces(context)
|
||||
return context.namespaces[(__name__, name)]
|
||||
def _mako_generate_namespaces(context):
|
||||
pass
|
||||
def _mako_inherit(template, context):
|
||||
_mako_generate_namespaces(context)
|
||||
return runtime._inherit_from(context, u'root.mako', _template_uri)
|
||||
def render_body(context,**pageargs):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
__M_locals = __M_dict_builtin(pageargs=pageargs)
|
||||
action = context.get('action', UNDEFINED)
|
||||
login = context.get('login', UNDEFINED)
|
||||
password = context.get('password', UNDEFINED)
|
||||
sid = context.get('sid', UNDEFINED)
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 1
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 2
|
||||
__M_writer(u'\n\n<div class="login_form" class="block">\n <form action="')
|
||||
# SOURCE LINE 5
|
||||
__M_writer(unicode(action))
|
||||
__M_writer(u'" method="post" class="login form">\n <input type="hidden" name="sid" value="')
|
||||
# SOURCE LINE 6
|
||||
__M_writer(unicode(sid))
|
||||
__M_writer(u'"/>\n <table>\n <tr>\n <td>Username</td>\n <td><input type="text" name="login" value="')
|
||||
# SOURCE LINE 10
|
||||
__M_writer(unicode(login))
|
||||
__M_writer(u'"/></td>\n </tr>\n <tr>\n <td>Password</td>\n <td><input type="password" name="password"\n value="')
|
||||
# SOURCE LINE 15
|
||||
__M_writer(unicode(password))
|
||||
__M_writer(u'"/></td>\n </tr>\n <tr>\n </td>\n <td><input type="submit" name="form.commit"\n value="Log In"/></td>\n </tr>\n </table>\n </form>\n</div>\n\n')
|
||||
# SOURCE LINE 32
|
||||
__M_writer(u'\n')
|
||||
return ''
|
||||
finally:
|
||||
context.caller_stack._pop_frame()
|
||||
|
||||
|
||||
def render_add_js(context):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 26
|
||||
__M_writer(u'\n <script type="text/javascript">\n $(document).ready(function() {\n bookie.login.init();\n });\n </script>\n')
|
||||
return ''
|
||||
finally:
|
||||
context.caller_stack._pop_frame()
|
||||
|
||||
|
||||
def render_title(context):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 2
|
||||
__M_writer(u'Log in')
|
||||
return ''
|
||||
finally:
|
||||
context.caller_stack._pop_frame()
|
||||
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
# -*- encoding:utf-8 -*-
|
||||
from mako import runtime, filters, cache
|
||||
UNDEFINED = runtime.UNDEFINED
|
||||
__M_dict_builtin = dict
|
||||
__M_locals_builtin = locals
|
||||
_magic_number = 6
|
||||
_modified_time = 1332142870.815408
|
||||
_template_filename=u'templates/root.mako'
|
||||
_template_uri=u'root.mako'
|
||||
_template_cache=cache.Cache(__name__, _modified_time)
|
||||
_source_encoding='utf-8'
|
||||
_exports = ['css_link', 'pre', 'post', 'css']
|
||||
|
||||
|
||||
def render_body(context,**pageargs):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
__M_locals = __M_dict_builtin(pageargs=pageargs)
|
||||
def pre():
|
||||
return render_pre(context.locals_(__M_locals))
|
||||
self = context.get('self', UNDEFINED)
|
||||
set = context.get('set', UNDEFINED)
|
||||
def post():
|
||||
return render_post(context.locals_(__M_locals))
|
||||
next = context.get('next', UNDEFINED)
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 1
|
||||
self.seen_css = set()
|
||||
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 7
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 10
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 15
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 22
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 25
|
||||
__M_writer(u'<html>\n<head><title>OAuth test</title>\n')
|
||||
# SOURCE LINE 27
|
||||
__M_writer(unicode(self.css()))
|
||||
__M_writer(u'\n<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />\n</head>\n<body>\n')
|
||||
# SOURCE LINE 31
|
||||
__M_writer(unicode(pre()))
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 34
|
||||
__M_writer(unicode(next.body()))
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 35
|
||||
__M_writer(unicode(post()))
|
||||
__M_writer(u'\n</body>\n</html>\n')
|
||||
return ''
|
||||
finally:
|
||||
context.caller_stack._pop_frame()
|
||||
|
||||
|
||||
def render_css_link(context,path,media=''):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
context._push_buffer()
|
||||
self = context.get('self', UNDEFINED)
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 2
|
||||
__M_writer(u'\n')
|
||||
# SOURCE LINE 3
|
||||
if path not in self.seen_css:
|
||||
# SOURCE LINE 4
|
||||
__M_writer(u' <link rel="stylesheet" type="text/css" href="')
|
||||
__M_writer(filters.html_escape(unicode(path)))
|
||||
__M_writer(u'" media="')
|
||||
__M_writer(unicode(media))
|
||||
__M_writer(u'">\n')
|
||||
pass
|
||||
# SOURCE LINE 6
|
||||
__M_writer(u' ')
|
||||
self.seen_css.add(path)
|
||||
|
||||
__M_writer(u'\n')
|
||||
finally:
|
||||
__M_buf, __M_writer = context._pop_buffer_and_writer()
|
||||
context.caller_stack._pop_frame()
|
||||
__M_writer(filters.trim(__M_buf.getvalue()))
|
||||
return ''
|
||||
|
||||
|
||||
def render_pre(context):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
context._push_buffer()
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 11
|
||||
__M_writer(u'\n<div class="header">\n <h1><a href="/">Login</a></h1>\n</div>\n')
|
||||
finally:
|
||||
__M_buf, __M_writer = context._pop_buffer_and_writer()
|
||||
context.caller_stack._pop_frame()
|
||||
__M_writer(filters.trim(__M_buf.getvalue()))
|
||||
return ''
|
||||
|
||||
|
||||
def render_post(context):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
context._push_buffer()
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 16
|
||||
__M_writer(u'\n<div>\n <div class="footer">\n <p>© Copyright 2011 Umeå Universitet </p>\n </div>\n</div>\n')
|
||||
finally:
|
||||
__M_buf, __M_writer = context._pop_buffer_and_writer()
|
||||
context.caller_stack._pop_frame()
|
||||
__M_writer(filters.trim(__M_buf.getvalue()))
|
||||
return ''
|
||||
|
||||
|
||||
def render_css(context):
|
||||
context.caller_stack._push_frame()
|
||||
try:
|
||||
context._push_buffer()
|
||||
def css_link(path,media=''):
|
||||
return render_css_link(context,path,media)
|
||||
__M_writer = context.writer()
|
||||
# SOURCE LINE 8
|
||||
__M_writer(u'\n ')
|
||||
# SOURCE LINE 9
|
||||
__M_writer(unicode(css_link('/css/main.css', 'screen')))
|
||||
__M_writer(u'\n')
|
||||
finally:
|
||||
__M_buf, __M_writer = context._pop_buffer_and_writer()
|
||||
context.caller_stack._pop_frame()
|
||||
__M_writer(filters.trim(__M_buf.getvalue()))
|
||||
return ''
|
||||
|
||||
|
|
@ -9,7 +9,8 @@ import logging
|
|||
import re
|
||||
|
||||
from oic.utils.http_util import *
|
||||
from oic.oic.message import OpenIDSchema, AuthnToken
|
||||
#from oic.oic.message import OpenIDSchema, AuthnToken
|
||||
from oic.oic.message import message, msg_deser
|
||||
from oic.oic.provider import AuthnFailure
|
||||
from oic.oic.claims_provider import ClaimsClient
|
||||
from oic.oic import JWT_BEARER
|
||||
|
@ -70,16 +71,17 @@ def do_authorization(user, session=None):
|
|||
|
||||
#noinspection PyUnusedLocal
|
||||
def verify_client(environ, areq, cdb):
|
||||
if areq.client_secret: # client_secret_post
|
||||
identity = areq.client_id
|
||||
if "client_secret" in areq: # client_secret_post
|
||||
identity = areq["client_id"]
|
||||
if identity in cdb:
|
||||
if cdb[identity]["client_secret"] == areq.client_secret:
|
||||
if cdb[identity]["client_secret"] == areq["client_secret"]:
|
||||
return True
|
||||
elif areq.client_assertion: # client_secret_jwt or public_key_jwt
|
||||
assert areq.client_assertion_type == JWT_BEARER
|
||||
secret = cdb[areq.client_id]["client_secret"]
|
||||
elif "client_assertion" in areq: # client_secret_jwt or public_key_jwt
|
||||
assert areq["client_assertion_type"] == JWT_BEARER
|
||||
secret = cdb[areq["client_id"]]["client_secret"]
|
||||
key_col = {"hmac": secret}
|
||||
bjwt = AuthnToken.set_jwt(areq.client_assertion, key_col)
|
||||
bjwt = msg_deser(areq["client_assertion"], "jwt", "AuthnToken",
|
||||
key=key_col)
|
||||
return False
|
||||
|
||||
#import sys
|
||||
|
@ -118,15 +120,15 @@ def _collect_distributed(srv, cc, user_id, what, alias=""):
|
|||
if not alias:
|
||||
alias = srv
|
||||
|
||||
for key in resp.claims_names:
|
||||
for key in resp["claims_names"]:
|
||||
result["_claims_names"][key] = alias
|
||||
|
||||
if resp.jwt:
|
||||
if "jwt" in resp:
|
||||
result["_claims_sources"][alias] = {"JWT": resp.jwt}
|
||||
else:
|
||||
result["_claims_sources"][alias] = {"endpoint": resp.endpoint}
|
||||
result["_claims_sources"][alias] = {"endpoint": resp["endpoint"]}
|
||||
if "access_token" in resp:
|
||||
result["_claims_sources"][alias]["access_token"] = resp.access_token
|
||||
result["_claims_sources"][alias]["access_token"] = resp["access_token"]
|
||||
|
||||
return result
|
||||
|
||||
|
@ -140,7 +142,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
|
|||
result = {}
|
||||
missing = []
|
||||
optional = []
|
||||
for key, restr in user_info_claims.claims.items():
|
||||
for key, restr in user_info_claims["claims"].items():
|
||||
try:
|
||||
result[key] = identity[key]
|
||||
except KeyError:
|
||||
|
@ -183,7 +185,7 @@ def user_info(oicsrv, userdb, user_id, client_id="", user_info_claims=None):
|
|||
#result = identity
|
||||
result = {"user_id": user_id}
|
||||
|
||||
return OpenIDSchema(**result)
|
||||
return message("OpenIDSchema", **result)
|
||||
|
||||
FUNCTIONS = {
|
||||
"authenticate": do_authentication,
|
||||
|
@ -304,7 +306,7 @@ for endp in ENDPOINTS:
|
|||
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
ROOT = '../'
|
||||
ROOT = './'
|
||||
|
||||
LOOKUP = TemplateLookup(directories=[ROOT + 'templates', ROOT + 'htdocs'],
|
||||
module_directory=ROOT + 'modules',
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
<% self.seen_css = set() %>
|
||||
<%def name="css_link(path, media='')" filter="trim">
|
||||
% if path not in self.seen_css:
|
||||
<link rel="stylesheet" type="text/css" href="${path|h}" media="${media}">
|
||||
% endif
|
||||
<% self.seen_css.add(path) %>
|
||||
</%def>
|
||||
<%def name="css()" filter="trim">
|
||||
${css_link('/css/main.css', 'screen')}
|
||||
</%def>
|
||||
<%def name="pre()" filter="trim">
|
||||
<div class="header">
|
||||
<h1><a href="/">Login</a></h1>
|
||||
</div>
|
||||
</%def>
|
||||
<%def name="post()" filter="trim">
|
||||
<div>
|
||||
<div class="footer">
|
||||
<p>© Copyright 2011 Umeå Universitet </p>
|
||||
</div>
|
||||
</div>
|
||||
</%def>
|
||||
##<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01//EN "
|
||||
##"http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
|
||||
<html>
|
||||
<head><title>OAuth test</title>
|
||||
${self.css()}
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
|
||||
</head>
|
||||
<body>
|
||||
${pre()}
|
||||
## ${comps.dict_to_table(pageargs)}
|
||||
## <hr><hr>
|
||||
${next.body()}
|
||||
${post()}
|
||||
</body>
|
||||
</html>
|
|
@ -1,3 +1 @@
|
|||
# Complete OpenID Connect implementation
|
||||
# __author__ = 'rohe0002'
|
||||
|
||||
__author__ = 'rohe0002'
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
__author__ = 'rohe0002'
|
||||
|
||||
import httplib2
|
||||
import inspect
|
||||
import random
|
||||
import string
|
||||
|
||||
|
@ -22,15 +21,15 @@ DEFAULT_POST_CONTENT_TYPE = 'application/x-www-form-urlencoded'
|
|||
REQUEST2ENDPOINT = {
|
||||
"AuthorizationRequest": "authorization_endpoint",
|
||||
"AccessTokenRequest": "token_endpoint",
|
||||
# ROPCAccessTokenRequest: "authorization_endpoint",
|
||||
# CCAccessTokenRequest: "authorization_endpoint",
|
||||
# ROPCAccessTokenRequest: "authorization_endpoint",
|
||||
# CCAccessTokenRequest: "authorization_endpoint",
|
||||
"RefreshAccessTokenRequest": "token_endpoint",
|
||||
"TokenRevocationRequest": "token_endpoint",
|
||||
}
|
||||
}
|
||||
|
||||
RESPONSE2ERROR = {
|
||||
AuthorizationResponse: [AuthorizationErrorResponse, TokenErrorResponse],
|
||||
AccessTokenResponse: [TokenErrorResponse]
|
||||
"AuthorizationResponse": ["AuthorizationErrorResponse", "TokenErrorResponse"],
|
||||
"AccessTokenResponse": ["TokenErrorResponse"]
|
||||
}
|
||||
|
||||
ENDPOINTS = ["authorization_endpoint", "token_endpoint",
|
||||
|
@ -67,24 +66,24 @@ def client_secret_post(cli, cis, request_args=None, http_args=None, **kwargs):
|
|||
if request_args is None:
|
||||
request_args = {}
|
||||
|
||||
if not cis.client_secret:
|
||||
if "client_secret" not in cis:
|
||||
try:
|
||||
cis.client_secret = http_args["client_secret"]
|
||||
cis["client_secret"] = http_args["client_secret"]
|
||||
del http_args["client_secret"]
|
||||
except (KeyError, TypeError):
|
||||
cis.client_secret = cli.client_secret
|
||||
cis["client_secret"] = cli.client_secret
|
||||
|
||||
cis.client_id = cli.client_id
|
||||
cis["client_id"] = cli.client_id
|
||||
|
||||
return http_args
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def bearer_header(cli, cis, request_args=None, http_args=None, **kwargs):
|
||||
if cis.access_token:
|
||||
_acc_token = cis.access_token
|
||||
cis.access_token = None
|
||||
if "access_token" in cis:
|
||||
_acc_token = cis["access_token"]
|
||||
del cis["access_token"]
|
||||
# Required under certain circumstances :-) not under other
|
||||
cis.c_attributes["access_token"] = SINGLE_OPTIONAL_STRING
|
||||
cis._schema["param"]["access_token"] = SINGLE_OPTIONAL_STRING
|
||||
else:
|
||||
try:
|
||||
_acc_token = request_args["access_token"]
|
||||
|
@ -118,11 +117,11 @@ def bearer_body(cli, cis, request_args=None, http_args=None, **kwargs):
|
|||
if request_args is None:
|
||||
request_args = {}
|
||||
|
||||
if cis.access_token:
|
||||
if "access_token" in cis:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
cis.access_token = request_args["access_token"]
|
||||
cis["access_token"] = request_args["access_token"]
|
||||
except KeyError:
|
||||
try:
|
||||
_state = kwargs["state"]
|
||||
|
@ -131,7 +130,7 @@ def bearer_body(cli, cis, request_args=None, http_args=None, **kwargs):
|
|||
raise Exception("Missing state specification")
|
||||
kwargs["state"] = cli.state
|
||||
|
||||
cis.access_token = cli.get_token(**kwargs).access_token
|
||||
cis["access_token"] = cli.get_token(**kwargs).access_token
|
||||
|
||||
return http_args
|
||||
|
||||
|
@ -140,7 +139,7 @@ AUTHN_METHOD = {
|
|||
"client_secret_post" : client_secret_post,
|
||||
"bearer_header": bearer_header,
|
||||
"bearer_body": bearer_body,
|
||||
}
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
@ -150,7 +149,7 @@ class ExpiredToken(Exception):
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Token(object):
|
||||
_class = AccessTokenResponse
|
||||
_schema = SCHEMA["AccessTokenResponse"]
|
||||
|
||||
def __init__(self, resp=None):
|
||||
self.scope = []
|
||||
|
@ -161,19 +160,11 @@ class Token(object):
|
|||
self.replaced = False
|
||||
|
||||
if resp:
|
||||
for prop in self._class.c_attributes.keys():
|
||||
try:
|
||||
_val = getattr(resp, prop)
|
||||
except KeyError:
|
||||
continue
|
||||
if _val:
|
||||
setattr(self, prop, _val)
|
||||
|
||||
for key, val in resp.c_extension.items():
|
||||
setattr(self, key, val)
|
||||
for prop, val in resp.items():
|
||||
setattr(self, prop, val)
|
||||
|
||||
try:
|
||||
_expires_in = resp.expires_in
|
||||
_expires_in = resp["expires_in"]
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
|
@ -210,10 +201,10 @@ class Token(object):
|
|||
return True
|
||||
|
||||
class Grant(object):
|
||||
_authz_resp = AuthorizationResponse
|
||||
_acc_resp = AccessTokenResponse
|
||||
_authz_resp = "AuthorizationResponse"
|
||||
_acc_resp = "AccessTokenResponse"
|
||||
_token_class = Token
|
||||
|
||||
|
||||
def __init__(self, exp_in=600, resp=None, seed=""):
|
||||
self.grant_expiration_time = 0
|
||||
self.exp_in = exp_in
|
||||
|
@ -232,12 +223,15 @@ class Grant(object):
|
|||
|
||||
def add_code(self, resp):
|
||||
try:
|
||||
self.code = resp.code
|
||||
self.code = resp["code"]
|
||||
self.grant_expiration_time = utc_time_sans_frac() + self.exp_in
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def add_token(self, resp):
|
||||
"""
|
||||
:param resp: A Authorization Response instance
|
||||
"""
|
||||
tok = self._token_class(resp)
|
||||
if tok.access_token:
|
||||
self.tokens.append(tok)
|
||||
|
@ -255,7 +249,7 @@ class Grant(object):
|
|||
return self.__dict__.keys()
|
||||
|
||||
def update(self, resp):
|
||||
if isinstance(resp, self._acc_resp):
|
||||
if resp.type() == self._acc_resp:
|
||||
if "access_token" in resp or "id_token" in resp:
|
||||
tok = self._token_class(resp)
|
||||
if tok not in self.tokens:
|
||||
|
@ -265,7 +259,7 @@ class Grant(object):
|
|||
self.tokens.append(tok)
|
||||
else:
|
||||
self.add_code(resp)
|
||||
elif isinstance(resp, self._authz_resp):
|
||||
elif resp.type() == self._authz_resp:
|
||||
self.add_code(resp)
|
||||
|
||||
def get_token(self, scope=""):
|
||||
|
@ -477,36 +471,37 @@ class KeyStore(object):
|
|||
|
||||
if "x509_url" in inst:
|
||||
try:
|
||||
_verkey = self.load_x509_cert(inst.x509_url, "verify",
|
||||
_verkey = self.load_x509_cert(inst["x509_url"], "verify",
|
||||
_issuer)
|
||||
except Exception:
|
||||
raise Exception(KEYLOADERR % ('x509', inst.x509_url))
|
||||
raise Exception(KEYLOADERR % ('x509', inst["x509_url"]))
|
||||
else:
|
||||
_verkey = None
|
||||
|
||||
if "x509_encryption_url" in inst:
|
||||
try:
|
||||
self.load_x509_cert(inst.x509_encryption_url, "enc",
|
||||
self.load_x509_cert(inst["x509_encryption_url"], "enc",
|
||||
_issuer)
|
||||
except Exception:
|
||||
raise Exception(KEYLOADERR % ('x509_encryption',
|
||||
inst.x509_encryption_url))
|
||||
inst["x509_encryption_url"]))
|
||||
elif _verkey:
|
||||
self.set_decrypt_key(_verkey, "rsa", _issuer)
|
||||
|
||||
if "jwk_url" in inst:
|
||||
try:
|
||||
_verkeys = self.load_jwk(inst.jwk_url, "verify", _issuer)
|
||||
_verkeys = self.load_jwk(inst["jwk_url"], "verify", _issuer)
|
||||
except Exception, err:
|
||||
raise Exception(KEYLOADERR % ('jwk', inst.jwk_url, err))
|
||||
raise Exception(KEYLOADERR % ('jwk', inst["jwk_url"], err))
|
||||
else:
|
||||
_verkeys = []
|
||||
|
||||
if "jwk_encryption_url" in inst:
|
||||
try:
|
||||
self.load_jwk(inst.jwk_url, "enc", _issuer)
|
||||
self.load_jwk(inst["jwk_encryption_url"], "enc", _issuer)
|
||||
except Exception:
|
||||
raise Exception(KEYLOADERR % ('jwk', inst.jwk_encryption_url))
|
||||
raise Exception(KEYLOADERR % ('jwk',
|
||||
inst["jwk_encryption_url"]))
|
||||
elif _verkeys:
|
||||
for key in _verkeys:
|
||||
self.set_decrypt_key(key, "rsa", _issuer)
|
||||
|
@ -550,8 +545,8 @@ class Client(PBase):
|
|||
jwt_keys=None):
|
||||
|
||||
PBase.__init__(self, cache, time_out, proxy_info, follow_redirects,
|
||||
disable_ssl_certificate_validation, ca_certs,
|
||||
httpclass, jwt_keys)
|
||||
disable_ssl_certificate_validation, ca_certs,
|
||||
httpclass, jwt_keys)
|
||||
|
||||
self.client_id = client_id
|
||||
self.client_timeout = client_timeout
|
||||
|
@ -608,30 +603,21 @@ class Client(PBase):
|
|||
|
||||
return None
|
||||
|
||||
# def scope_from_state(self, state):
|
||||
#
|
||||
# def grant_from_state_or_scope(self, state, scope):
|
||||
def _parse_args(self, schema, **kwargs):
|
||||
ar_args = kwargs.copy()
|
||||
|
||||
def _parse_args(self, klass, **kwargs):
|
||||
ar_args = {}
|
||||
for prop, val in kwargs.items():
|
||||
if prop in klass.c_attributes:
|
||||
ar_args[prop] = val
|
||||
elif prop.startswith("extra_"):
|
||||
if prop[6:] not in klass.c_attributes:
|
||||
ar_args[prop[6:]] = val
|
||||
|
||||
# Used to not overwrite defaults
|
||||
argspec = inspect.getargspec(klass.__init__)
|
||||
for prop in klass.c_attributes.keys():
|
||||
if prop not in ar_args:
|
||||
index = argspec[0].index(prop) -1 # skip self
|
||||
if not argspec[3][index]:
|
||||
if prop == "redirect_uri":
|
||||
ar_args[prop] = getattr(self, "redirect_uris",
|
||||
[None])[0]
|
||||
else:
|
||||
ar_args[prop] = getattr(self, prop, None)
|
||||
for prop in schema["param"].keys():
|
||||
if prop in ar_args:
|
||||
continue
|
||||
else:
|
||||
if prop == "redirect_uri":
|
||||
_val = getattr(self, "redirect_uris", [None])[0]
|
||||
if _val:
|
||||
ar_args[prop] = _val
|
||||
else:
|
||||
_val = getattr(self, prop, None)
|
||||
if _val:
|
||||
ar_args[prop] = _val
|
||||
|
||||
return ar_args
|
||||
|
||||
|
@ -693,18 +679,19 @@ class Client(PBase):
|
|||
else:
|
||||
raise ExpiredToken()
|
||||
|
||||
def construct_request(self, reqclass, request_args=None, extra_args=None):
|
||||
def construct_request(self, schema, request_args=None, extra_args=None):
|
||||
if request_args is None:
|
||||
request_args = {}
|
||||
|
||||
args = self._parse_args(reqclass, **request_args)
|
||||
args = self._parse_args(schema, **request_args)
|
||||
|
||||
if extra_args:
|
||||
args.update(extra_args)
|
||||
return reqclass(**args)
|
||||
return Message(schema["name"], schema, **args)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_AuthorizationRequest(self, reqclass=AuthorizationRequest,
|
||||
def construct_AuthorizationRequest(self,
|
||||
schema=SCHEMA["AuthorizationRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
|
@ -716,10 +703,11 @@ class Client(PBase):
|
|||
else:
|
||||
request_args = {}
|
||||
|
||||
return self.construct_request(reqclass, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_AccessTokenRequest(self, cls=AccessTokenRequest,
|
||||
def construct_AccessTokenRequest(self,
|
||||
schema=SCHEMA["AccessTokenRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
|
@ -727,8 +715,8 @@ class Client(PBase):
|
|||
|
||||
if not grant.is_valid():
|
||||
raise GrantExpired("Authorization Code to old %s > %s" % (
|
||||
utc_time_sans_frac(),
|
||||
grant.grant_expiration_time))
|
||||
utc_time_sans_frac(),
|
||||
grant.grant_expiration_time))
|
||||
|
||||
if request_args is None:
|
||||
request_args = {}
|
||||
|
@ -743,12 +731,12 @@ class Client(PBase):
|
|||
elif not request_args["client_id"]:
|
||||
request_args["client_id"] = self.client_id
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
def construct_RefreshAccessTokenRequest(self,
|
||||
cls=RefreshAccessTokenRequest,
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
schema=SCHEMA["RefreshAccessTokenRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
if request_args is None:
|
||||
request_args = {}
|
||||
|
@ -762,9 +750,10 @@ class Client(PBase):
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
def construct_TokenRevocationRequest(self, cls=TokenRevocationRequest,
|
||||
def construct_TokenRevocationRequest(self,
|
||||
schema=SCHEMA["TokenRevocationRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
|
@ -774,15 +763,15 @@ class Client(PBase):
|
|||
token = self.get_token(**kwargs)
|
||||
|
||||
request_args["token"] = token.access_token
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
def get_or_post(self, uri, method, req, extend=False, **kwargs):
|
||||
def get_or_post(self, uri, method, req, **kwargs):
|
||||
if method == "GET":
|
||||
path = uri + '?' + req.get_urlencoded(extended=extend)
|
||||
path = uri + '?' + req.to_urlencoded()
|
||||
body = None
|
||||
elif method == "POST":
|
||||
path = uri
|
||||
body = req.get_urlencoded(extended=extend)
|
||||
body = req.to_urlencoded()
|
||||
header_ext = {"content-type": DEFAULT_POST_CONTENT_TYPE}
|
||||
if "headers" in kwargs.keys():
|
||||
kwargs["headers"].update(header_ext)
|
||||
|
@ -793,13 +782,12 @@ class Client(PBase):
|
|||
|
||||
return path, body, kwargs
|
||||
|
||||
def uri_and_body(self, cls, cis, method="POST", request_args=None,
|
||||
extend=False, **kwargs):
|
||||
def uri_and_body(self, reqmsg, cis, method="POST", request_args=None,
|
||||
**kwargs):
|
||||
|
||||
uri = self._endpoint(self.request2endpoint[cls.__name__],
|
||||
**request_args)
|
||||
uri = self._endpoint(self.request2endpoint[reqmsg], **request_args)
|
||||
|
||||
uri, body, kwargs = self.get_or_post(uri, method, cis, extend, **kwargs)
|
||||
uri, body, kwargs = self.get_or_post(uri, method, cis, **kwargs)
|
||||
try:
|
||||
h_args = {"headers": kwargs["headers"]}
|
||||
except KeyError:
|
||||
|
@ -807,15 +795,15 @@ class Client(PBase):
|
|||
|
||||
return uri, body, h_args, cis
|
||||
|
||||
def request_info(self, cls, method="POST", request_args=None,
|
||||
def request_info(self, schema, method="POST", request_args=None,
|
||||
extra_args=None, **kwargs):
|
||||
|
||||
if request_args is None:
|
||||
request_args = {}
|
||||
|
||||
cis = getattr(self, "construct_%s" % cls.__name__)(cls, request_args,
|
||||
extra_args,
|
||||
**kwargs)
|
||||
cis = getattr(self, "construct_%s" % schema["name"])(schema,
|
||||
request_args,
|
||||
extra_args, **kwargs)
|
||||
|
||||
if "authn_method" in kwargs:
|
||||
h_arg = self.init_authentication_method(cis,
|
||||
|
@ -830,81 +818,51 @@ class Client(PBase):
|
|||
else:
|
||||
kwargs["headers"] = h_arg
|
||||
|
||||
if extra_args:
|
||||
extend = True
|
||||
else:
|
||||
extend = False
|
||||
return self.uri_and_body(schema["name"], cis, method, request_args,
|
||||
**kwargs)
|
||||
|
||||
return self.uri_and_body(cls, cis, method, request_args,
|
||||
extend=extend, **kwargs)
|
||||
|
||||
def parse_response(self, cls, info="", format="json", state="",
|
||||
extended=False, **kwargs):
|
||||
def parse_response(self, schema, info="", format="json", state="",
|
||||
**kwargs):
|
||||
"""
|
||||
Parse a response
|
||||
|
||||
:param cls: Which class to use when parsing the response
|
||||
:param schema: Which schema the response should adhere to
|
||||
:param info: The response, can be either an JSON code or an urlencoded
|
||||
form:
|
||||
:param format: Which serialization that was used
|
||||
:param extended: If non-standard parameters should be honored
|
||||
:return: The parsed and to some extend verified response
|
||||
"""
|
||||
|
||||
_r2e = self.response2error
|
||||
|
||||
err = None
|
||||
if format == "json":
|
||||
try:
|
||||
resp = cls.set_json(info, extended)
|
||||
assert resp.verify(**kwargs)
|
||||
except Exception, err:
|
||||
resp = None
|
||||
|
||||
eresp = None
|
||||
try:
|
||||
for errcls in _r2e[cls]:
|
||||
try:
|
||||
eresp = errcls.set_json(info, extended)
|
||||
eresp.verify()
|
||||
break
|
||||
except Exception:
|
||||
eresp = None
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
elif format == "urlencoded":
|
||||
if format == "urlencoded":
|
||||
if '?' in info or '#' in info:
|
||||
parts = urlparse.urlparse(info)
|
||||
scheme, netloc, path, params, query, fragment = parts[:6]
|
||||
# either query of fragment
|
||||
if query:
|
||||
pass
|
||||
info = query
|
||||
else:
|
||||
query = fragment
|
||||
else:
|
||||
query = info
|
||||
info = fragment
|
||||
|
||||
try:
|
||||
resp = cls.set_urlencoded(query, extended)
|
||||
assert resp.verify(**kwargs)
|
||||
except Exception, err:
|
||||
resp = None
|
||||
try:
|
||||
resp = msg_deser(info, format, schema=schema)
|
||||
assert resp.verify(**kwargs)
|
||||
except Exception, err:
|
||||
resp = None
|
||||
|
||||
eresp = None
|
||||
try:
|
||||
for errcls in _r2e[cls]:
|
||||
try:
|
||||
eresp = errcls.set_urlencoded(query, extended)
|
||||
eresp.verify()
|
||||
break
|
||||
except Exception:
|
||||
eresp = None
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception("Unknown package format: '%s'" % format)
|
||||
eresp = None
|
||||
try:
|
||||
for errmsg in _r2e[schema["name"]]:
|
||||
try:
|
||||
eresp = msg_deser(info, format, typ=errmsg)
|
||||
eresp.verify()
|
||||
break
|
||||
except Exception:
|
||||
eresp = None
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Error responses has higher precedence
|
||||
if eresp:
|
||||
|
@ -913,9 +871,9 @@ class Client(PBase):
|
|||
if not resp:
|
||||
raise err
|
||||
|
||||
if isinstance(resp, (AuthorizationResponse, AccessTokenResponse)):
|
||||
if resp._name in ["AuthorizationResponse", "AccessTokenResponse"]:
|
||||
try:
|
||||
_state = resp.state
|
||||
_state = resp["state"]
|
||||
except (AttributeError, KeyError):
|
||||
_state = ""
|
||||
|
||||
|
@ -931,7 +889,7 @@ class Client(PBase):
|
|||
|
||||
#noinspection PyUnusedLocal
|
||||
def init_authentication_method(self, cis, authn_method, request_args=None,
|
||||
http_args=None, **kwargs):
|
||||
http_args=None, **kwargs):
|
||||
|
||||
if http_args is None:
|
||||
http_args = {}
|
||||
|
@ -944,16 +902,15 @@ class Client(PBase):
|
|||
else:
|
||||
return http_args
|
||||
|
||||
def request_and_return(self, url, respcls=None, method="GET", body=None,
|
||||
body_type="json", extended=True,
|
||||
state="", http_args=None, **kwargs):
|
||||
def request_and_return(self, url, schema=None, method="GET", body=None,
|
||||
body_type="json", state="", http_args=None,
|
||||
**kwargs):
|
||||
"""
|
||||
:param url: The URL to which the request should be sent
|
||||
:param respcls: The class the should represent the response
|
||||
:param schema: The schema the response should adhere to
|
||||
:param method: Which HTTP method to use
|
||||
:param body: A message body if any
|
||||
:param body_type: The format of the body of the return message
|
||||
:param extended: If non-standard parameters should be honored
|
||||
:param http_args: Arguments for the HTTP client
|
||||
:return: A cls or ErrorResponse instance or the HTTP response
|
||||
instance if no response body was expected.
|
||||
|
@ -985,41 +942,44 @@ class Client(PBase):
|
|||
raise Exception("ERROR: Something went wrong [%s]" % response.status)
|
||||
|
||||
if body_type:
|
||||
return self.parse_response(respcls, content, body_type,
|
||||
state, extended, **kwargs)
|
||||
return self.parse_response(schema, content, body_type, state,
|
||||
**kwargs)
|
||||
else:
|
||||
return response
|
||||
|
||||
def do_authorization_request(self, cls=AuthorizationRequest,
|
||||
def do_authorization_request(self, schema=SCHEMA["AuthorizationRequest"],
|
||||
state="", body_type="", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=None):
|
||||
http_args=None,
|
||||
resp_schema=SCHEMA["AuthorizationResponse"]):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method, request_args,
|
||||
extra_args)
|
||||
url, body, ht_args, csi = self.request_info(schema, method,
|
||||
request_args, extra_args)
|
||||
|
||||
if http_args is None:
|
||||
http_args = ht_args
|
||||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
resp = self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
resp = self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
if isinstance(resp, ErrorResponse):
|
||||
resp.state = csi.state
|
||||
if isinstance(resp, Message):
|
||||
if resp.type() in RESPONSE2ERROR["AuthorizationRequest"]:
|
||||
resp.state = csi.state
|
||||
|
||||
return resp
|
||||
|
||||
def do_access_token_request(self, cls=AccessTokenRequest, scope="",
|
||||
state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=AccessTokenResponse,
|
||||
def do_access_token_request(self, schema=SCHEMA["AccessTokenRequest"],
|
||||
scope="", state="", body_type="json",
|
||||
method="POST", request_args=None,
|
||||
extra_args=None, http_args=None,
|
||||
resp_schema=SCHEMA["AccessTokenResponse"],
|
||||
authn_method="", **kwargs):
|
||||
|
||||
# method is default POST
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state,
|
||||
|
@ -1031,19 +991,20 @@ class Client(PBase):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def do_access_token_refresh(self, cls=RefreshAccessTokenRequest,
|
||||
def do_access_token_refresh(self, schema=SCHEMA["RefreshAccessTokenRequest"],
|
||||
state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=AccessTokenResponse,
|
||||
http_args=None,
|
||||
resp_schema=SCHEMA["AccessTokenResponse"],
|
||||
authn_method="", **kwargs):
|
||||
|
||||
token = self.get_token(also_expired=True, state=state, **kwargs)
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
token=token,
|
||||
|
@ -1054,16 +1015,16 @@ class Client(PBase):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def do_revocate_token(self, cls=TokenRevocationRequest, scope="", state="",
|
||||
body_type="json", method="POST",
|
||||
def do_revocate_token(self, schema=SCHEMA["TokenRevocationRequest"],
|
||||
scope="", state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None, http_args=None,
|
||||
resp_cls=None, authn_method=""):
|
||||
resp_schema=SCHEMA[""], authn_method=""):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state,
|
||||
|
@ -1074,9 +1035,9 @@ class Client(PBase):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def fetch_protected_resource(self, uri, method="GET", headers=None,
|
||||
state="", **kwargs):
|
||||
|
@ -1111,56 +1072,48 @@ class Server(PBase):
|
|||
httpclass=None):
|
||||
|
||||
PBase.__init__(self, cache, time_out, proxy_info, follow_redirects,
|
||||
disable_ssl_certificate_validation, ca_certs,
|
||||
httpclass, jwt_keys)
|
||||
disable_ssl_certificate_validation, ca_certs,
|
||||
httpclass, jwt_keys)
|
||||
|
||||
|
||||
def parse_url_request(self, cls, url=None, query=None, extended=False):
|
||||
def parse_url_request(self, msg, url=None, query=None):
|
||||
if url:
|
||||
parts = urlparse.urlparse(url)
|
||||
scheme, netloc, path, params, query, fragment = parts[:6]
|
||||
|
||||
req = cls.set_urlencoded(query, extended)
|
||||
req = msg_deser(query, "urlencoded", msg)
|
||||
#req = message(msg).from_urlencoded(query)
|
||||
req.verify()
|
||||
return req
|
||||
|
||||
def parse_authorization_request(self, rcls=AuthorizationRequest,
|
||||
url=None, query=None, extended=False):
|
||||
|
||||
return self.parse_url_request(rcls, url, query, extended)
|
||||
def parse_authorization_request(self, reqmsg="AuthorizationRequest",
|
||||
url=None, query=None):
|
||||
|
||||
def parse_jwt_request(self, rcls=AuthorizationRequest, txt="", keystore="",
|
||||
verify=True, extend=False):
|
||||
return self.parse_url_request(reqmsg, url, query)
|
||||
|
||||
def parse_jwt_request(self, reqmsg="AuthorizationRequest", txt="",
|
||||
keystore="", verify=True):
|
||||
if not keystore:
|
||||
keystore = self.keystore
|
||||
|
||||
keys = keystore.get_keys("verify", owner=None)
|
||||
areq = rcls.set_jwt(txt, keys, verify, extend)
|
||||
#areq = message().from_(txt, keys, verify)
|
||||
areq = msg_deser(txt, "jwt", reqmsg, key=keys, verify=verify)
|
||||
areq.verify()
|
||||
return areq
|
||||
|
||||
def parse_body_request(self, cls=AccessTokenRequest, body=None,
|
||||
extend=False):
|
||||
req = cls.set_urlencoded(body, extend)
|
||||
def parse_body_request(self, reqmsg="AccessTokenRequest", body=None):
|
||||
#req = message(reqmsg).from_urlencoded(body)
|
||||
req = msg_deser(body, "urlencoded", reqmsg)
|
||||
req.verify()
|
||||
return req
|
||||
|
||||
def parse_token_request(self, rcls=AccessTokenRequest, body=None,
|
||||
extend=False):
|
||||
return self.parse_body_request(rcls, body, extend)
|
||||
def parse_token_request(self, reqmsg="AccessTokenRequest", body=None):
|
||||
return self.parse_body_request(reqmsg, body)
|
||||
|
||||
def parse_refresh_token_request(self, rcls=RefreshAccessTokenRequest,
|
||||
body=None, extend=False):
|
||||
return self.parse_body_request(rcls, body, extend)
|
||||
|
||||
# def is_authorized(self, path, authorization=None):
|
||||
# if not authorization:
|
||||
# return False
|
||||
#
|
||||
# if authorization.startswith("Bearer"):
|
||||
# parts = authorization.split(" ")
|
||||
#
|
||||
# return True
|
||||
def parse_refresh_token_request(self, reqmsg="RefreshAccessTokenRequest",
|
||||
body=None):
|
||||
return self.parse_body_request(reqmsg, body)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -7,17 +7,15 @@ import time
|
|||
from hashlib import md5
|
||||
|
||||
from oic.utils import http_util
|
||||
from oic.oauth2 import AuthorizationRequest
|
||||
from oic.oauth2 import AccessTokenRequest
|
||||
from oic.oauth2 import AuthorizationResponse
|
||||
from oic.oauth2 import AccessTokenResponse
|
||||
from oic.oauth2 import Client
|
||||
from oic.oauth2 import ErrorResponse
|
||||
from oic.oauth2 import Grant
|
||||
from oic.oauth2 import rndstr
|
||||
from oic.oauth2.message import SCHEMA
|
||||
from oic.oauth2.message import Message
|
||||
|
||||
ENDPOINTS = ["authorization_endpoint", "token_endpoint", "userinfo_endpoint",
|
||||
"check_id_endpoint", "registration_endpoint", "token_revokation_endpoint"]
|
||||
"check_id_endpoint", "registration_endpoint",
|
||||
"token_revokation_endpoint"]
|
||||
|
||||
def stateID(url, seed):
|
||||
"""The hash of the time + server path + a seed makes an unique
|
||||
|
@ -44,7 +42,7 @@ def factory(kaka, sdb, client_id, **kwargs):
|
|||
part = http_util.cookie_parts(client_id, kaka)
|
||||
if part is None:
|
||||
return None
|
||||
|
||||
|
||||
cons = Consumer(sdb, **kwargs)
|
||||
cons.restore(part[0])
|
||||
http_util.parse_cookie(client_id, cons.seed, kaka)
|
||||
|
@ -81,7 +79,7 @@ class Consumer(Client):
|
|||
"""
|
||||
if client_config is None:
|
||||
client_config = {}
|
||||
|
||||
|
||||
Client.__init__(self, **client_config)
|
||||
|
||||
self.authz_page = authz_page
|
||||
|
@ -149,7 +147,7 @@ class Consumer(Client):
|
|||
"grant": self.grant,
|
||||
"seed": self.seed,
|
||||
"redirect_uris": self.redirect_uris,
|
||||
}
|
||||
}
|
||||
|
||||
for endpoint in ENDPOINTS:
|
||||
res[endpoint] = getattr(self, endpoint, None)
|
||||
|
@ -185,9 +183,9 @@ class Consumer(Client):
|
|||
self._backup(sid)
|
||||
self.sdb["seed:%s" % self.seed] = sid
|
||||
|
||||
location = self.request_info(AuthorizationRequest, method="GET",
|
||||
scope=self.scope,
|
||||
request_args={"state": sid})[0]
|
||||
location = self.request_info(SCHEMA["AuthorizationRequest"],
|
||||
method="GET", scope=self.scope,
|
||||
request_args={"state": sid})[0]
|
||||
|
||||
|
||||
if self.debug:
|
||||
|
@ -220,44 +218,47 @@ class Consumer(Client):
|
|||
if "code" in self.response_type:
|
||||
# Might be an error response
|
||||
try:
|
||||
aresp = self.parse_response(AuthorizationResponse,
|
||||
aresp = self.parse_response(SCHEMA["AuthorizationResponse"],
|
||||
info=_query, format="urlencoded")
|
||||
except Exception, err:
|
||||
logger.error("%s" % err)
|
||||
raise
|
||||
|
||||
if isinstance(aresp, ErrorResponse):
|
||||
raise AuthzError(aresp.error)
|
||||
|
||||
if isinstance(aresp, Message):
|
||||
if aresp.type().endswith("ErrorResponse"):
|
||||
raise AuthzError(aresp["error"])
|
||||
|
||||
try:
|
||||
self.update(aresp.state)
|
||||
self.update(aresp["state"])
|
||||
except KeyError:
|
||||
raise UnknownState(aresp.state)
|
||||
|
||||
self._backup(aresp.state)
|
||||
raise UnknownState(aresp["state"])
|
||||
|
||||
self._backup(aresp["state"])
|
||||
|
||||
return aresp
|
||||
else: # implicit flow
|
||||
atr = self.parse_response(AccessTokenResponse, info=_query,
|
||||
format="urlencoded", extended=True)
|
||||
atr = self.parse_response(SCHEMA["AccessTokenResponse"],
|
||||
info=_query, format="urlencoded",
|
||||
extended=True)
|
||||
|
||||
if isinstance(atr, ErrorResponse):
|
||||
raise TokenError(atr.error)
|
||||
if isinstance(atr, Message):
|
||||
if atr.type().endswith("ErrorResponse"):
|
||||
raise TokenError(atr["error"])
|
||||
|
||||
try:
|
||||
self.update(atr.state)
|
||||
self.update(atr["state"])
|
||||
except KeyError:
|
||||
raise UnknownState(atr.state)
|
||||
raise UnknownState(atr["state"])
|
||||
|
||||
self.seed = self.grant[self.state].seed
|
||||
|
||||
|
||||
return atr
|
||||
|
||||
def complete(self, environ, start_response, logger):
|
||||
resp = self.handle_authorization_response(environ, start_response,
|
||||
logger)
|
||||
|
||||
if isinstance(resp, AuthorizationResponse):
|
||||
if resp.type() == "AuthorizationResponse":
|
||||
# Go get the access token
|
||||
resp = self.do_access_token_request(state=self.state)
|
||||
|
||||
|
@ -273,9 +274,9 @@ class Consumer(Client):
|
|||
elif self.client_secret:
|
||||
http_args = {}
|
||||
request_args = {
|
||||
"client_secret":self.client_secret,
|
||||
"client_id": self.client_id,
|
||||
"auth_method":"request_body"}
|
||||
"client_secret":self.client_secret,
|
||||
"client_id": self.client_id,
|
||||
"auth_method":"request_body"}
|
||||
else:
|
||||
raise Exception("Nothing to authenticate with")
|
||||
|
||||
|
@ -286,7 +287,7 @@ class Consumer(Client):
|
|||
|
||||
request_args, http_args = self.client_auth_info()
|
||||
|
||||
url, body, ht_args, csi = self.request_info(AccessTokenRequest,
|
||||
url, body, ht_args, csi = self.request_info(SCHEMA["AccessTokenRequest"],
|
||||
request_args=request_args,
|
||||
state=self.state)
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
from oic.oauth2.message import message, add_non_standard, msg_deser, by_schema, SCHEMA
|
||||
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
|
@ -9,17 +10,8 @@ from urlparse import parse_qs
|
|||
from oic.utils.http_util import *
|
||||
|
||||
from oic.oauth2 import rndstr
|
||||
|
||||
from oic.oauth2 import Server as SrvMethod
|
||||
|
||||
from oic.oauth2 import MissingRequiredAttribute
|
||||
from oic.oauth2 import AuthorizationResponse
|
||||
from oic.oauth2 import AuthorizationRequest
|
||||
from oic.oauth2 import AccessTokenResponse
|
||||
from oic.oauth2 import AccessTokenRequest
|
||||
from oic.oauth2 import TokenErrorResponse
|
||||
from oic.oauth2 import NoneResponse
|
||||
from oic import oauth2
|
||||
|
||||
class AuthnFailure(Exception):
|
||||
pass
|
||||
|
@ -44,10 +36,11 @@ def get_post(environ):
|
|||
def code_response(**kwargs):
|
||||
_areq = kwargs["areq"]
|
||||
_scode = kwargs["scode"]
|
||||
aresp = AuthorizationResponse()
|
||||
if _areq.state:
|
||||
aresp.state = _areq.state
|
||||
aresp.code = _scode
|
||||
aresp = message("AuthorizationResponse")
|
||||
if "state" in _areq:
|
||||
aresp["state"] = _areq["state"]
|
||||
aresp["code"] = _scode
|
||||
add_non_standard(_areq, aresp)
|
||||
return aresp
|
||||
|
||||
def token_response(**kwargs):
|
||||
|
@ -56,18 +49,19 @@ def token_response(**kwargs):
|
|||
_sdb = kwargs["sdb"]
|
||||
_dic = _sdb.update_to_token(_scode, issue_refresh=False)
|
||||
|
||||
aresp = oauth2.factory(AccessTokenResponse, **_dic)
|
||||
if _areq.scope:
|
||||
aresp.scope = _areq.scope
|
||||
aresp.c_extension = _areq.c_extension
|
||||
aresp = message("AccessTokenResponse", **_dic)
|
||||
if "state" in _areq:
|
||||
aresp["state"] = _areq["state"]
|
||||
|
||||
return aresp
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def none_response(**kwargs):
|
||||
_areq = kwargs["areq"]
|
||||
aresp = NoneResponse()
|
||||
if _areq.state:
|
||||
aresp.state = _areq.state
|
||||
aresp = message("NoneResponse")
|
||||
if "state" in _areq:
|
||||
aresp["state"] = _areq["state"]
|
||||
|
||||
return aresp
|
||||
|
||||
def location_url(response_type, redirect_uri, query):
|
||||
|
@ -77,8 +71,6 @@ def location_url(response_type, redirect_uri, query):
|
|||
return "%s#%s" % (redirect_uri, query)
|
||||
|
||||
class Provider(object):
|
||||
authorization_request = AuthorizationRequest
|
||||
|
||||
def __init__(self, name, sdb, cdb, function, urlmap=None, debug=0):
|
||||
self.name = name
|
||||
self.sdb = sdb
|
||||
|
@ -136,8 +128,8 @@ class Provider(object):
|
|||
#_log_info( "type: %s" % type(session["authzreq"]))
|
||||
|
||||
# pick up the original request
|
||||
areq = self.authorization_request.set_json(session["authzreq"],
|
||||
extended=True)
|
||||
areq = msg_deser(session["authzreq"], "json",
|
||||
schema=SCHEMA["AuthorizationRequest"])
|
||||
|
||||
if self.debug:
|
||||
_log_info("areq: %s" % areq)
|
||||
|
@ -149,21 +141,21 @@ class Provider(object):
|
|||
except Exception:
|
||||
raise
|
||||
|
||||
_log_info("response type: %s" % areq.response_type)
|
||||
_log_info("response type: %s" % areq["response_type"])
|
||||
|
||||
return areq, session
|
||||
|
||||
def authn_reply(self, areq, aresp, environ, start_response, logger):
|
||||
_log_info = logger.info
|
||||
|
||||
if areq.redirect_uri:
|
||||
if "redirect_uri" in areq:
|
||||
# TODO verify that the uri is reasonable
|
||||
redirect_uri = areq.redirect_uri
|
||||
redirect_uri = areq["redirect_uri"]
|
||||
else:
|
||||
redirect_uri = self.urlmap[areq.client_id]
|
||||
redirect_uri = self.urlmap[areq["client_id"]]
|
||||
|
||||
location = location_url(areq.response_type, redirect_uri,
|
||||
aresp.get_urlencoded())
|
||||
location = location_url(areq["response_type"], redirect_uri,
|
||||
aresp.to_urlencoded())
|
||||
|
||||
if self.debug:
|
||||
_log_info("Redirected to: '%s' (%s)" % (location, type(location)))
|
||||
|
@ -173,13 +165,12 @@ class Provider(object):
|
|||
|
||||
def authn_response(self, areq, session):
|
||||
scode = session["code"]
|
||||
areq.response_type.sort()
|
||||
_rtype = " ".join(areq.response_type)
|
||||
areq["response_type"].sort()
|
||||
_rtype = " ".join(areq["response_type"])
|
||||
return self.response_type_map[_rtype](areq=areq, scode=scode,
|
||||
sdb=self.sdb)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def authenticated(self, environ, start_response, logger, **kwargs):
|
||||
def authenticated(self, environ, start_response, logger):
|
||||
_log_info = logger.info
|
||||
|
||||
if self.debug:
|
||||
|
@ -203,13 +194,11 @@ class Provider(object):
|
|||
resp = BadRequest("Unknown response type")
|
||||
return resp(environ, start_response)
|
||||
|
||||
aresp.c_extension = areq.c_extension
|
||||
add_non_standard(aresp, areq)
|
||||
|
||||
return self.authn_reply(areq, aresp, environ, start_response, logger)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def authorization_endpoint(self, environ, start_response, logger,
|
||||
**kwargs):
|
||||
def authorization_endpoint(self, environ, start_response, logger):
|
||||
# The AuthorizationRequest endpoint
|
||||
|
||||
_log_info = logger.info
|
||||
|
@ -230,8 +219,7 @@ class Provider(object):
|
|||
_log_info("Query: '%s'" % query)
|
||||
|
||||
try:
|
||||
areq = self.srvmethod.parse_authorization_request(query=query,
|
||||
extended=True)
|
||||
areq = self.srvmethod.parse_authorization_request(query=query)
|
||||
except MissingRequiredAttribute, err:
|
||||
resp = BadRequest("%s" % err)
|
||||
return resp(environ, start_response)
|
||||
|
@ -239,11 +227,11 @@ class Provider(object):
|
|||
resp = BadRequest("%s" % err)
|
||||
return resp(environ, start_response)
|
||||
|
||||
if areq.redirect_uri:
|
||||
_redirect = areq.redirect_uri
|
||||
else:
|
||||
# A list, so pick one (==the first)
|
||||
_redirect = self.urlmap[areq.client_id][0]
|
||||
# if "redirect_uri" in areq:
|
||||
# _redirect = areq["redirect_uri"]
|
||||
# else:
|
||||
# # A list, so pick one (==the first)
|
||||
# _redirect = self.urlmap[areq["client_id"]][0]
|
||||
|
||||
sid = _sdb.create_authz_session("", areq)
|
||||
bsid = base64.b64encode(sid)
|
||||
|
@ -254,8 +242,7 @@ class Provider(object):
|
|||
|
||||
return self.function["authenticate"](environ, start_response, bsid)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def token_endpoint(self, environ, start_response, logger, handle):
|
||||
def token_endpoint(self, environ, start_response, logger):
|
||||
"""
|
||||
This is where clients come to get their access tokens
|
||||
"""
|
||||
|
@ -269,39 +256,40 @@ class Provider(object):
|
|||
if self.debug:
|
||||
_log_info("body: %s" % body)
|
||||
|
||||
areq = AccessTokenRequest.set_urlencoded(body, extended=True)
|
||||
areq = msg_deser(body, "urlencoded", typ="AccessTokenRequest")
|
||||
|
||||
# Client is from basic auth or ...
|
||||
client = environ["REMOTE_USER"]
|
||||
if not self.function["verify client"](environ, client, self.cdb):
|
||||
err = TokenErrorResponse(error="unathorized_client")
|
||||
resp = Response(err.get_json(), content="application/json",
|
||||
err = message("TokenErrorResponse", error="unathorized_client")
|
||||
resp = Response(err.to_json(), content="application/json",
|
||||
status="401 Unauthorized")
|
||||
return resp(environ, start_response)
|
||||
|
||||
if self.debug:
|
||||
_log_info("AccessTokenRequest: %s" % areq)
|
||||
|
||||
assert areq.grant_type == "authorization_code"
|
||||
assert areq["grant_type"] == "authorization_code"
|
||||
|
||||
# assert that the code is valid
|
||||
_info = _sdb[areq.code]
|
||||
_info = _sdb[areq["code"]]
|
||||
|
||||
# If redirect_uri was in the initial authorization request
|
||||
# verify that the one given here is the correct one.
|
||||
if "redirect_uri" in _info:
|
||||
assert areq.redirect_uri == _info["redirect_uri"]
|
||||
assert areq["redirect_uri"] == _info["redirect_uri"]
|
||||
|
||||
_tinfo = _sdb.update_to_token(areq.code)
|
||||
_tinfo = _sdb.update_to_token(areq["code"])
|
||||
|
||||
if self.debug:
|
||||
_log_info("_tinfo: %s" % _tinfo)
|
||||
|
||||
atr = oauth2.factory(AccessTokenResponse, **_tinfo)
|
||||
atr = message("AccessTokenResponse",
|
||||
**by_schema("AccessTokenResponse", **_tinfo))
|
||||
|
||||
if self.debug:
|
||||
_log_info("AccessTokenResponse: %s" % atr)
|
||||
|
||||
resp = Response(atr.get_json(), content="application/json")
|
||||
resp = Response(atr.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
import urlparse
|
||||
|
||||
from oic import oauth2
|
||||
|
||||
from oic.oauth2 import AUTHN_METHOD as OAUTH2_AUTHN_METHOD
|
||||
from oic.oauth2 import DEF_SIGN_ALG
|
||||
from oic.oauth2 import HTTP_ARGS
|
||||
from oic.oauth2 import rndstr
|
||||
from oic.oauth2.message import ErrorResponse
|
||||
from oic.oauth2.message import message_from_schema
|
||||
|
||||
from oic.oic.message import *
|
||||
from oic.utils import jwt
|
||||
|
@ -20,11 +23,12 @@ ENDPOINTS = ["authorization_endpoint", "token_endpoint",
|
|||
"registration_endpoint", "check_id_endpoint"]
|
||||
|
||||
RESPONSE2ERROR = {
|
||||
AuthorizationResponse: [AuthorizationErrorResponse, TokenErrorResponse],
|
||||
AccessTokenResponse: [TokenErrorResponse],
|
||||
IdToken: [ErrorResponse],
|
||||
RegistrationResponse: [ClientRegistrationErrorResponse],
|
||||
OpenIDSchema: [UserInfoErrorResponse]
|
||||
"AuthorizationResponse": ["AuthorizationErrorResponse",
|
||||
"TokenErrorResponse"],
|
||||
"AccessTokenResponse": ["TokenErrorResponse"],
|
||||
"IdToken": ["ErrorResponse"],
|
||||
"RegistrationResponse": ["ClientRegistrationErrorResponse"],
|
||||
"OpenIDSchema": ["UserInfoErrorResponse"]
|
||||
}
|
||||
|
||||
REQUEST2ENDPOINT = {
|
||||
|
@ -48,15 +52,10 @@ OIDCONF_PATTERN = "%s/.well-known/openid-configuration"
|
|||
AUTHN_METHOD = OAUTH2_AUTHN_METHOD.copy()
|
||||
|
||||
def assertion_jwt(cli, keys, audience, algorithm=DEF_SIGN_ALG):
|
||||
at = AuthnToken(
|
||||
iss = cli.client_id,
|
||||
prn = cli.client_id,
|
||||
aud = audience,
|
||||
jti = rndstr(8),
|
||||
exp = int(epoch_in_a_while(minutes=10)),
|
||||
iat = utc_now()
|
||||
)
|
||||
return at.get_jwt(key=keys, algorithm=algorithm)
|
||||
at = message("AuthnToken", iss = cli.client_id, prn = cli.client_id,
|
||||
aud = audience, jti = rndstr(8),
|
||||
exp = int(epoch_in_a_while(minutes=10)), iat = utc_now())
|
||||
return at.to_jwt(key=keys, algorithm=algorithm)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def client_secret_jwt(cli, cis, authn_method, request_args=None,
|
||||
|
@ -66,30 +65,37 @@ def client_secret_jwt(cli, cis, authn_method, request_args=None,
|
|||
signing_key = cli.keystore.get_sign_key()
|
||||
|
||||
# audience is the OP endpoint
|
||||
audience = cli._endpoint(REQUEST2ENDPOINT[cis.__class__.__name__])
|
||||
audience = cli._endpoint(REQUEST2ENDPOINT[cis.type()])
|
||||
|
||||
cis.client_assertion = assertion_jwt(cli, signing_key, audience,
|
||||
"RS256")
|
||||
cis.client_assertion_type = JWT_BEARER
|
||||
cis["client_assertion"] = assertion_jwt(cli, signing_key, audience,
|
||||
"RS256")
|
||||
cis["client_assertion_type"] = JWT_BEARER
|
||||
|
||||
if cis.client_secret:
|
||||
cis.client_secret = None
|
||||
try:
|
||||
del cis["client_secret"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def private_key_jwt(cli, cis, authn_method, request_args=None,
|
||||
http_args=None, req=None):
|
||||
http_args=None, req=None):
|
||||
|
||||
# signing key is the clients rsa key for instance
|
||||
signing_key = cli.keystore.get_sign_key()
|
||||
|
||||
# audience is the OP endpoint
|
||||
audience = cli._endpoint(REQUEST2ENDPOINT[cis.__class__.__name__])
|
||||
audience = cli._endpoint(REQUEST2ENDPOINT[cis.type()])
|
||||
|
||||
cis.client_assertion = assertion_jwt(cli, signing_key, audience,
|
||||
cis["client_assertion"] = assertion_jwt(cli, signing_key, audience,
|
||||
algorithm="RS512")
|
||||
cis.client_assertion_type = JWT_BEARER
|
||||
cis["client_assertion_type"] = JWT_BEARER
|
||||
|
||||
try:
|
||||
del cis["client_secret"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
|
@ -99,12 +105,12 @@ AUTHN_METHOD.update({"client_secret_jwt": client_secret_jwt,
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
class Token(oauth2.Token):
|
||||
_class = AccessTokenResponse
|
||||
_schema = SCHEMA["AccessTokenResponse"]
|
||||
|
||||
|
||||
class Grant(oauth2.Grant):
|
||||
_authz_resp = AuthorizationResponse
|
||||
_acc_resp = AccessTokenResponse
|
||||
_authz_resp = "AuthorizationResponse"
|
||||
_acc_resp = "AccessTokenResponse"
|
||||
_token_class = Token
|
||||
|
||||
def add_token(self, resp):
|
||||
|
@ -130,8 +136,8 @@ class Client(oauth2.Client):
|
|||
client_timeout = time_sans_frac() + expire_in
|
||||
|
||||
oauth2.Client.__init__(self, client_id, cache, timeout, proxy_info,
|
||||
follow_redirects, disable_ssl_certificate_validation,
|
||||
ca_certs, grant_expire_in, client_timeout, httpclass)
|
||||
follow_redirects, disable_ssl_certificate_validation,
|
||||
ca_certs, grant_expire_in, client_timeout, httpclass)
|
||||
|
||||
self.file_store = "./file/"
|
||||
self.file_uri = "http://localhost/"
|
||||
|
@ -186,21 +192,26 @@ class Client(oauth2.Client):
|
|||
The request will be signed
|
||||
"""
|
||||
|
||||
oir_args = {}
|
||||
|
||||
if userinfo_claims is not None:
|
||||
# UserInfoClaims
|
||||
claim = Claims(**userinfo_claims["claims"])
|
||||
claim = message("Claims", **userinfo_claims["claims"])
|
||||
|
||||
uic_args = {}
|
||||
for prop, val in userinfo_claims.items():
|
||||
if prop == "claims":
|
||||
continue
|
||||
if prop in UserInfoClaim.c_attributes.keys():
|
||||
if prop in SCHEMA["UserInfoClaim"]["param"].keys():
|
||||
uic_args[prop] = val
|
||||
|
||||
uic = UserInfoClaim(claim, **uic_args)
|
||||
uic = message("UserInfoClaim", claims=claim, **uic_args)
|
||||
else:
|
||||
uic = None
|
||||
|
||||
if uic:
|
||||
oir_args["userinfo"] = uic
|
||||
|
||||
if idtoken_claims is not None:
|
||||
#IdTokenClaims
|
||||
try:
|
||||
|
@ -208,28 +219,29 @@ class Client(oauth2.Client):
|
|||
except KeyError:
|
||||
_max_age=MAX_AUTHENTICATION_AGE
|
||||
|
||||
id_token = IDTokenClaim(max_age=_max_age)
|
||||
id_token = message("IDTokenClaim", max_age=_max_age)
|
||||
if "claims" in idtoken_claims:
|
||||
idtclaims = Claims(**idtoken_claims["claims"])
|
||||
id_token.claims = idtclaims
|
||||
idtclaims = message("Claims", **idtoken_claims["claims"])
|
||||
id_token["claims"] = idtclaims
|
||||
else: # uic must be != None
|
||||
id_token = IDTokenClaim(max_age=MAX_AUTHENTICATION_AGE)
|
||||
id_token = message("IDTokenClaim", max_age=MAX_AUTHENTICATION_AGE)
|
||||
|
||||
oir_args = {"userinfo":uic, "id_token":id_token}
|
||||
for prop in arq.keys():
|
||||
_val = getattr(arq, prop)
|
||||
if _val:
|
||||
oir_args[prop] = _val
|
||||
if id_token:
|
||||
oir_args["id_token"] = id_token
|
||||
|
||||
for prop, val in arq.items():
|
||||
oir_args[prop] = val
|
||||
|
||||
for attr in ["scope", "prompt", "response_type"]:
|
||||
if attr in oir_args:
|
||||
oir_args[attr] = " ".join(oir_args[attr])
|
||||
|
||||
oir = OpenIDRequest(**oir_args)
|
||||
oir = message("OpenIDRequest", **oir_args)
|
||||
|
||||
return oir.get_jwt(extended=True, key=keys, algorithm=algorithm)
|
||||
return oir.to_jwt(key=keys, algorithm=algorithm)
|
||||
|
||||
def construct_AuthorizationRequest(self, cls=AuthorizationRequest,
|
||||
def construct_AuthorizationRequest(self,
|
||||
schema=SCHEMA["AuthorizationRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
|
@ -239,13 +251,13 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
request_args = {"nonce": rndstr(12)}
|
||||
|
||||
return oauth2.Client.construct_AuthorizationRequest(self, cls,
|
||||
return oauth2.Client.construct_AuthorizationRequest(self, schema,
|
||||
request_args,
|
||||
extra_args,
|
||||
**kwargs)
|
||||
|
||||
def construct_OpenIDRequest(self, cls=OpenIDRequest, request_args=None,
|
||||
extra_args=None, **kwargs):
|
||||
def construct_OpenIDRequest(self, schema=SCHEMA["OpenIDRequest"],
|
||||
request_args=None, extra_args=None, **kwargs):
|
||||
|
||||
if request_args is not None:
|
||||
for arg in ["idtoken_claims", "userinfo_claims"]:
|
||||
|
@ -257,7 +269,7 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
request_args = {"nonce": rndstr(12)}
|
||||
|
||||
areq = oauth2.Client.construct_AuthorizationRequest(self, cls,
|
||||
areq = oauth2.Client.construct_AuthorizationRequest(self, schema,
|
||||
request_args,
|
||||
extra_args,
|
||||
**kwargs)
|
||||
|
@ -266,29 +278,29 @@ class Client(oauth2.Client):
|
|||
kwargs["keys"] = self.keystore.get_sign_key()
|
||||
|
||||
if "userinfo_claims" in kwargs or "idtoken_claims" in kwargs:
|
||||
areq.request = self.make_openid_request(areq, **kwargs)
|
||||
areq["request"] = self.make_openid_request(areq, **kwargs)
|
||||
|
||||
return areq
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_AccessTokenRequest(self, cls=AccessTokenRequest,
|
||||
def construct_AccessTokenRequest(self, schema=SCHEMA["AccessTokenRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return oauth2.Client.construct_AccessTokenRequest(self, cls,
|
||||
return oauth2.Client.construct_AccessTokenRequest(self, schema,
|
||||
request_args,
|
||||
extra_args, **kwargs)
|
||||
|
||||
def construct_RefreshAccessTokenRequest(self,
|
||||
cls=RefreshAccessTokenRequest,
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
schema=SCHEMA["RefreshAccessTokenRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return oauth2.Client.construct_RefreshAccessTokenRequest(self, cls,
|
||||
request_args,
|
||||
extra_args, **kwargs)
|
||||
return oauth2.Client.construct_RefreshAccessTokenRequest(self, schema,
|
||||
request_args,
|
||||
extra_args, **kwargs)
|
||||
|
||||
def construct_UserInfoRequest(self, cls=UserInfoRequest,
|
||||
def construct_UserInfoRequest(self, schema=SCHEMA["UserInfoRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
|
@ -304,23 +316,25 @@ class Client(oauth2.Client):
|
|||
|
||||
request_args["access_token"] = token.access_token
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_RegistrationRequest(self, cls=RegistrationRequest,
|
||||
def construct_RegistrationRequest(self,
|
||||
schema=SCHEMA["RegistrationRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_RefreshSessionRequest(self, cls=RefreshSessionRequest,
|
||||
def construct_RefreshSessionRequest(self,
|
||||
schema=SCHEMA["RefreshSessionRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
def _id_token_based(self, cls, request_args=None, extra_args=None,
|
||||
def _id_token_based(self, schema, request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
if request_args is None:
|
||||
|
@ -340,82 +354,88 @@ class Client(oauth2.Client):
|
|||
|
||||
request_args[_prop] = id_token
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
def construct_CheckSessionRequest(self, cls=CheckSessionRequest,
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
def construct_CheckSessionRequest(self,
|
||||
schema=SCHEMA["CheckSessionRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return self._id_token_based(cls, request_args, extra_args, **kwargs)
|
||||
return self._id_token_based(schema, request_args, extra_args, **kwargs)
|
||||
|
||||
def construct_CheckIDRequest(self, cls=CheckIDRequest, request_args=None,
|
||||
def construct_CheckIDRequest(self, schema=SCHEMA["CheckIDRequest"],
|
||||
request_args=None,
|
||||
extra_args=None, **kwargs):
|
||||
|
||||
# access_token is where the id_token will be placed
|
||||
return self._id_token_based(cls, request_args, extra_args,
|
||||
return self._id_token_based(schema, request_args, extra_args,
|
||||
prop="access_token", **kwargs)
|
||||
|
||||
def construct_EndSessionRequest(self, cls=EndSessionRequest,
|
||||
def construct_EndSessionRequest(self, schema=SCHEMA["EndSessionRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
if request_args is None:
|
||||
request_args = {}
|
||||
|
||||
|
||||
if "state" in kwargs:
|
||||
request_args["state"] = kwargs["state"]
|
||||
elif "state" in request_args:
|
||||
kwargs["state"] = request_args["state"]
|
||||
|
||||
# if "redirect_url" not in request_args:
|
||||
# request_args["redirect_url"] = self.redirect_url
|
||||
|
||||
return self._id_token_based(cls, request_args, extra_args, **kwargs)
|
||||
# if "redirect_url" not in request_args:
|
||||
# request_args["redirect_url"] = self.redirect_url
|
||||
|
||||
return self._id_token_based(schema, request_args, extra_args, **kwargs)
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def do_authorization_request(self, cls=AuthorizationRequest,
|
||||
def do_authorization_request(self, schema=SCHEMA["AuthorizationRequest"],
|
||||
state="", body_type="", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=None):
|
||||
http_args=None,
|
||||
resp_schema=SCHEMA["AuthorizationResponse"]):
|
||||
|
||||
return oauth2.Client.do_authorization_request(self, cls, state,
|
||||
return oauth2.Client.do_authorization_request(self, schema, state,
|
||||
body_type, method,
|
||||
request_args,
|
||||
extra_args, http_args,
|
||||
resp_cls)
|
||||
resp_schema)
|
||||
|
||||
|
||||
def do_access_token_request(self, cls=AccessTokenRequest, scope="",
|
||||
state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=AccessTokenResponse,
|
||||
def do_access_token_request(self, schema=SCHEMA["AccessTokenRequest"],
|
||||
scope="", state="", body_type="json",
|
||||
method="POST", request_args=None,
|
||||
extra_args=None, http_args=None,
|
||||
resp_schema=SCHEMA["AccessTokenResponse"],
|
||||
authn_method="", **kwargs):
|
||||
|
||||
return oauth2.Client.do_access_token_request(self, cls, scope, state,
|
||||
return oauth2.Client.do_access_token_request(self, schema, scope, state,
|
||||
body_type, method,
|
||||
request_args, extra_args,
|
||||
http_args, resp_cls,
|
||||
http_args, resp_schema,
|
||||
authn_method, **kwargs)
|
||||
|
||||
def do_access_token_refresh(self, cls=RefreshAccessTokenRequest,
|
||||
def do_access_token_refresh(self, schema=SCHEMA["RefreshAccessTokenRequest"],
|
||||
state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=AccessTokenResponse,
|
||||
http_args=None,
|
||||
resp_schema=SCHEMA["AccessTokenResponse"],
|
||||
**kwargs):
|
||||
|
||||
return oauth2.Client.do_access_token_refresh(self, cls, state,
|
||||
return oauth2.Client.do_access_token_refresh(self, schema, state,
|
||||
body_type, method,
|
||||
request_args,
|
||||
extra_args, http_args,
|
||||
resp_cls, **kwargs)
|
||||
resp_schema, **kwargs)
|
||||
|
||||
def do_registration_request(self, cls=RegistrationRequest, scope="",
|
||||
state="", body_type="json", method="POST",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=RegistrationResponse):
|
||||
def do_registration_request(self, schema=SCHEMA["RegistrationRequest"],
|
||||
scope="", state="", body_type="json",
|
||||
method="POST", request_args=None,
|
||||
extra_args=None, http_args=None,
|
||||
resp_schema=SCHEMA["RegistrationResponse"]):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state)
|
||||
|
@ -425,17 +445,18 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def do_check_session_request(self, cls=CheckSessionRequest, scope="",
|
||||
def do_check_session_request(self, schema=SCHEMA["CheckSessionRequest"],
|
||||
scope="",
|
||||
state="", body_type="json", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None,
|
||||
resp_cls=IdToken):
|
||||
resp_schema=SCHEMA["IdToken"]):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state)
|
||||
|
@ -445,17 +466,17 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def do_check_id_request(self, cls=CheckIDRequest, scope="",
|
||||
state="", body_type="json", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None,
|
||||
resp_cls=IdToken):
|
||||
def do_check_id_request(self, schema=SCHEMA["CheckIDRequest"], scope="",
|
||||
state="", body_type="json", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None,
|
||||
resp_schema=SCHEMA["IdToken"]):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state)
|
||||
|
@ -465,16 +486,16 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def do_end_session_request(self, cls=EndSessionRequest, scope="",
|
||||
state="", body_type="", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_cls=None):
|
||||
def do_end_session_request(self, schema=SCHEMA["EndSessionRequest"], scope="",
|
||||
state="", body_type="", method="GET",
|
||||
request_args=None, extra_args=None,
|
||||
http_args=None, resp_schema=SCHEMA[""]):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args,
|
||||
scope=scope, state=state)
|
||||
|
@ -484,15 +505,15 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
body_type, extended=False,
|
||||
state=state, http_args=http_args)
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, state=state,
|
||||
http_args=http_args)
|
||||
|
||||
def user_info_request(self, method="GET", state="", scope="", **kwargs):
|
||||
uir = UserInfoRequest()
|
||||
uir = message("UserInfoRequest")
|
||||
if "token" in kwargs:
|
||||
if kwargs["token"]:
|
||||
uir.access_token = kwargs["token"]
|
||||
uir["access_token"] = kwargs["token"]
|
||||
token = Token()
|
||||
token.type = "Bearer"
|
||||
token.access_token = kwargs["token"]
|
||||
|
@ -504,7 +525,7 @@ class Client(oauth2.Client):
|
|||
token = self.grant[state].get_token(scope)
|
||||
|
||||
if token.is_valid():
|
||||
uir.access_token = token.access_token
|
||||
uir["access_token"] = token.access_token
|
||||
else:
|
||||
# raise oauth2.OldAccessToken
|
||||
if self.log:
|
||||
|
@ -512,12 +533,12 @@ class Client(oauth2.Client):
|
|||
try:
|
||||
self.do_access_token_refresh(token=token)
|
||||
token = self.grant[state].get_token(scope)
|
||||
uir.access_token = token.access_token
|
||||
uir["access_token"] = token.access_token
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
try:
|
||||
uir.schema = kwargs["schema"]
|
||||
uir["schema"] = kwargs["schema"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
@ -538,7 +559,7 @@ class Client(oauth2.Client):
|
|||
kwargs["headers"] = [("Authorization", token.access_token)]
|
||||
if not "token_in_message_body" in _behav:
|
||||
# remove the token from the request
|
||||
uir.access_token = None
|
||||
uir["access_token"] = None
|
||||
|
||||
path, body, kwargs = self.get_or_post(uri, method, uir, **kwargs)
|
||||
|
||||
|
@ -551,7 +572,7 @@ class Client(oauth2.Client):
|
|||
|
||||
kwargs["schema"] = schema
|
||||
path, body, method, h_args = self.user_info_request(method, state,
|
||||
scope, **kwargs)
|
||||
scope, **kwargs)
|
||||
|
||||
try:
|
||||
response, content = self.http.request(path, method, body, **h_args)
|
||||
|
@ -565,7 +586,7 @@ class Client(oauth2.Client):
|
|||
else:
|
||||
raise Exception("ERROR: Something went wrong [%s]" % response.status)
|
||||
|
||||
return OpenIDSchema.set_json(txt=content, extended=True)
|
||||
return message("OpenIDSchema").from_json(txt=content)
|
||||
|
||||
|
||||
def provider_config(self, issuer, keys=True, endpoints=True):
|
||||
|
@ -578,21 +599,21 @@ class Client(oauth2.Client):
|
|||
|
||||
(response, content) = self.http.request(url)
|
||||
if response.status == 200:
|
||||
pcr = ProviderConfigurationResponse.from_json(content,
|
||||
extended=True)
|
||||
pcr = message("ProviderConfigurationResponse").from_json(content)
|
||||
else:
|
||||
raise Exception("%s" % response.status)
|
||||
|
||||
if pcr.issuer:
|
||||
if pcr.issuer.endswith("/"):
|
||||
_pcr_issuer = pcr.issuer[:-1]
|
||||
if "issuer" in pcr:
|
||||
if pcr["issuer"].endswith("/"):
|
||||
_pcr_issuer = pcr["issuer"][:-1]
|
||||
else:
|
||||
_pcr_issuer = pcr.issuer
|
||||
_pcr_issuer = pcr["issuer"]
|
||||
|
||||
try:
|
||||
assert _issuer == _pcr_issuer
|
||||
except AssertionError:
|
||||
raise Exception("provider info issuer mismatch")
|
||||
raise Exception("provider info issuer mismatch '%s' != '%s'" % (
|
||||
_issuer, _pcr_issuer))
|
||||
|
||||
if endpoints:
|
||||
for key, val in pcr.items():
|
||||
|
@ -614,7 +635,7 @@ class Client(oauth2.Client):
|
|||
keycol = self.keystore.pairkeys(csrc)["verify"]
|
||||
info = json.loads(jwt.verify(str(spec["JWT"]), keycol))
|
||||
attr = [n for n, s in userinfo._claim_names.items() if s ==
|
||||
csrc]
|
||||
csrc]
|
||||
assert attr == info.keys()
|
||||
|
||||
for key, vals in info.items():
|
||||
|
@ -629,14 +650,14 @@ class Client(oauth2.Client):
|
|||
|
||||
if "access_token" in spec:
|
||||
_uinfo = self.do_user_info_request(
|
||||
token=spec["access_token"],
|
||||
userinfo_endpoint=spec["endpoint"])
|
||||
token=spec["access_token"],
|
||||
userinfo_endpoint=spec["endpoint"])
|
||||
else:
|
||||
_uinfo = self.do_user_info_request(token=callback(csrc),
|
||||
userinfo_endpoint=spec["endpoint"])
|
||||
userinfo_endpoint=spec["endpoint"])
|
||||
|
||||
attr = [n for n, s in userinfo._claim_names.items() if s ==
|
||||
csrc]
|
||||
csrc]
|
||||
assert attr == _uinfo.keys()
|
||||
|
||||
for key, vals in _uinfo.items():
|
||||
|
@ -662,30 +683,27 @@ class Server(oauth2.Server):
|
|||
|
||||
return urlparse.parse_qs(query)
|
||||
|
||||
def parse_token_request(self, cls=AccessTokenRequest, body=None,
|
||||
extended=False):
|
||||
return oauth2.Server.parse_token_request(self, cls, body, extended)
|
||||
def parse_token_request(self, msg="AccessTokenRequest", body=None):
|
||||
return oauth2.Server.parse_token_request(self, msg, body)
|
||||
|
||||
def parse_authorization_request(self, rcls=AuthorizationRequest,
|
||||
url=None, query=None, extended=False):
|
||||
return oauth2.Server.parse_authorization_request(self, rcls, url,
|
||||
query, extended)
|
||||
def parse_authorization_request(self, msg="AuthorizationRequest",
|
||||
url=None, query=None):
|
||||
return oauth2.Server.parse_authorization_request(self, msg, url, query)
|
||||
|
||||
def parse_jwt_request(self, rcls=AuthorizationRequest, txt="",
|
||||
keys=None, verify=True, extended=False):
|
||||
def parse_jwt_request(self, msg="AuthorizationRequest", txt="",
|
||||
keys=None, verify=True):
|
||||
|
||||
return oauth2.Server.parse_jwt_request(self, rcls, txt,
|
||||
keys, verify, extended)
|
||||
return oauth2.Server.parse_jwt_request(self, msg, txt,
|
||||
keys, verify)
|
||||
|
||||
def parse_refresh_token_request(self, cls=RefreshAccessTokenRequest,
|
||||
body=None, extended=False):
|
||||
return oauth2.Server.parse_refresh_token_request(self, cls, body,
|
||||
extended)
|
||||
def parse_refresh_token_request(self, msg="RefreshAccessTokenRequest",
|
||||
body=None):
|
||||
return oauth2.Server.parse_refresh_token_request(self, msg, body)
|
||||
|
||||
def _deser_id_token(self, str=""):
|
||||
if not str:
|
||||
return None
|
||||
|
||||
|
||||
# have to start decoding the jwt without verifying in order to find
|
||||
# out which key to verify the JWT signature with
|
||||
_ = json.loads(jwt.unpack(str)[1])
|
||||
|
@ -695,7 +713,7 @@ class Server(oauth2.Server):
|
|||
|
||||
keys = self.keystore.get_keys("verify", owner=None)
|
||||
|
||||
return IdToken.set_jwt(str, key=keys)
|
||||
return message("IdToken").from_jwt(str, key=keys)
|
||||
|
||||
def parse_check_session_request(self, url=None, query=None):
|
||||
"""
|
||||
|
@ -713,46 +731,44 @@ class Server(oauth2.Server):
|
|||
assert "access_token" in param # ignore the rest
|
||||
return self._deser_id_token(param["access_token"][0])
|
||||
|
||||
def _parse_request(self, cls, data, format, extended):
|
||||
def _parse_request(self, schema, data, format):
|
||||
if format == "json":
|
||||
request = cls.set_json(data, extended)
|
||||
request = message_from_schema(schema).from_json(data)
|
||||
elif format == "urlencoded":
|
||||
if '?' in data:
|
||||
parts = urlparse.urlparse(data)
|
||||
scheme, netloc, path, params, query, fragment = parts[:6]
|
||||
else:
|
||||
query = data
|
||||
request = cls.set_urlencoded(query, extended)
|
||||
request = message_from_schema(schema).from_urlencoded(query)
|
||||
else:
|
||||
raise Exception("Unknown package format: '%s'" % format)
|
||||
|
||||
request.verify()
|
||||
return request
|
||||
|
||||
def parse_open_id_request(self, data, format="urlencoded", extended=False):
|
||||
return self._parse_request(OpenIDRequest, data, format, extended)
|
||||
|
||||
def parse_user_info_request(self, data, format="urlencoded", extended=False):
|
||||
return self._parse_request(UserInfoRequest, data, format, extended)
|
||||
def parse_open_id_request(self, data, format="urlencoded"):
|
||||
return self._parse_request(SCHEMA["OpenIDRequest"], data, format)
|
||||
|
||||
def parse_refresh_session_request(self, url=None, query=None,
|
||||
extended=False):
|
||||
def parse_user_info_request(self, data, format="urlencoded"):
|
||||
return self._parse_request(SCHEMA["UserInfoRequest"], data, format)
|
||||
|
||||
def parse_refresh_session_request(self, url=None, query=None):
|
||||
if url:
|
||||
parts = urlparse.urlparse(url)
|
||||
scheme, netloc, path, params, query, fragment = parts[:6]
|
||||
|
||||
return RefreshSessionRequest.set_urlencoded(query, extended)
|
||||
return message("RefreshSessionRequest").from_urlencoded(query)
|
||||
|
||||
def parse_registration_request(self, data, format="urlencoded",
|
||||
extended=True):
|
||||
return self._parse_request(RegistrationRequest, data, format, extended)
|
||||
def parse_registration_request(self, data, format="urlencoded"):
|
||||
return self._parse_request(SCHEMA["RegistrationRequest"], data, format)
|
||||
|
||||
def parse_end_session_request(self, query, extended=True):
|
||||
esr = EndSessionRequest.set_urlencoded(query, extended)
|
||||
def parse_end_session_request(self, query):
|
||||
esr = message("EndSessionRequest").from_urlencoded(query)
|
||||
# if there is a id_token in there it is as a string
|
||||
esr.id_token = self._deser_id_token(esr.id_token)
|
||||
esr["id_token"] = self._deser_id_token(esr["id_token"])
|
||||
return esr
|
||||
|
||||
def parse_issuer_request(self, info, format="urlencoded", extended=True):
|
||||
return self._parse_request(IssuerRequest, info, format, extended)
|
||||
|
||||
def parse_issuer_request(self, info, format="urlencoded"):
|
||||
return self._parse_request(SCHEMA["IssuerRequest"], info, format)
|
||||
|
||||
|
|
|
@ -1,83 +1,65 @@
|
|||
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
from oic import oauth2
|
||||
from oic.oic import Server as OicServer
|
||||
from oic.oic import Client
|
||||
from oic.oic.provider import Provider, get_or_post, Endpoint
|
||||
from oic.oauth2.message import SINGLE_REQUIRED_STRING
|
||||
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
|
||||
from oic.oauth2.message import REQUIRED_LIST_OF_STRINGS
|
||||
from oic.oauth2.message import ErrorResponse
|
||||
|
||||
from oic.utils.http_util import Response, Unauthorized
|
||||
from oic.oic import REQUEST2ENDPOINT
|
||||
from oic.oic import RESPONSE2ERROR
|
||||
|
||||
#from oic.oic.message import IdToken, OpenIDSchema
|
||||
from oic.oic.message import Claims, TokenErrorResponse
|
||||
from oic.oic.message import OpenIDSchema
|
||||
from oic.oic.message import UserInfoClaim
|
||||
from oic.oic.message import message
|
||||
#from oic.oic.message import SCHEMA
|
||||
|
||||
from oic.oic.provider import Provider, get_or_post, Endpoint
|
||||
|
||||
from oic.oauth2.message import SINGLE_REQUIRED_STRING, Message
|
||||
from oic.oauth2.message import SINGLE_OPTIONAL_STRING
|
||||
from oic.oauth2.message import REQUIRED_LIST_OF_STRINGS
|
||||
|
||||
from oic.utils.http_util import Response, Unauthorized
|
||||
|
||||
# Used in claims.py
|
||||
from oic.oic.message import RegistrationRequest
|
||||
from oic.oic.message import RegistrationResponse
|
||||
#from oic.oic.message import RegistrationRequest
|
||||
#from oic.oic.message import RegistrationResponse
|
||||
|
||||
class UserClaimsRequest(oauth2.Base):
|
||||
c_attributes = oauth2.Base.c_attributes.copy()
|
||||
c_attributes["user_id"] = SINGLE_REQUIRED_STRING
|
||||
c_attributes["client_id"] = SINGLE_REQUIRED_STRING
|
||||
c_attributes["client_secret"] = SINGLE_REQUIRED_STRING
|
||||
c_attributes["claims_names"] = REQUIRED_LIST_OF_STRINGS
|
||||
def verify(self, **kwargs):
|
||||
if self.jwt:
|
||||
# Try to decode the JWT, checks the signature
|
||||
try:
|
||||
item = message("OpenIDSchema").set_jwt(str(self.jwt),
|
||||
kwargs["key"])
|
||||
except Exception, _err:
|
||||
raise Exception(_err.__class__.__name__)
|
||||
|
||||
def __init__(self,
|
||||
user_id=None,
|
||||
client_id=None,
|
||||
client_secret=None,
|
||||
claims_names=None,
|
||||
**kwargs):
|
||||
oauth2.Base.__init__(self, **kwargs)
|
||||
self.user_id = user_id
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.claims_names = claims_names
|
||||
if not item.verify(**kwargs):
|
||||
return False
|
||||
|
||||
class UserClaimsResponse(oauth2.Base):
|
||||
c_attributes = oauth2.Base.c_attributes.copy()
|
||||
c_attributes["claims_names"] = REQUIRED_LIST_OF_STRINGS
|
||||
c_attributes["jwt"] = SINGLE_OPTIONAL_STRING
|
||||
c_attributes["endpoint"] = SINGLE_OPTIONAL_STRING
|
||||
c_attributes["access_token"] = SINGLE_OPTIONAL_STRING
|
||||
return super(self.__class__, self).verify(**kwargs)
|
||||
|
||||
def __init__(self,
|
||||
jwt=None,
|
||||
claims_names=None,
|
||||
endpoint=None,
|
||||
access_token=None,
|
||||
**kwargs):
|
||||
oauth2.Base.__init__(self, **kwargs)
|
||||
self.jwt = jwt
|
||||
self.claims_names = claims_names
|
||||
self.endpoint = endpoint
|
||||
self.access_token = access_token
|
||||
|
||||
def verify(self, **kwargs):
|
||||
if self.jwt:
|
||||
# Try to decode the JWT, checks the signature
|
||||
try:
|
||||
item = OpenIDSchema.set_jwt(str(self.jwt), kwargs["key"])
|
||||
except Exception, _err:
|
||||
raise Exception(_err.__class__.__name__)
|
||||
|
||||
if not item.verify(**kwargs):
|
||||
return False
|
||||
|
||||
return oauth2.Base.verify(self, **kwargs)
|
||||
SCHEMA = {
|
||||
"": {"param": {}},
|
||||
"UserClaimsRequest": {
|
||||
"param": {
|
||||
"user_id": SINGLE_REQUIRED_STRING,
|
||||
"client_id": SINGLE_REQUIRED_STRING,
|
||||
"client_secret": SINGLE_REQUIRED_STRING,
|
||||
"claims_names": REQUIRED_LIST_OF_STRINGS
|
||||
},
|
||||
},
|
||||
"UserClaimsResponse": {
|
||||
"param": {
|
||||
"claims_names": REQUIRED_LIST_OF_STRINGS,
|
||||
"jwt": SINGLE_OPTIONAL_STRING,
|
||||
"endpoint": SINGLE_OPTIONAL_STRING,
|
||||
"access_token": SINGLE_OPTIONAL_STRING
|
||||
},
|
||||
"verify": verify,
|
||||
},
|
||||
}
|
||||
|
||||
class OICCServer(OicServer):
|
||||
|
||||
def parse_user_claims_request(self, info, format="urlencoded",
|
||||
extended=True):
|
||||
return self._parse_request(UserClaimsRequest, info, format, extended)
|
||||
def parse_user_claims_request(self, info, format="urlencoded"):
|
||||
return self._parse_request(SCHEMA["UserClaimsRequest"], info, format)
|
||||
|
||||
class ClaimsServer(Provider):
|
||||
|
||||
|
@ -85,8 +67,8 @@ class ClaimsServer(Provider):
|
|||
debug=0, cache=None, timeout=None, proxy_info=None,
|
||||
follow_redirects=True, ca_certs="", jwt_keys=None):
|
||||
Provider.__init__(self, name, sdb, cdb, function, userdb, urlmap,
|
||||
debug, cache, timeout, proxy_info,
|
||||
follow_redirects, ca_certs, jwt_keys)
|
||||
debug, cache, timeout, proxy_info,
|
||||
follow_redirects, ca_certs, jwt_keys)
|
||||
|
||||
if jwt_keys is None:
|
||||
jwt_keys = []
|
||||
|
@ -102,19 +84,16 @@ class ClaimsServer(Provider):
|
|||
def _aggregation(self, info, logger):
|
||||
|
||||
jwt_key = self.keystore.get_sign_key()
|
||||
cresp = UserClaimsResponse(jwt=info.get_jwt(key=jwt_key,
|
||||
algorithm="RS256"),
|
||||
claims_names=info.keys())
|
||||
cresp = Message("UserClaimsResponse", SCHEMA["UserClaimsResponse"],
|
||||
jwt=info.get_jwt(key=jwt_key, algorithm="RS256"),
|
||||
claims_names=info.keys())
|
||||
|
||||
logger.info("RESPONSE: %s" % (cresp.dictionary(),))
|
||||
logger.info("RESPONSE: %s" % (cresp.to_dict(),))
|
||||
return cresp
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def _distributed(self, ucreq, logger):
|
||||
cresp = UserClaimsResponse()
|
||||
cresp.endpoint = ""
|
||||
cresp.access_token = ""
|
||||
return cresp
|
||||
return Message("UserClaimsResponse", SCHEMA["UserClaimsResponse"])
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def do_aggregation(self, info, uid):
|
||||
|
@ -134,13 +113,13 @@ class ClaimsServer(Provider):
|
|||
|
||||
if not self.function["verify_client"](environ, ucreq, self.cdb):
|
||||
_log_info("could not verify client")
|
||||
err = TokenErrorResponse(error="unathorized_client")
|
||||
err = message("TokenErrorResponse", error="unathorized_client")
|
||||
resp = Unauthorized(err.get_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
if ucreq.claims_names:
|
||||
args = dict([(n, {"optional": True}) for n in ucreq.claims_names])
|
||||
uic = UserInfoClaim(claims=Claims(**args))
|
||||
uic = message("UserInfoClaim", claims=message("Claims", **args))
|
||||
else:
|
||||
uic = None
|
||||
|
||||
|
@ -156,7 +135,7 @@ class ClaimsServer(Provider):
|
|||
else:
|
||||
cresp = self._distributed(info, logger)
|
||||
|
||||
resp = Response(cresp.get_json(), content="application/json")
|
||||
resp = Response(cresp.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
class ClaimsClient(Client):
|
||||
|
@ -176,22 +155,23 @@ class ClaimsClient(Client):
|
|||
self.request2endpoint = REQUEST2ENDPOINT.copy()
|
||||
self.request2endpoint["UserClaimsRequest"] = "userclaims_endpoint"
|
||||
self.response2error = RESPONSE2ERROR.copy()
|
||||
self.response2error[UserClaimsResponse] = [ErrorResponse]
|
||||
self.response2error["UserClaimsResponse"] = ["ErrorResponse"]
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def construct_UserClaimsRequest(self, cls=UserClaimsRequest,
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
def construct_UserClaimsRequest(self, schema=SCHEMA["UserClaimsRequest"],
|
||||
request_args=None, extra_args=None,
|
||||
**kwargs):
|
||||
|
||||
return self.construct_request(cls, request_args, extra_args)
|
||||
return self.construct_request(schema, request_args, extra_args)
|
||||
|
||||
|
||||
def do_claims_request(self, cls=UserClaimsRequest,
|
||||
resp_cls=UserClaimsResponse, body_type="json",
|
||||
def do_claims_request(self, schema=SCHEMA["UserClaimsRequest"],
|
||||
resp_schema=SCHEMA["UserClaimsResponse"],
|
||||
body_type="json",
|
||||
method="POST", request_args=None, extra_args=None,
|
||||
http_args=None):
|
||||
|
||||
url, body, ht_args, csi = self.request_info(cls, method=method,
|
||||
url, body, ht_args, csi = self.request_info(schema, method=method,
|
||||
request_args=request_args,
|
||||
extra_args=extra_args)
|
||||
|
||||
|
@ -200,11 +180,11 @@ class ClaimsClient(Client):
|
|||
else:
|
||||
http_args.update(http_args)
|
||||
|
||||
return self.request_and_return(url, resp_cls, method, body,
|
||||
return self.request_and_return(url, resp_schema, method, body,
|
||||
body_type, extended=False,
|
||||
http_args=http_args,
|
||||
key=self.keystore.pairkeys(
|
||||
self.keystore.match_owner(url)))
|
||||
self.keystore.match_owner(url)))
|
||||
|
||||
class UserClaimsEndpoint(Endpoint) :
|
||||
type = "userclaims"
|
||||
|
|
|
@ -10,26 +10,14 @@ import httplib2
|
|||
from hashlib import md5
|
||||
|
||||
from oic.utils import http_util
|
||||
#from oic.utils.time_util import time_sans_frac
|
||||
|
||||
from oic.oic import Client
|
||||
from oic.oic import ENDPOINTS
|
||||
from oic.oic.message import AuthorizationRequest
|
||||
#from oic.oic.message import IDTokenClaim
|
||||
from oic.oic.message import UserInfoClaim
|
||||
from oic.oic.message import Claims
|
||||
#from oic.oic.message import OpenIDRequest
|
||||
from oic.oic.message import AuthorizationResponse
|
||||
from oic.oic.message import AccessTokenResponse
|
||||
#from oic.oic.message import ProviderConfigurationResponse
|
||||
from oic.oic.message import RegistrationRequest
|
||||
from oic.oic.message import RegistrationResponse
|
||||
from oic.oic.message import IssuerRequest
|
||||
from oic.oic.message import IssuerResponse
|
||||
|
||||
from oic.oauth2.message import ErrorResponse
|
||||
from oic.oic.message import SCHEMA, msg_deser
|
||||
from oic.oic.message import message
|
||||
|
||||
from oic.oauth2 import Grant
|
||||
#from oic.oauth2 import DEF_SIGN_ALG
|
||||
from oic.oauth2 import rndstr
|
||||
|
||||
from oic.oauth2.consumer import TokenError
|
||||
|
@ -85,9 +73,9 @@ def build_userinfo_claims(claims, format="signed", locale="us-en"):
|
|||
}
|
||||
}
|
||||
"""
|
||||
claim = Claims(**claims)
|
||||
claim = message("Claims", **claims)
|
||||
|
||||
return UserInfoClaim(claim, format=format, locale=locale)
|
||||
return message("UserInfoClaim", claims=claim, format=format, locale=locale)
|
||||
|
||||
|
||||
#def construct_openid_request(arq, keys, algorithm=DEF_SIGN_ALG, iss=None,
|
||||
|
@ -127,9 +115,12 @@ def clean_response(aresp):
|
|||
:param aresp: The original AccessTokenResponse
|
||||
:return: An AccessTokenResponse instance
|
||||
"""
|
||||
atr = AccessTokenResponse()
|
||||
for prop in AccessTokenResponse.c_attributes.keys():
|
||||
setattr(atr, prop, getattr(aresp, prop))
|
||||
atr = message("AccessTokenResponse")
|
||||
for prop in atr.parameters():
|
||||
try:
|
||||
atr[prop] = aresp[prop]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return atr
|
||||
|
||||
|
@ -233,7 +224,17 @@ class Consumer(Client):
|
|||
_log_info("- begin -")
|
||||
|
||||
_path = http_util.geturl(environ, False, False)
|
||||
self.redirect_uris = [_path + self.config["authz_page"]]
|
||||
_page = self.config["authz_page"]
|
||||
if not _path.endswith("/"):
|
||||
if _page.startswith("/"):
|
||||
self.redirect_uris = [_path + _page]
|
||||
else:
|
||||
self.redirect_uris = ["%s/%s" % (_path, _page)]
|
||||
else:
|
||||
if _page.startswith("/"):
|
||||
self.redirect_uris = [_path + _page[1:]]
|
||||
else:
|
||||
self.redirect_uris = ["%s/%s" % (_path, _page)]
|
||||
|
||||
# Put myself in the dictionary of sessions, keyed on session-id
|
||||
if not self.seed:
|
||||
|
@ -274,8 +275,8 @@ class Consumer(Client):
|
|||
extra_args=None)
|
||||
|
||||
if self.config["request_method"] == "file":
|
||||
id_request = areq.request
|
||||
areq.request = None
|
||||
id_request = areq["request"]
|
||||
del areq["request"]
|
||||
_filedir = self.config["temp_dir"]
|
||||
_webpath = self.config["temp_path"]
|
||||
_name = rndstr(10)
|
||||
|
@ -287,16 +288,16 @@ class Consumer(Client):
|
|||
fid.write(id_request)
|
||||
fid.close()
|
||||
_webname = "%s%s%s" % (_path,_webpath,_name)
|
||||
areq.request_uri = _webname
|
||||
areq["request_uri"] = _webname
|
||||
self.request_uri = _webname
|
||||
self._backup(sid)
|
||||
else:
|
||||
areq = self.construct_AuthorizationRequest(AuthorizationRequest,
|
||||
request_args=args)
|
||||
areq = self.construct_AuthorizationRequest(
|
||||
SCHEMA["AuthorizationRequest"],
|
||||
request_args=args)
|
||||
|
||||
|
||||
location = "%s?%s" % (self.authorization_endpoint,
|
||||
areq.get_urlencoded())
|
||||
location = areq.request(self.authorization_endpoint)
|
||||
|
||||
if self.debug:
|
||||
_log_info("Redirecting to: %s" % location)
|
||||
|
@ -335,39 +336,41 @@ class Consumer(Client):
|
|||
if "code" in self.config["response_type"]:
|
||||
# Might be an error response
|
||||
_log_info("Expect Authorization Response")
|
||||
aresp = self.parse_response(AuthorizationResponse, info=_query,
|
||||
aresp = self.parse_response(SCHEMA["AuthorizationResponse"],
|
||||
info=_query,
|
||||
format="urlencoded")
|
||||
if isinstance(aresp, ErrorResponse):
|
||||
if aresp.type() == "ErrorResponse":
|
||||
_log_info("ErrorResponse: %s" % aresp)
|
||||
raise AuthzError(aresp.error)
|
||||
|
||||
_log_info("Aresp: %s" % aresp)
|
||||
|
||||
_state = aresp["state"]
|
||||
try:
|
||||
self.update(aresp.state)
|
||||
self.update(_state)
|
||||
except KeyError:
|
||||
raise UnknownState(aresp.state)
|
||||
raise UnknownState(_state)
|
||||
|
||||
self.redirect_uris = [self.sdb[aresp.state]["redirect_uris"]]
|
||||
self.redirect_uris = [self.sdb[_state]["redirect_uris"]]
|
||||
|
||||
# May have token and id_token information too
|
||||
if aresp.access_token:
|
||||
if "access_token" in aresp:
|
||||
atr = clean_response(aresp)
|
||||
self.access_token = atr
|
||||
# update the grant object
|
||||
self.get_grant(state=aresp.state).add_token(atr)
|
||||
self.get_grant(state=_state).add_token(atr)
|
||||
else:
|
||||
atr = None
|
||||
|
||||
self._backup(aresp.state)
|
||||
self._backup(_state)
|
||||
|
||||
idt = None
|
||||
return aresp, atr, idt
|
||||
else: # implicit flow
|
||||
_log_info("Expect Access Token Response")
|
||||
atr = self.parse_response(AccessTokenResponse, info=_query,
|
||||
format="urlencoded", extended=True)
|
||||
if isinstance(atr, ErrorResponse):
|
||||
atr = self.parse_response(SCHEMA["AccessTokenResponse"],
|
||||
info=_query, format="urlencoded")
|
||||
if atr.type() == "ErrorResponse":
|
||||
raise TokenError(atr.error)
|
||||
|
||||
idt = None
|
||||
|
@ -397,7 +400,7 @@ class Consumer(Client):
|
|||
|
||||
logger.info("Access Token Response: %s" % resp)
|
||||
|
||||
if isinstance(resp, ErrorResponse):
|
||||
if resp.type() == "ErrorResponse":
|
||||
raise TokenError(resp.error)
|
||||
|
||||
#self._backup(self.sdb["seed:%s" % _cli.seed])
|
||||
|
@ -413,7 +416,7 @@ class Consumer(Client):
|
|||
self.log = logger
|
||||
uinfo = self.do_user_info_request(state=self.state)
|
||||
|
||||
if isinstance(uinfo, ErrorResponse):
|
||||
if uinfo.type() == "ErrorResponse":
|
||||
raise TokenError(uinfo.error)
|
||||
|
||||
self.user_info = uinfo
|
||||
|
@ -441,10 +444,11 @@ class Consumer(Client):
|
|||
raise
|
||||
|
||||
if response.status == 200:
|
||||
result = IssuerResponse.from_json(content)
|
||||
if result.SWD_service_redirect:
|
||||
_loc = result.SWD_service_redirect.location
|
||||
_uri = IssuerRequest(ISSUER_URL, principal).request(_loc)
|
||||
result = msg_deser(content, "json", "IssuerResponse")
|
||||
if "SWD_service_redirect" in result:
|
||||
_loc = result["SWD_service_redirect"]["location"]
|
||||
_uri = message("IssuerRequest", service=ISSUER_URL,
|
||||
principal=principal).request(_loc)
|
||||
return self.discovery_query(_uri, principal)
|
||||
else:
|
||||
return result
|
||||
|
@ -463,32 +467,35 @@ class Consumer(Client):
|
|||
|
||||
def discover(self, principal, idtype="mail"):
|
||||
_loc = SWD_PATTERN % self.get_domain(principal, idtype)
|
||||
uri = IssuerRequest(ISSUER_URL, principal).request(_loc)
|
||||
uri = message("IssuerRequest", service=ISSUER_URL,
|
||||
principal=principal).request(_loc)
|
||||
|
||||
result = self.discovery_query(uri, principal)
|
||||
return result.locations[0]
|
||||
return result["locations"][0]
|
||||
|
||||
def register(self, server, type="client_associate", **kwargs):
|
||||
req = RegistrationRequest(type=type)
|
||||
req = message("RegistrationRequest", type=type)
|
||||
|
||||
if type == "client_update":
|
||||
req.client_id = self.client_id
|
||||
req.client_secret = self.client_secret
|
||||
req["client_id"] = self.client_id
|
||||
req["client_secret"] = self.client_secret
|
||||
|
||||
for prop in RegistrationRequest.c_attributes.keys():
|
||||
for prop in req.parameters():
|
||||
if prop in ["type", "client_id", "client_secret"]:
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(self, prop)
|
||||
if val:
|
||||
setattr(req, prop, val)
|
||||
req[prop] = val
|
||||
except Exception:
|
||||
val = None
|
||||
|
||||
if not val:
|
||||
if prop in kwargs:
|
||||
setattr(req, prop, kwargs[prop])
|
||||
try:
|
||||
req[prop] = kwargs[prop]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
(response, content) = self.http.request(server, "POST",
|
||||
|
@ -496,12 +503,12 @@ class Consumer(Client):
|
|||
headers=headers)
|
||||
|
||||
if response.status == 200:
|
||||
resp = RegistrationResponse.from_json(content)
|
||||
self.client_secret = resp.client_secret
|
||||
self.client_id = resp.client_id
|
||||
self.registration_expires = resp.expires_at
|
||||
resp = msg_deser(content, "json", "RegistrationResponse")
|
||||
self.client_secret = resp["client_secret"]
|
||||
self.client_id = resp["client_id"]
|
||||
self.registration_expires = resp["expires_at"]
|
||||
else:
|
||||
err = ErrorResponse.from_json(content)
|
||||
err = msg_deser(content, "json", "ErrorResponse")
|
||||
raise Exception("Registration failed: %s" % err.get_json())
|
||||
|
||||
return resp
|
File diff suppressed because it is too large
Load Diff
|
@ -3,11 +3,8 @@
|
|||
__author__ = 'rohe0002'
|
||||
|
||||
import random
|
||||
#import httplib2
|
||||
import base64
|
||||
|
||||
#from random import SystemRandom
|
||||
|
||||
from urlparse import parse_qs
|
||||
|
||||
from oic.oauth2.provider import Provider as AProvider
|
||||
|
@ -16,27 +13,17 @@ from oic.utils.http_util import *
|
|||
from oic.utils import time_util
|
||||
|
||||
from oic.oauth2 import MissingRequiredAttribute
|
||||
from oic.oauth2.provider import AuthnFailure
|
||||
from oic.oauth2 import rndstr
|
||||
from oic.oauth2.message import ErrorResponse
|
||||
from oic.oauth2.provider import AuthnFailure
|
||||
from oic.oauth2.message import by_schema
|
||||
from oic.oauth2.message import SCHEMA as OA2_SCHEMA
|
||||
|
||||
from oic.oic import Server
|
||||
|
||||
from oic.oic.message import AuthorizationResponse, AuthnToken
|
||||
from oic.oic.message import AuthorizationErrorResponse
|
||||
from oic.oic.message import msg_deser
|
||||
from oic.oic.message import SCOPE2CLAIMS
|
||||
from oic.oic.message import AuthorizationRequest
|
||||
from oic.oic.message import AccessTokenResponse
|
||||
from oic.oic.message import AccessTokenRequest
|
||||
from oic.oic.message import TokenErrorResponse
|
||||
from oic.oic.message import OpenIDRequest
|
||||
from oic.oic.message import IdToken
|
||||
from oic.oic.message import RegistrationRequest
|
||||
from oic.oic.message import RegistrationResponse
|
||||
from oic.oic.message import ProviderConfigurationResponse
|
||||
from oic.oic.message import UserInfoClaim
|
||||
|
||||
from oic import oauth2
|
||||
from oic.oic.message import message
|
||||
from oic.oic.message import SCHEMA
|
||||
from oic.oic import JWT_BEARER
|
||||
|
||||
class OICError(Exception):
|
||||
|
@ -93,53 +80,34 @@ def secret(seed, id):
|
|||
csum.update(id)
|
||||
return csum.hexdigest()
|
||||
|
||||
##noinspection PyUnusedLocal
|
||||
#def code_response(**kwargs):
|
||||
# _areq = kwargs["areq"]
|
||||
# _scode = kwargs["scode"]
|
||||
# aresp = AuthorizationResponse()
|
||||
# if _areq.state:
|
||||
# aresp.state = _areq.state
|
||||
# if _areq.nonce:
|
||||
# aresp.nonce = _areq.nonce
|
||||
# aresp.code = _scode
|
||||
# return aresp
|
||||
#
|
||||
#def token_response(**kwargs):
|
||||
# _areq = kwargs["areq"]
|
||||
# _scode = kwargs["scode"]
|
||||
# _sdb = kwargs["sdb"]
|
||||
# _dic = _sdb.update_to_token(_scode, issue_refresh=False)
|
||||
#
|
||||
# aresp = oauth2.factory(AccessTokenResponse, **_dic)
|
||||
# if _areq.scope:
|
||||
# aresp.scope = _areq.scope
|
||||
# return aresp
|
||||
|
||||
def add_token_info(aresp, sdict):
|
||||
for prop in AccessTokenResponse.c_attributes.keys():
|
||||
try:
|
||||
if sdict[prop]:
|
||||
setattr(aresp, prop, sdict[prop])
|
||||
except KeyError:
|
||||
pass
|
||||
#def update_info(aresp, sdict):
|
||||
# for prop in aresp._schema["param"].keys():
|
||||
# try:
|
||||
# aresp[prop] = sdict[prop]
|
||||
# except KeyError:
|
||||
# pass
|
||||
|
||||
def code_token_response(**kwargs):
|
||||
_areq = kwargs["areq"]
|
||||
_scode = kwargs["scode"]
|
||||
_sdb = kwargs["sdb"]
|
||||
aresp = AuthorizationResponse()
|
||||
if _areq.state:
|
||||
aresp.state = _areq.state
|
||||
if _areq.nonce:
|
||||
aresp.nonce = _areq.nonce
|
||||
if _areq.scope:
|
||||
aresp.scope = _areq.scope
|
||||
|
||||
aresp.code = _scode
|
||||
aresp = message("AuthorizationResponse")
|
||||
|
||||
for key in ["state", "nonce", "scope"]:
|
||||
try:
|
||||
aresp[key] = _areq[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
aresp["code"] = _scode
|
||||
|
||||
_dic = _sdb.update_to_token(_scode, issue_refresh=False)
|
||||
add_token_info(aresp, _dic)
|
||||
for prop in SCHEMA["AccessTokenResponse"]["param"].keys():
|
||||
try:
|
||||
aresp[prop] = _dic[prop]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return aresp
|
||||
|
||||
|
@ -170,8 +138,6 @@ def verify_acr_level(req, level):
|
|||
raise AccessDenied
|
||||
|
||||
class Provider(AProvider):
|
||||
authorization_request = AuthorizationRequest
|
||||
|
||||
def __init__(self, name, sdb, cdb, function, userdb, urlmap=None,
|
||||
debug=0, cache=None, timeout=None, proxy_info=None,
|
||||
follow_redirects=True, ca_certs="", jwt_keys=None):
|
||||
|
@ -204,13 +170,15 @@ class Provider(AProvider):
|
|||
# Handle the idtoken_claims
|
||||
extra = {}
|
||||
try:
|
||||
oidreq = OpenIDRequest.from_json(session["oidreq"])
|
||||
itc = oidreq.id_token
|
||||
info_log("ID Token claims: %s" % itc.dictionary())
|
||||
if itc.max_age:
|
||||
inawhile = {"seconds": itc.max_age}
|
||||
if itc.claims:
|
||||
for key, val in itc.claims.items():
|
||||
oidreq = msg_deser(session["oidreq"], "json", "OpenIDRequest")
|
||||
itc = oidreq["id_token"]
|
||||
info_log("ID Token claims: %s" % itc.to_dict())
|
||||
try:
|
||||
inawhile = {"seconds": itc["max_age"]}
|
||||
except KeyError:
|
||||
inawhile = {}
|
||||
if "claims" in itc:
|
||||
for key, val in itc["claims"].items():
|
||||
if key == "auth_time":
|
||||
extra["auth_time"] = time_util.utc_time_sans_frac()
|
||||
elif key == "acr":
|
||||
|
@ -219,14 +187,12 @@ class Provider(AProvider):
|
|||
except KeyError:
|
||||
pass
|
||||
|
||||
idt = IdToken(iss=self.name,
|
||||
user_id=session["user_id"],
|
||||
aud = session["client_id"],
|
||||
exp = time_util.epoch_in_a_while(**inawhile),
|
||||
acr=loa,
|
||||
)
|
||||
idt = message("IdToken", iss=self.name, user_id=session["user_id"],
|
||||
aud = session["client_id"],
|
||||
exp = time_util.epoch_in_a_while(**inawhile), acr=loa)
|
||||
|
||||
for key, val in extra.items():
|
||||
setattr(idt, key, val)
|
||||
idt[key] = val
|
||||
|
||||
if "nonce" in session:
|
||||
idt.nonce = session["nonce"]
|
||||
|
@ -241,17 +207,19 @@ class Provider(AProvider):
|
|||
if info_log:
|
||||
info_log("Sign idtoken with '%s'" % ckey)
|
||||
|
||||
return idt.get_jwt(key=ckey)
|
||||
return idt.to_jwt(key=ckey)
|
||||
|
||||
def _error(self, environ, start_response, error, descr=None):
|
||||
response = ErrorResponse(error=error, error_description=descr)
|
||||
resp = Response(response.get_json(), content="application/json")
|
||||
response = message(OA2_SCHEMA["ErrorResponse"], error=error,
|
||||
error_description=descr)
|
||||
resp = Response(response.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
def _authz_error(self, environ, start_response, error, descr=None):
|
||||
response = AuthorizationErrorResponse(error=error,
|
||||
error_description=descr)
|
||||
resp = Response(response.get_json(), content="application/json")
|
||||
response = message("AuthorizationErrorResponse", error=error,
|
||||
error_description=descr)
|
||||
|
||||
resp = Response(response.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
def authorization_endpoint(self, environ, start_response, logger,
|
||||
|
@ -276,8 +244,7 @@ class Provider(AProvider):
|
|||
|
||||
# Same serialization used for GET and POST
|
||||
try:
|
||||
areq = self.server.parse_authorization_request(query=query,
|
||||
extended=True)
|
||||
areq = self.server.parse_authorization_request(query=query)
|
||||
except MissingRequiredAttribute, err:
|
||||
resp = BadRequest("%s" % err)
|
||||
return resp(environ, start_response)
|
||||
|
@ -285,54 +252,64 @@ class Provider(AProvider):
|
|||
resp = BadRequest("%s" % err)
|
||||
return resp(environ, start_response)
|
||||
|
||||
if self.debug:
|
||||
_log_info("Prompt: '%s'" % areq.prompt)
|
||||
|
||||
if "none" in areq.prompt:
|
||||
if len(areq.prompt) > 1:
|
||||
return self._error(environ, start_response, "invalid_request")
|
||||
else:
|
||||
return self._authz_error(environ, start_response,
|
||||
"login_required")
|
||||
if "prompt" in areq:
|
||||
if self.debug:
|
||||
_log_info("Prompt: '%s'" % areq["prompt"])
|
||||
|
||||
if "none" in areq["prompt"]:
|
||||
if len(areq["prompt"]) > 1:
|
||||
return self._error(environ, start_response,
|
||||
"invalid_request")
|
||||
else:
|
||||
return self._authz_error(environ, start_response,
|
||||
"login_required")
|
||||
|
||||
|
||||
if areq.client_id not in self.cdb:
|
||||
raise UnknownClient(areq.client_id)
|
||||
if areq["client_id"] not in self.cdb:
|
||||
raise UnknownClient(areq["client_id"])
|
||||
|
||||
# verify that the redirect URI is resonable
|
||||
if areq.redirect_uri:
|
||||
if "redirect_uri" in areq:
|
||||
try:
|
||||
assert areq.redirect_uri in self.cdb[
|
||||
areq.client_id]["redirect_uris"]
|
||||
assert areq["redirect_uri"] in self.cdb[
|
||||
areq["client_id"]]["redirect_uris"]
|
||||
except AssertionError:
|
||||
return self._authz_error(environ, start_response,
|
||||
"invalid_request_redirect_uri")
|
||||
|
||||
if self.debug:
|
||||
_log_info("AREQ keys: %s" % areq.keys())
|
||||
# Is there an request decode it
|
||||
openid_req = None
|
||||
if "request" in areq or "request_uri" in areq:
|
||||
if self.debug:
|
||||
_log_info("OpenID request")
|
||||
try:
|
||||
_keystore = self.server.keystore
|
||||
jwt_key = _keystore.get_keys("verify", owner=None)
|
||||
except KeyError: # TODO
|
||||
raise KeyError("Missing verifying key")
|
||||
|
||||
if areq.request:
|
||||
if "request" in areq:
|
||||
try:
|
||||
openid_req = OpenIDRequest.set_jwt(areq.request, jwt_key)
|
||||
openid_req = message("OpenIDRequest").from_jwt(
|
||||
areq["request"],
|
||||
jwt_key)
|
||||
except Exception:
|
||||
return self._authz_error(environ, start_response,
|
||||
"invalid_openid_request_object")
|
||||
|
||||
elif areq.request_uri:
|
||||
elif "request_uri" in areq:
|
||||
# Do a HTTP get
|
||||
_req = self.http.request(areq.request_uri)
|
||||
_req = self.http.request(areq["request_uri"])
|
||||
if not _req:
|
||||
return self._authz_error(environ, start_response,
|
||||
"invalid_request_uri")
|
||||
|
||||
try:
|
||||
openid_req = OpenIDRequest.set_jwt(_req, jwt_key)
|
||||
openid_req = message("OpenIDRequest").from_jwt(_req,
|
||||
jwt_key)
|
||||
except Exception:
|
||||
return self._authz_error(environ, start_response,
|
||||
"invalid_openid_request_object")
|
||||
|
@ -346,10 +323,10 @@ class Provider(AProvider):
|
|||
_log_info("SID:%s" % bsid)
|
||||
|
||||
if openid_req:
|
||||
_max_age = -1
|
||||
if openid_req.id_token:
|
||||
if openid_req.id_token.max_age:
|
||||
_max_age = openid_req.id_token.max_age
|
||||
try:
|
||||
_max_age = openid_req["id_token"]["max_age"]
|
||||
except KeyError:
|
||||
_max_age = -1
|
||||
|
||||
if _max_age >= 0:
|
||||
if "handle" in kwargs:
|
||||
|
@ -365,6 +342,19 @@ class Provider(AProvider):
|
|||
areq=areq, user=user)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if "handle" in kwargs:
|
||||
try:
|
||||
(b64sid, timestamp) = kwargs["handle"]
|
||||
_log_info("- SSO -")
|
||||
_scode = base64.b64decode(b64sid)
|
||||
user = self.sdb[_scode]["user_id"]
|
||||
_sdb.update(sid, "user_id", user)
|
||||
return self.authenticated(environ, start_response,
|
||||
logger, active_auth=bsid,
|
||||
areq=areq, user=user)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# DEFAULT: start the authentication process
|
||||
return self.function["authenticate"](environ, start_response, bsid)
|
||||
|
@ -377,32 +367,32 @@ class Provider(AProvider):
|
|||
except AuthnFailure:
|
||||
pass
|
||||
|
||||
if areq.client_id not in self.cdb:
|
||||
if areq["client_id"] not in self.cdb:
|
||||
return False
|
||||
|
||||
if areq.client_secret: # client_secret_post
|
||||
identity = areq.client_id
|
||||
if self.cdb[identity]["client_secret"] == areq.client_secret:
|
||||
if "client_secret" in areq: # client_secret_post
|
||||
identity = areq["client_id"]
|
||||
if self.cdb[identity]["client_secret"] == areq["client_secret"]:
|
||||
return True
|
||||
elif areq.client_assertion: # client_secret_jwt or public_key_jwt
|
||||
if areq.client_assertion_type != JWT_BEARER:
|
||||
elif "client_assertion" in areq: # client_secret_jwt or public_key_jwt
|
||||
if areq["client_assertion_type"] != JWT_BEARER:
|
||||
return False
|
||||
|
||||
key_col = {areq.client_id:
|
||||
self.keystore.get_verify_key(owner=areq.client_id)}
|
||||
key_col = {areq["client_id"]:
|
||||
self.keystore.get_verify_key(owner=areq["client_id"])}
|
||||
key_col.update({".":self.keystore.get_verify_key()})
|
||||
|
||||
if log_info:
|
||||
log_info("key_col: %s" % (key_col,))
|
||||
|
||||
bjwt = AuthnToken.set_jwt(areq.client_assertion, key_col)
|
||||
bjwt = message("AuthnToken").from_jwt(areq["client_assertion"],
|
||||
key_col)
|
||||
|
||||
try:
|
||||
assert bjwt.iss == areq.client_id # Issuer = the client
|
||||
assert bjwt["iss"] == areq["client_id"] # Issuer = the client
|
||||
# Is this true bjwt.iss == areq.client_id
|
||||
assert str(bjwt.iss) in self.cdb # It's a client I know
|
||||
assert str(bjwt.aud) == geturl(environ,
|
||||
query=False) # audience = me
|
||||
assert str(bjwt["iss"]) in self.cdb # It's a client I know
|
||||
assert str(bjwt["aud"]) == geturl(environ, query=False)
|
||||
return True
|
||||
except AssertionError:
|
||||
pass
|
||||
|
@ -424,29 +414,29 @@ class Provider(AProvider):
|
|||
if self.debug:
|
||||
_log_info("body: %s" % body)
|
||||
|
||||
areq = AccessTokenRequest.set_urlencoded(body, extended=True)
|
||||
areq = msg_deser(body, "urlencoded", "AccessTokenRequest")
|
||||
|
||||
if self.debug:
|
||||
_log_info("environ: %s" % environ)
|
||||
|
||||
if not self.verify_client(environ, areq, _log_info):
|
||||
_log_info("could not verify client")
|
||||
err = TokenErrorResponse(error="unathorized_client")
|
||||
resp = Unauthorized(err.get_json(), content="application/json")
|
||||
err = message("TokenErrorResponse", error="unathorized_client")
|
||||
resp = Unauthorized(err.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
if self.debug:
|
||||
_log_info("AccessTokenRequest: %s" % areq)
|
||||
|
||||
assert areq.grant_type == "authorization_code"
|
||||
assert areq["grant_type"] == "authorization_code"
|
||||
|
||||
# assert that the code is valid
|
||||
_info = _sdb[areq.code]
|
||||
_info = _sdb[areq["code"]]
|
||||
|
||||
# If redirect_uri was in the initial authorization request
|
||||
# verify that the one given here is the correct one.
|
||||
if "redirect_uri" in _info:
|
||||
assert areq.redirect_uri == _info["redirect_uri"]
|
||||
assert areq["redirect_uri"] == _info["redirect_uri"]
|
||||
|
||||
if self.debug:
|
||||
_log_info("All checks OK")
|
||||
|
@ -461,7 +451,7 @@ class Provider(AProvider):
|
|||
_idtoken = None
|
||||
|
||||
try:
|
||||
_tinfo = _sdb.update_to_token(areq.code, id_token=_idtoken)
|
||||
_tinfo = _sdb.update_to_token(areq["code"], id_token=_idtoken)
|
||||
except Exception,err:
|
||||
_log_info("Error: %s" % err)
|
||||
raise
|
||||
|
@ -469,7 +459,8 @@ class Provider(AProvider):
|
|||
if self.debug:
|
||||
_log_info("_tinfo: %s" % _tinfo)
|
||||
|
||||
atr = oauth2.factory(AccessTokenResponse, **_tinfo)
|
||||
atr = message("AccessTokenResponse",
|
||||
**by_schema(SCHEMA["AccessTokenResponse"], **_tinfo))
|
||||
|
||||
if self.debug:
|
||||
_log_info("AccessTokenResponse: %s" % atr)
|
||||
|
@ -510,7 +501,7 @@ class Provider(AProvider):
|
|||
else:
|
||||
uireq = self.server.parse_user_info_request(data=query)
|
||||
_log_info("user_info_request: %s" % uireq)
|
||||
_token = uireq.access_token
|
||||
_token = uireq["access_token"]
|
||||
|
||||
# should be an access token
|
||||
typ, key = self.sdb.token.type_and_key(_token)
|
||||
|
@ -534,23 +525,25 @@ class Provider(AProvider):
|
|||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
_req = session["oidreq"]
|
||||
_log_info("OIDREQ: %s" % _req)
|
||||
oidreq = OpenIDRequest.from_json(_req)
|
||||
userinfo_claims = oidreq.userinfo
|
||||
if userinfo_claims:
|
||||
_claim = userinfo_claims.claims
|
||||
if "oidreq" in session:
|
||||
oidreq = msg_deser(session["oidreq"], "json", "OpenIDRequest")
|
||||
_log_info("OIDREQ: %s" % oidreq.to_dict())
|
||||
if "userinfo" in oidreq:
|
||||
userinfo_claims = oidreq["userinfo"]
|
||||
_claim = oidreq["userinfo"]["claims"]
|
||||
for key, val in uic.items():
|
||||
if key not in _claim:
|
||||
setattr(_claim, key, val)
|
||||
except KeyError:
|
||||
if uic:
|
||||
userinfo_claims = UserInfoClaim(claims=uic)
|
||||
_claim[key] = val
|
||||
elif uic:
|
||||
userinfo_claims = message("UserInfoClaim", claims=uic)
|
||||
else:
|
||||
userinfo_claims = None
|
||||
userinfo_claims = None
|
||||
elif uic:
|
||||
userinfo_claims = message("UserInfoClaim", claims=uic)
|
||||
else:
|
||||
userinfo_claims = None
|
||||
|
||||
_log_info("userinfo_claim: %s" % userinfo_claims)
|
||||
_log_info("userinfo_claim: %s" % userinfo_claims.to_dict())
|
||||
_log_info("userdb: %s" % self.userdb.keys())
|
||||
#logger.info("oidreq: %s[%s]" % (oidreq, type(oidreq)))
|
||||
info = self.function["userinfo"](self, self.userdb,
|
||||
|
@ -559,7 +552,7 @@ class Provider(AProvider):
|
|||
userinfo_claims)
|
||||
|
||||
_log_info("info: %s" % (info,))
|
||||
resp = Response(info.get_json(), content="application/json")
|
||||
resp = Response(info.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
|
@ -576,7 +569,7 @@ class Provider(AProvider):
|
|||
|
||||
idt = self.server.parse_check_session_request(query=info)
|
||||
|
||||
resp = Response(idt.get_json(), content="application/json")
|
||||
resp = Response(idt.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
|
@ -588,12 +581,17 @@ class Provider(AProvider):
|
|||
resp = BadRequest("Unsupported method")
|
||||
return resp(environ, start_response)
|
||||
|
||||
logger.info("environ: %s" % environ)
|
||||
logger.info("info: '%s'" % info)
|
||||
if not info:
|
||||
logger.info("HTTP_AUTHORIZATION: %s" %
|
||||
environ["HTTP_AUTHORIZATION"])
|
||||
info = "access_token=%s" % self._bearer_auth(environ)
|
||||
|
||||
logger.info("check_id_endpoint: query=%s" % info)
|
||||
idt = self.server.parse_check_id_request(query=info)
|
||||
|
||||
resp = Response(idt.get_json(), content="application/json")
|
||||
resp = Response(idt.to_json(), content="application/json")
|
||||
return resp(environ, start_response)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
|
@ -605,11 +603,11 @@ class Provider(AProvider):
|
|||
resp = BadRequest("Unsupported method")
|
||||
return resp(environ, start_response)
|
||||
|
||||
request = RegistrationRequest.from_urlencoded(query)
|
||||
logger.info("RegistrationRequest:%s" % request.dictionary())
|
||||
request = msg_deser(query, "urlencoded", "RegistrationRequest")
|
||||
logger.info("RegistrationRequest:%s" % request.to_dict())
|
||||
|
||||
_keystore = self.server.keystore
|
||||
if request.type == "client_associate":
|
||||
if request["type"] == "client_associate":
|
||||
# create new id och secret
|
||||
client_id = rndstr(12)
|
||||
while client_id in self.cdb:
|
||||
|
@ -621,15 +619,15 @@ class Provider(AProvider):
|
|||
}
|
||||
_cinfo = self.cdb[client_id]
|
||||
|
||||
for key,val in request.dictionary().items():
|
||||
for key,val in request.items():
|
||||
_cinfo[key] = val
|
||||
|
||||
self.keystore.load_keys(request, client_id)
|
||||
logger.info("KEYSTORE: %s" % self.keystore._store)
|
||||
|
||||
elif request.type == "client_update":
|
||||
elif request["type"] == "client_update":
|
||||
# that these are an id,secret pair I know about
|
||||
client_id = request.client_id
|
||||
client_id = request["client_id"]
|
||||
try:
|
||||
_cinfo = self.cdb[client_id]
|
||||
except KeyError:
|
||||
|
@ -637,7 +635,7 @@ class Provider(AProvider):
|
|||
resp = BadRequest()
|
||||
return resp(environ, start_response)
|
||||
|
||||
if _cinfo["client_secret"] != request.client_secret:
|
||||
if _cinfo["client_secret"] != request["client_secret"]:
|
||||
logger.info("Wrong secret")
|
||||
resp = BadRequest()
|
||||
return resp(environ, start_response)
|
||||
|
@ -646,12 +644,12 @@ class Provider(AProvider):
|
|||
client_secret = secret(self.seed, client_id)
|
||||
_cinfo["client_secret"] = client_secret
|
||||
|
||||
old_key = request.client_secret
|
||||
old_key = request["client_secret"]
|
||||
_keystore.remove_key(old_key, client_id, type="hmac", usage="sign")
|
||||
_keystore.remove_key(old_key, client_id, type="hmac",
|
||||
usage="verify")
|
||||
|
||||
for key,val in request.dictionary().items():
|
||||
for key,val in request.items():
|
||||
if key in ["client_id", "client_secret"]:
|
||||
continue
|
||||
|
||||
|
@ -670,10 +668,11 @@ class Provider(AProvider):
|
|||
|
||||
# set expiration time
|
||||
_cinfo["registration_expires"] = time_util.time_sans_frac()+3600
|
||||
response = RegistrationResponse(client_id, client_secret,
|
||||
expires_in=3600)
|
||||
response = message("RegistrationResponse", client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
expires_at=_cinfo["registration_expires"])
|
||||
|
||||
logger.info("Registration response: %s" % response.dictionary())
|
||||
logger.info("Registration response: %s" % response.to_dict())
|
||||
|
||||
resp = Response(response.to_json(), content="application/json",
|
||||
headers=[("Cache-Control", "no-store")])
|
||||
|
@ -681,7 +680,7 @@ class Provider(AProvider):
|
|||
|
||||
#noinspection PyUnusedLocal
|
||||
def providerinfo_endpoint(self, environ, start_response, logger, *args):
|
||||
_response = ProviderConfigurationResponse(
|
||||
_response = message("ProviderConfigurationResponse",
|
||||
issuer=self.baseurl,
|
||||
token_endpoint_auth_types_supported=["client_secret_post",
|
||||
"client_secret_basic",
|
||||
|
@ -696,9 +695,9 @@ class Provider(AProvider):
|
|||
|
||||
#keys = self.keystore.keys_by_owner(owner=".")
|
||||
for cert in self.cert:
|
||||
setattr(_response, "x509_url", "%s%s" % (self.baseurl, cert))
|
||||
_response["x509_url"] = "%s%s" % (self.baseurl, cert)
|
||||
for jwk in self.jwk:
|
||||
setattr(_response, "jwk_url", "%s%s" % (self.baseurl, jwk))
|
||||
_response["jwk_url"] = "%s%s" % (self.baseurl, jwk)
|
||||
|
||||
if not self.baseurl.endswith("/"):
|
||||
self.baseurl += "/"
|
||||
|
@ -707,8 +706,8 @@ class Provider(AProvider):
|
|||
logger.info("# %s, %s" % (endp, endp.name))
|
||||
_response[endp.name] = "%s%s" % (self.baseurl, endp.type)
|
||||
|
||||
logger.info("provider_info_response: %s" % _response.dictionary(True))
|
||||
resp = Response(_response.to_json(True), content="application/json",
|
||||
logger.info("provider_info_response: %s" % _response.to_dict())
|
||||
resp = Response(_response.to_json(), content="application/json",
|
||||
headers=[("Cache-Control", "no-store")])
|
||||
return resp(environ, start_response)
|
||||
|
||||
|
@ -761,8 +760,8 @@ class Provider(AProvider):
|
|||
#_log_info( "type: %s" % type(asession["authzreq"]))
|
||||
|
||||
# pick up the original request
|
||||
areq = AuthorizationRequest.set_json(asession["authzreq"],
|
||||
extended=True)
|
||||
areq = msg_deser(asession["authzreq"], "json",
|
||||
"AuthorizationRequest")
|
||||
|
||||
if self.debug:
|
||||
_log_info("areq: %s" % areq)
|
||||
|
@ -775,64 +774,67 @@ class Provider(AProvider):
|
|||
except Exception:
|
||||
raise
|
||||
|
||||
_log_info("response type: %s" % areq.response_type)
|
||||
_log_info("response type: %s" % areq["response_type"])
|
||||
|
||||
# create the response
|
||||
aresp = AuthorizationResponse()
|
||||
if areq.state:
|
||||
aresp.state = areq.state
|
||||
aresp = message("AuthorizationResponse")
|
||||
try:
|
||||
aresp["state"] = areq["state"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if len(areq.response_type) == 1 and "none" in areq.response_type:
|
||||
if "response_type" in areq and \
|
||||
len(areq["response_type"]) == 1 and \
|
||||
"none" in areq["response_type"]:
|
||||
pass
|
||||
else:
|
||||
_sinfo = self.sdb[scode]
|
||||
|
||||
if areq.scope:
|
||||
aresp.scope = areq.scope
|
||||
try:
|
||||
aresp["scope"] = areq["scope"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if self.debug:
|
||||
_log_info("_dic: %s" % _sinfo)
|
||||
|
||||
rtype = set(areq.response_type[:])
|
||||
if "code" in areq.response_type:
|
||||
aresp.code = _sinfo["code"]
|
||||
aresp.c_extension = areq.c_extension
|
||||
rtype = set(areq["response_type"][:])
|
||||
if "code" in areq["response_type"]:
|
||||
aresp["code"] = _sinfo["code"]
|
||||
rtype.remove("code")
|
||||
else:
|
||||
self.sdb[scode]["code"] = None
|
||||
|
||||
if "id_token" in areq.response_type:
|
||||
if "id_token" in areq["response_type"]:
|
||||
id_token = self._id_token(_sinfo, info_log=_log_info)
|
||||
aresp.id_token = id_token
|
||||
aresp["id_token"] = id_token
|
||||
_sinfo["id_token"] = id_token
|
||||
rtype.remove("id_token")
|
||||
|
||||
if "token" in areq.response_type:
|
||||
if "token" in areq["response_type"]:
|
||||
_dic = self.sdb.update_to_token(issue_refresh=False,
|
||||
key=scode)
|
||||
|
||||
if self.debug:
|
||||
_log_info("_dic: %s" % _dic)
|
||||
for key, val in _dic.items():
|
||||
if key in aresp.c_attributes:
|
||||
setattr(aresp, key, val)
|
||||
if key in aresp.parameters():
|
||||
aresp[key] = val
|
||||
|
||||
aresp.c_extension = areq.c_extension
|
||||
rtype.remove("token")
|
||||
|
||||
if len(rtype):
|
||||
resp = BadRequest("Unknown response type")
|
||||
return resp(environ, start_response)
|
||||
|
||||
if areq.redirect_uri:
|
||||
assert areq.redirect_uri in self.cdb[
|
||||
areq.client_id]["redirect_uris"]
|
||||
redirect_uri = areq.redirect_uri
|
||||
if "redirect_uri" in areq:
|
||||
assert areq["redirect_uri"] in self.cdb[
|
||||
areq["client_id"]]["redirect_uris"]
|
||||
redirect_uri = areq["redirect_uri"]
|
||||
else:
|
||||
redirect_uri = self.cdb[areq.client_id]["redirect_uris"][0]
|
||||
|
||||
location = "%s?%s" % (redirect_uri, aresp.get_urlencoded())
|
||||
redirect_uri = self.cdb[areq["client_id"]]["redirect_uris"][0]
|
||||
|
||||
location = aresp.request(redirect_uri)
|
||||
|
||||
if self.debug:
|
||||
_log_info("Redirected to: '%s' (%s)" % (location, type(location)))
|
||||
|
@ -874,4 +876,4 @@ class CheckIDEndpoint(Endpoint):
|
|||
type = "check_id"
|
||||
|
||||
class RegistrationEndpoint(Endpoint) :
|
||||
type = "registration"
|
||||
type = "registration"
|
|
@ -64,13 +64,20 @@ class Token(object):
|
|||
csum.update("%f" % random.random())
|
||||
if user:
|
||||
csum.update(user)
|
||||
|
||||
if areq:
|
||||
csum.update(areq.state)
|
||||
if areq.scope:
|
||||
for val in areq.scope:
|
||||
csum.update(areq["state"])
|
||||
try:
|
||||
for val in areq["scope"]:
|
||||
csum.update(val)
|
||||
if areq.redirect_uri:
|
||||
csum.update(areq.redirect_uri)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
csum.update(areq["redirect_uri"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return csum.digest() # 28 bytes long, 224 bits
|
||||
|
||||
def _split_token(self, token):
|
||||
|
@ -161,33 +168,30 @@ class SessionDB(object):
|
|||
"user_id": user_id,
|
||||
"code": access_grant,
|
||||
"code_used": False,
|
||||
"authzreq": areq.get_json(),
|
||||
"client_id": areq.client_id,
|
||||
"authzreq": areq.to_json(),
|
||||
"client_id": areq["client_id"],
|
||||
"expires_in": self.grant_expires_in,
|
||||
"expires_at": utc_time_sans_frac()+self.grant_expires_in,
|
||||
"issued": time.time()
|
||||
}
|
||||
|
||||
try:
|
||||
_val = areq.nonce
|
||||
_val = areq["nonce"]
|
||||
if _val:
|
||||
_dic["nonce"] = _val
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
if areq.redirect_uri:
|
||||
_dic["redirect_uri"] = areq.redirect_uri
|
||||
if areq.state:
|
||||
_dic["state"] = areq.state
|
||||
|
||||
# Just an assumption
|
||||
if areq.scope:
|
||||
_dic["scope"] = areq.scope
|
||||
for key in ["redirect_uri", "state", "scope"]:
|
||||
try:
|
||||
_dic[key] = areq[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if id_token:
|
||||
_dic["id_token"] = id_token
|
||||
if oidreq:
|
||||
_dic["oidreq"] = oidreq.get_json()
|
||||
_dic["oidreq"] = oidreq.to_json()
|
||||
|
||||
self._db[sid] = _dic
|
||||
return sid
|
||||
|
|
Reference in New Issue