http basic authentication while retreiving the jwt
This commit is contained in:
parent
14b9aa7fa9
commit
a5e11c247f
|
@ -424,7 +424,7 @@ class PBase(object):
|
||||||
class Client(PBase):
|
class Client(PBase):
|
||||||
_endpoints = ENDPOINTS
|
_endpoints = ENDPOINTS
|
||||||
|
|
||||||
def __init__(self, client_id=None, ca_certs=None, client_authn_method=None,
|
def __init__(self, client_id=None, client_secret=None, ca_certs=None, client_authn_method=None,
|
||||||
keyjar=None, verify_ssl=True):
|
keyjar=None, verify_ssl=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -436,12 +436,12 @@ class Client(PBase):
|
||||||
:param verify_ssl: Whether the SSL certificate should be verfied.
|
:param verify_ssl: Whether the SSL certificate should be verfied.
|
||||||
:return: Client instance
|
:return: Client instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)
|
PBase.__init__(self, ca_certs, verify_ssl=verify_ssl)
|
||||||
|
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_authn_method = client_authn_method
|
self.client_authn_method = client_authn_method
|
||||||
self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl)
|
self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret)
|
||||||
self.verify_ssl = verify_ssl
|
self.verify_ssl = verify_ssl
|
||||||
# self.secret_type = "basic "
|
# self.secret_type = "basic "
|
||||||
|
|
||||||
|
|
|
@ -237,11 +237,10 @@ PROVIDER_DEFAULT = {
|
||||||
class Client(oauth2.Client):
|
class Client(oauth2.Client):
|
||||||
_endpoints = ENDPOINTS
|
_endpoints = ENDPOINTS
|
||||||
|
|
||||||
def __init__(self, client_id=None, ca_certs=None,
|
def __init__(self, client_id=None, client_secret=None, ca_certs=None,
|
||||||
client_prefs=None, client_authn_method=None, keyjar=None,
|
client_prefs=None, client_authn_method=None, keyjar=None,
|
||||||
verify_ssl=True):
|
verify_ssl=True):
|
||||||
|
oauth2.Client.__init__(self, client_id, client_secret, ca_certs,
|
||||||
oauth2.Client.__init__(self, client_id, ca_certs,
|
|
||||||
client_authn_method=client_authn_method,
|
client_authn_method=client_authn_method,
|
||||||
keyjar=keyjar, verify_ssl=verify_ssl)
|
keyjar=keyjar, verify_ssl=verify_ssl)
|
||||||
|
|
||||||
|
@ -841,7 +840,9 @@ class Client(oauth2.Client):
|
||||||
|
|
||||||
if keys:
|
if keys:
|
||||||
if self.keyjar is None:
|
if self.keyjar is None:
|
||||||
self.keyjar = KeyJar(verify_ssl=self.verify_ssl)
|
self.keyjar = KeyJar(verify_ssl=self.verify_ssl,
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_secret=self.client_secret)
|
||||||
|
|
||||||
self.keyjar.load_keys(pcr, _pcr_issuer)
|
self.keyjar.load_keys(pcr, _pcr_issuer)
|
||||||
|
|
||||||
|
@ -1024,8 +1025,8 @@ class Client(oauth2.Client):
|
||||||
self.registration_response = reginfo
|
self.registration_response = reginfo
|
||||||
if "token_endpoint_auth_method" not in self.registration_response:
|
if "token_endpoint_auth_method" not in self.registration_response:
|
||||||
self.registration_response["token_endpoint_auth_method"] = "client_secret_post"
|
self.registration_response["token_endpoint_auth_method"] = "client_secret_post"
|
||||||
self.client_secret = reginfo["client_secret"]
|
self.client_secret = self.keyjar.client_secret = reginfo["client_secret"]
|
||||||
self.client_id = reginfo["client_id"]
|
self.client_id = self.keyjar.client_id = reginfo["client_id"]
|
||||||
try:
|
try:
|
||||||
self.registration_expires = reginfo["client_secret_expires_at"]
|
self.registration_expires = reginfo["client_secret_expires_at"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
|
@ -857,4 +857,4 @@ if __name__ == "__main__":
|
||||||
print atr.verify()
|
print atr.verify()
|
||||||
uue = atr.serialize()
|
uue = atr.serialize()
|
||||||
atr = AccessTokenRequest().deserialize(uue, "urlencoded")
|
atr = AccessTokenRequest().deserialize(uue, "urlencoded")
|
||||||
print atr
|
print atr
|
||||||
|
|
|
@ -48,7 +48,7 @@ K2C = {
|
||||||
|
|
||||||
class KeyBundle(object):
|
class KeyBundle(object):
|
||||||
def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
|
def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
|
||||||
fileformat="jwk", keytype="RSA", keyusage=None):
|
fileformat="jwk", keytype="RSA", keyusage=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param keys: A list of dictionaries
|
:param keys: A list of dictionaries
|
||||||
|
@ -71,6 +71,7 @@ class KeyBundle(object):
|
||||||
self.keytype = keytype
|
self.keytype = keytype
|
||||||
self.keyusage = keyusage
|
self.keyusage = keyusage
|
||||||
self.imp_jwks = None
|
self.imp_jwks = None
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
if keys:
|
if keys:
|
||||||
self.source = None
|
self.source = None
|
||||||
|
@ -136,6 +137,8 @@ class KeyBundle(object):
|
||||||
args = {"allow_redirects": True,
|
args = {"allow_redirects": True,
|
||||||
"verify": self.verify_ssl,
|
"verify": self.verify_ssl,
|
||||||
"timeout": 5.0}
|
"timeout": 5.0}
|
||||||
|
if 'client_id' in self.kwargs and 'client_secret' in self.kwargs:
|
||||||
|
args.update({'auth': (self.kwargs['client_id'], self.kwargs['client_secret'])})
|
||||||
if self.etag:
|
if self.etag:
|
||||||
args["headers"] = {"If-None-Match": self.etag}
|
args["headers"] = {"If-None-Match": self.etag}
|
||||||
|
|
||||||
|
@ -161,7 +164,7 @@ class KeyBundle(object):
|
||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise UpdateFailed()
|
raise UpdateFailed("Remote key update failed")
|
||||||
|
|
||||||
def _uptodate(self):
|
def _uptodate(self):
|
||||||
res = False
|
res = False
|
||||||
|
@ -308,7 +311,8 @@ def dump_jwks(kbl, target):
|
||||||
class KeyJar(object):
|
class KeyJar(object):
|
||||||
""" A keyjar contains a number of KeyBundles """
|
""" A keyjar contains a number of KeyBundles """
|
||||||
|
|
||||||
def __init__(self, ca_certs=None, verify_ssl=True):
|
def __init__(self, ca_certs=None, verify_ssl=True,
|
||||||
|
client_id=None, client_secret=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param ca_certs:
|
:param ca_certs:
|
||||||
|
@ -319,6 +323,8 @@ class KeyJar(object):
|
||||||
self.issuer_keys = {}
|
self.issuer_keys = {}
|
||||||
self.ca_certs = ca_certs
|
self.ca_certs = ca_certs
|
||||||
self.verify_ssl = verify_ssl
|
self.verify_ssl = verify_ssl
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
|
||||||
def add_if_unique(self, issuer, use, keys):
|
def add_if_unique(self, issuer, use, keys):
|
||||||
if use in self.issuer_keys[issuer] and self.issuer_keys[issuer][use]:
|
if use in self.issuer_keys[issuer] and self.issuer_keys[issuer][use]:
|
||||||
|
@ -344,9 +350,12 @@ class KeyJar(object):
|
||||||
raise KeyError("No jwks_uri")
|
raise KeyError("No jwks_uri")
|
||||||
|
|
||||||
if "/localhost:" in url or "/localhost/" in url:
|
if "/localhost:" in url or "/localhost/" in url:
|
||||||
|
|
||||||
kc = KeyBundle(source=url, verify_ssl=False)
|
kc = KeyBundle(source=url, verify_ssl=False)
|
||||||
else:
|
else:
|
||||||
kc = KeyBundle(source=url, verify_ssl=self.verify_ssl)
|
kc = KeyBundle(source=url, verify_ssl=self.verify_ssl,
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_secret=self.client_secret)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.issuer_keys[issuer].append(kc)
|
self.issuer_keys[issuer].append(kc)
|
||||||
|
|
Reference in New Issue