diff --git a/src/oic/oauth2/__init__.py b/src/oic/oauth2/__init__.py index bebe268..2608daa 100644 --- a/src/oic/oauth2/__init__.py +++ b/src/oic/oauth2/__init__.py @@ -61,11 +61,31 @@ class ResponseError(PyoidcError): pass -class TimeFormatError(Exception): +class TimeFormatError(PyoidcError): pass -class CapabilitiesMisMatch(Exception): +class CapabilitiesMisMatch(PyoidcError): + pass + + +class MissingEndpoint(PyoidcError): + pass + + +class TokenError(PyoidcError): + pass + + +class GrantError(PyoidcError): + pass + + +class ParseError(PyoidcError): + pass + + +class OtherError(PyoidcError): pass @@ -510,10 +530,10 @@ class Client(PBase): try: uri = getattr(self, endpoint) except Exception: - raise Exception("No '%s' specified" % endpoint) + raise MissingEndpoint("No '%s' specified" % endpoint) if not uri: - raise Exception("No '%s' specified" % endpoint) + raise MissingEndpoint("No '%s' specified" % endpoint) return uri @@ -528,7 +548,7 @@ class Client(PBase): try: return self.grant[state] except: - raise Exception("No grant found for state:'%s'" % state) + raise GrantError("No grant found for state:'%s'" % state) def get_token(self, also_expired=False, **kwargs): try: @@ -544,17 +564,17 @@ class Client(PBase): try: token = self.grant[kwargs["state"]].get_token("") except KeyError: - raise Exception("No token found for scope") + raise TokenError("No token found for scope") if token is None: - raise Exception("No suitable token found") + raise TokenError("No suitable token found") if also_expired: return token elif token.is_valid(): return token else: - raise ExpiredToken("Token has expired") + raise TokenError("Token has expired") def construct_request(self, request, request_args=None, extra_args=None): if request_args is None: @@ -696,7 +716,7 @@ class Client(PBase): else: kwargs["headers"] = header_ext else: - raise Exception("Unsupported HTTP method: '%s'" % method) + raise UnSupported("Unsupported HTTP method: '%s'" % method) return path, body, kwargs @@ -889,14 +909,14 @@ class Client(PBase): pass elif reqresp.status_code == 500: logger.error("(%d) %s" % (reqresp.status_code, reqresp.text)) - raise Exception("ERROR: Something went wrong: %s" % reqresp.text) + raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) elif reqresp.status_code in [400, 401]: #expecting an error response if issubclass(response, ErrorResponse): pass else: logger.error("(%d) %s" % (reqresp.status_code, reqresp.text)) - raise Exception("HTTP ERROR: %s [%s] on %s" % ( + raise HTTP_ERROR("HTTP ERROR: %s [%s] on %s" % ( reqresp.text, reqresp.status_code, reqresp.url)) if body_type: @@ -904,7 +924,7 @@ class Client(PBase): return self.parse_response(response, reqresp.text, body_type, state, **kwargs) else: - raise Exception("Didn't expect a response body") + raise OtherError("Didn't expect a response body") else: return reqresp diff --git a/src/oic/oauth2/message.py b/src/oic/oauth2/message.py index c0b1b97..562cd51 100644 --- a/src/oic/oauth2/message.py +++ b/src/oic/oauth2/message.py @@ -16,6 +16,10 @@ from oic.exception import MessageException logger = logging.getLogger(__name__) +class FormatError(PyoidcError): + pass + + class MissingRequiredAttribute(MessageException): def __init__(self, attr, message=""): Exception.__init__(self, attr) @@ -180,7 +184,7 @@ class Message(object): try: func = getattr(self, "from_%s" % method) except AttributeError, err: - raise Exception("Unknown method (%s)" % method) + raise FormatError("Unknown serialization method (%s)" % method) else: return func(info, **kwargs) @@ -474,28 +478,29 @@ class Message(object): if isinstance(jso, basestring): jso = json.loads(jso) - if "jku" in header: - if not keyjar.find(header["jku"], jso["iss"]): - # This is really questionable + if keyjar: + if "jku" in header: + if not keyjar.find(header["jku"], jso["iss"]): + # This is really questionable + try: + if kwargs["trusting"]: + keyjar.add(jso["iss"], header["jku"]) + except KeyError: + pass + + if _kid: try: - if kwargs["trusting"]: - keyjar.add(jso["iss"], header["jku"]) + _key = keyjar.get_key_by_kid(_kid, jso["iss"]) + if _key: + key.append(_key) except KeyError: pass - if _kid: try: - _key = keyjar.get_key_by_kid(_kid, jso["iss"]) - if _key: - key.append(_key) + self._add_key(keyjar, kwargs["opponent_id"], key) except KeyError: pass - try: - self._add_key(keyjar, kwargs["opponent_id"], key) - except KeyError: - pass - if verify: if keyjar: for ent in ["iss", "aud", "client_id"]: @@ -901,7 +906,7 @@ def factory(msgtype): try: return MSG[msgtype] except KeyError: - raise Exception("Unknown message type: %s" % msgtype) + raise FormatError("Unknown message type: %s" % msgtype) if __name__ == "__main__": diff --git a/src/oic/oic/message.py b/src/oic/oic/message.py index 30cf687..dee0fc3 100644 --- a/src/oic/oic/message.py +++ b/src/oic/oic/message.py @@ -30,6 +30,14 @@ from jwkest import jws logger = logging.getLogger(__name__) +class AtHashError(VerificationError): + pass + + +class CHashError(VerificationError): + pass + + #noinspection PyUnusedLocal def json_ser(val, sformat=None, lev=0): return json.dumps(val) @@ -292,7 +300,7 @@ class AuthorizationResponse(message.AuthorizationResponse, assert idt["at_hash"] == jws.left_hash( self["access_token"], hfunc) except AssertionError: - raise VerificationError( + raise AtHashError( "Failed to verify access_token hash", idt) if "code" in self: @@ -304,7 +312,7 @@ class AuthorizationResponse(message.AuthorizationResponse, try: assert idt["c_hash"] == jws.left_hash(self["code"], hfunc) except AssertionError: - raise VerificationError("Failed to verify code hash", idt) + raise CHashError("Failed to verify code hash", idt) self["id_token"] = idt diff --git a/src/oic/utils/aes.py b/src/oic/utils/aes.py index 50dda3c..4241ec9 100644 --- a/src/oic/utils/aes.py +++ b/src/oic/utils/aes.py @@ -15,6 +15,10 @@ POSTFIX_MODE = { BLOCK_SIZE = 16 +class AESError(Exception): + pass + + def build_cipher(key, iv, alg="aes_128_cbc"): """ :param key: encryption key @@ -30,16 +34,16 @@ def build_cipher(key, iv, alg="aes_128_cbc"): assert len(iv) == AES.block_size if bits not in ["128", "192", "256"]: - raise Exception("Unsupported key length") + raise AESError("Unsupported key length") try: assert len(key) == int(bits) >> 3 except AssertionError: - raise Exception("Wrong Key length") + raise AESError("Wrong Key length") try: return AES.new(key, POSTFIX_MODE[cmode], iv), iv except KeyError: - raise Exception("Unsupported chaining mode") + raise AESError("Unsupported chaining mode") def encrypt(key, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7", diff --git a/src/oic/utils/authn/client.py b/src/oic/utils/authn/client.py index c0cc6f6..eb8c314 100644 --- a/src/oic/utils/authn/client.py +++ b/src/oic/utils/authn/client.py @@ -223,7 +223,7 @@ class BearerBody(ClientAuthnMethod): _ = kwargs["state"] except KeyError: if not self.cli.state: - raise Exception("Missing state specification") + raise AuthnFailure("Missing state specification") kwargs["state"] = self.cli.state cis["access_token"] = self.cli.get_token(**kwargs).access_token @@ -255,7 +255,7 @@ class JWSAuthnMethod(ClientAuthnMethod): except KeyError: algorithm = DEF_SIGN_ALG[entity] if not algorithm: - raise Exception("Missing algorithm specification") + raise AuthnFailure("Missing algorithm specification") return algorithm def get_signing_key(self, algorithm): diff --git a/src/oic/utils/authn/ldapc.py b/src/oic/utils/authn/ldapc.py index 5c8e330..764f1a6 100644 --- a/src/oic/utils/authn/ldapc.py +++ b/src/oic/utils/authn/ldapc.py @@ -1,4 +1,5 @@ import ldap +from oic.exception import PyoidcError from oic.utils.authn.user import UsernamePasswordMako @@ -10,6 +11,10 @@ SCOPE_MAP = { } +class LDAPCError(PyoidcError): + pass + + class LDAPAuthn(UsernamePasswordMako): def __init__(self, srv, ldapsrv, return_to, pattern, mako_template, template_lookup, ldap_user="", ldap_pwd="", @@ -52,7 +57,7 @@ class LDAPAuthn(UsernamePasswordMako): try: _pat = self.pattern["search"] except: - raise Exception("unknown pattern") + raise LDAPCError("unknown search pattern") else: args = { "filterstr": _pat["filterstr"] % user, diff --git a/src/oic/utils/authn/user.py b/src/oic/utils/authn/user.py index d8ca82a..adde265 100644 --- a/src/oic/utils/authn/user.py +++ b/src/oic/utils/authn/user.py @@ -6,6 +6,7 @@ from urllib import urlencode import urllib from urlparse import parse_qs from urlparse import urlsplit +from oic.exception import PyoidcError from oic.utils import aes from oic.utils.http_util import Response @@ -36,19 +37,23 @@ LOC = { } -class NoSuchAuthentication(Exception): +class NoSuchAuthentication(PyoidcError): pass -class TamperAllert(Exception): +class TamperAllert(PyoidcError): pass -class ToOld(Exception): +class ToOld(PyoidcError): pass -class FailedAuthentication(Exception): +class FailedAuthentication(PyoidcError): + pass + + +class InstantiationError(PyoidcError): pass @@ -121,7 +126,7 @@ class UserAuthnMethod(CookieDealer): def url_encode_params(params=None): if not isinstance(params, dict): - raise Exception("You must pass in a dictionary!") + raise InstantiationError("You must pass in a dictionary!") params_list = [] for k, v in params.items(): if isinstance(v, list): @@ -368,3 +373,23 @@ class SymKeyAuthn(UserAuthnMethod): raise FailedAuthentication() return {"uid": user} + + +class NoAuthn(UserAuthnMethod): + # Just for testing allows anyone it without authentication + + def __init__(self, srv, user): + UserAuthnMethod.__init__(self, srv) + self.user = user + + def authenticated_as(self, cookie=None, authorization="", **kwargs): + """ + + :param cookie: A HTTP Cookie + :param authorization: The HTTP Authorization header + :param args: extra args + :param kwargs: extra key word arguments + :return: + """ + + return {"uid": self.user} diff --git a/src/oic/utils/keyio.py b/src/oic/utils/keyio.py index 29a5e08..b84eb0b 100644 --- a/src/oic/utils/keyio.py +++ b/src/oic/utils/keyio.py @@ -2,7 +2,7 @@ import json import time from Crypto.PublicKey import RSA from cryptlib.ecc import NISTEllipticCurve -from oic.exception import MessageException +from oic.exception import MessageException, PyoidcError __author__ = 'rohe0002' @@ -27,11 +27,15 @@ logger = logging.getLogger(__name__) traceback.format_exception(*sys.exc_info()) -class UnknownKeyType(Exception): +class KeyIOError(PyoidcError): pass -class UpdateFailed(Exception): +class UnknownKeyType(KeyIOError): + pass + + +class UpdateFailed(KeyIOError): pass @@ -83,7 +87,7 @@ class KeyBundle(object): elif source == "": return else: - raise Exception("Unsupported source type: %s" % source) + raise KeyIOError("Unsupported source type: %s" % source) if not self.remote: # local file if self.fileformat == "jwk": @@ -271,7 +275,7 @@ def keybundle_from_local_file(filename, typ, usage): elif typ.lower() == "jwk": kb = KeyBundle(source=filename, fileformat="jwk", keyusage=usage) else: - raise Exception("Unsupported key type") + raise UnknownKeyType("Unsupported key type") return kb @@ -526,7 +530,7 @@ class KeyJar(object): if url.startswith(owner): return owner - raise Exception("No keys for '%s'" % url) + raise KeyIOError("No keys for '%s'" % url) def __str__(self): _res = {} diff --git a/src/oic/utils/time_util.py b/src/oic/utils/time_util.py index f1a4052..4d96733 100644 --- a/src/oic/utils/time_util.py +++ b/src/oic/utils/time_util.py @@ -31,6 +31,11 @@ TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" TIME_FORMAT_WITH_FRAGMENT = re.compile( "^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") + +class TimeUtilError(Exception): + pass + + # --------------------------------------------------------------------------- # I'm sure this is implemented somewhere else can't find it now though, so I # made an attempt. @@ -86,14 +91,14 @@ def parse_duration(duration): for code, typ in D_FORMAT: #print duration[index:], code if duration[index] == '-': - raise Exception("Negation not allowed on individual items") + raise TimeUtilError("Negation not allowed on individual items") if code == "T": if duration[index] == "T": index += 1 if index == len(duration): - raise Exception("Not allowed to end with 'T'") + raise TimeUtilError("Not allowed to end with 'T'") else: - raise Exception("Missing T") + raise TimeUtilError("Missing T") else: try: mod = duration[index:].index(code) @@ -104,9 +109,9 @@ def parse_duration(duration): try: dic[typ] = float(duration[index:index + mod]) except ValueError: - raise Exception("Not a float") + raise TimeUtilError("Not a float") else: - raise Exception( + raise TimeUtilError( "Fractions not allow on anything byt seconds") index = mod + index + 1 except ValueError: diff --git a/src/oic/utils/userinfo/distaggr.py b/src/oic/utils/userinfo/distaggr.py index defdfb7..38fa1de 100644 --- a/src/oic/utils/userinfo/distaggr.py +++ b/src/oic/utils/userinfo/distaggr.py @@ -1,5 +1,6 @@ import copy import logging +from oic.exception import MissingAttribute from oic.oic import OpenIDSchema from oic.oic.claims_provider import ClaimsClient @@ -116,7 +117,8 @@ class DistributedAggregatedUserInfo(UserInfo): pass if remaining: - raise Exception("Missing properties '%s'" % remaining) + raise MissingAttribute( + "Missing properties '%s'" % remaining) for srv, what in cpoints.items(): cc = self.oidcsrv.claims_clients[srv] diff --git a/src/oic/utils/webfinger.py b/src/oic/utils/webfinger.py index 8598c3a..50368dd 100644 --- a/src/oic/utils/webfinger.py +++ b/src/oic/utils/webfinger.py @@ -4,6 +4,7 @@ import logging import re from urllib import urlencode import urlparse +from oic.exception import PyoidcError import requests from oic.utils.time_util import in_a_while @@ -16,6 +17,9 @@ logger = logging.getLogger(__name__) WF_URL = "https://%s/.well-known/webfinger" OIC_ISSUER = "http://openid.net/specs/connect/1.0/issuer" +class WebFingerError(PyoidcError): + pass + class Base(object): c_param = {} @@ -243,7 +247,7 @@ class WebFinger(object): elif resource.startswith("device:"): host = resource.split(':')[1] else: - raise Exception("Unknown schema") + raise WebFingerError("Unknown schema") return "%s?%s" % (WF_URL % host, urlencode(info)) @@ -288,7 +292,7 @@ class WebFinger(object): elif rsp.status_code in [302, 301, 307]: return self.discovery_query(rsp.headers["location"]) else: - raise Exception(rsp.status_code) + raise WebFingerError(rsp.status_code) def response(self, subject, base): self.jrd = JRD()