diff --git a/src/oic/oauth2/__init__.py b/src/oic/oauth2/__init__.py index b5448fa..43952eb 100644 --- a/src/oic/oauth2/__init__.py +++ b/src/oic/oauth2/__init__.py @@ -424,7 +424,7 @@ class PBase(object): class Client(PBase): _endpoints = ENDPOINTS - def __init__(self, client_id=None, ca_certs=None, client_authn_method=None, + def __init__(self, client_id=None, client_secret=None, ca_certs=None, client_authn_method=None, keyjar=None, verify_ssl=True): """ @@ -436,12 +436,12 @@ class Client(PBase): :param verify_ssl: Whether the SSL certificate should be verfied. :return: Client instance """ - PBase.__init__(self, ca_certs, verify_ssl=verify_ssl) - self.client_id = client_id self.client_authn_method = client_authn_method - self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl) + self.keyjar = keyjar or KeyJar(verify_ssl=verify_ssl, + client_id=client_id, + client_secret=client_secret) self.verify_ssl = verify_ssl # self.secret_type = "basic " diff --git a/src/oic/oic/__init__.py b/src/oic/oic/__init__.py index 8230746..dc435cc 100644 --- a/src/oic/oic/__init__.py +++ b/src/oic/oic/__init__.py @@ -237,11 +237,10 @@ PROVIDER_DEFAULT = { class Client(oauth2.Client): _endpoints = ENDPOINTS - def __init__(self, client_id=None, ca_certs=None, + def __init__(self, client_id=None, client_secret=None, ca_certs=None, client_prefs=None, client_authn_method=None, keyjar=None, verify_ssl=True): - - oauth2.Client.__init__(self, client_id, ca_certs, + oauth2.Client.__init__(self, client_id, client_secret, ca_certs, client_authn_method=client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl) @@ -841,7 +840,9 @@ class Client(oauth2.Client): if keys: if self.keyjar is None: - self.keyjar = KeyJar(verify_ssl=self.verify_ssl) + self.keyjar = KeyJar(verify_ssl=self.verify_ssl, + client_id=self.client_id, + client_secret=self.client_secret) self.keyjar.load_keys(pcr, _pcr_issuer) @@ -1024,8 +1025,8 @@ class Client(oauth2.Client): self.registration_response = reginfo if "token_endpoint_auth_method" not in self.registration_response: self.registration_response["token_endpoint_auth_method"] = "client_secret_post" - self.client_secret = reginfo["client_secret"] - self.client_id = reginfo["client_id"] + self.client_secret = self.keyjar.client_secret = reginfo["client_secret"] + self.client_id = self.keyjar.client_id = reginfo["client_id"] try: self.registration_expires = reginfo["client_secret_expires_at"] except KeyError: diff --git a/src/oic/oic/message.py b/src/oic/oic/message.py index 4c966a7..b3b41a1 100644 --- a/src/oic/oic/message.py +++ b/src/oic/oic/message.py @@ -857,4 +857,4 @@ if __name__ == "__main__": print atr.verify() uue = atr.serialize() atr = AccessTokenRequest().deserialize(uue, "urlencoded") - print atr \ No newline at end of file + print atr diff --git a/src/oic/utils/keyio.py b/src/oic/utils/keyio.py index b84eb0b..738166e 100644 --- a/src/oic/utils/keyio.py +++ b/src/oic/utils/keyio.py @@ -48,7 +48,7 @@ K2C = { class KeyBundle(object): def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, - fileformat="jwk", keytype="RSA", keyusage=None): + fileformat="jwk", keytype="RSA", keyusage=None, **kwargs): """ :param keys: A list of dictionaries @@ -71,6 +71,7 @@ class KeyBundle(object): self.keytype = keytype self.keyusage = keyusage self.imp_jwks = None + self.kwargs = kwargs if keys: self.source = None @@ -136,6 +137,8 @@ class KeyBundle(object): args = {"allow_redirects": True, "verify": self.verify_ssl, "timeout": 5.0} + if 'client_id' in self.kwargs and 'client_secret' in self.kwargs: + args.update({'auth': (self.kwargs['client_id'], self.kwargs['client_secret'])}) if self.etag: args["headers"] = {"If-None-Match": self.etag} @@ -161,7 +164,7 @@ class KeyBundle(object): pass return True else: - raise UpdateFailed() + raise UpdateFailed("Remote key update failed") def _uptodate(self): res = False @@ -308,7 +311,8 @@ def dump_jwks(kbl, target): class KeyJar(object): """ A keyjar contains a number of KeyBundles """ - def __init__(self, ca_certs=None, verify_ssl=True): + def __init__(self, ca_certs=None, verify_ssl=True, + client_id=None, client_secret=None): """ :param ca_certs: @@ -319,6 +323,8 @@ class KeyJar(object): self.issuer_keys = {} self.ca_certs = ca_certs self.verify_ssl = verify_ssl + self.client_id = client_id + self.client_secret = client_secret def add_if_unique(self, issuer, use, keys): if use in self.issuer_keys[issuer] and self.issuer_keys[issuer][use]: @@ -344,9 +350,12 @@ class KeyJar(object): raise KeyError("No jwks_uri") if "/localhost:" in url or "/localhost/" in url: + kc = KeyBundle(source=url, verify_ssl=False) else: - kc = KeyBundle(source=url, verify_ssl=self.verify_ssl) + kc = KeyBundle(source=url, verify_ssl=self.verify_ssl, + client_id=self.client_id, + client_secret=self.client_secret) try: self.issuer_keys[issuer].append(kc)