From 3a5d8d42229cd4ef0f126210e5abbac41999aaa9 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 10 Jan 2015 10:10:12 +0100 Subject: [PATCH] Make all exceptions thrown by pyoidc code have a common ancestor. Also, a missing key for verifying a signature is not the same thing as a faulty signature. --- src/oic/oauth2/message.py | 31 ++++++++++++++++++++++++------- src/oic/oic/message.py | 16 ++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/oic/oauth2/message.py b/src/oic/oauth2/message.py index 562cd51..2ed4565 100644 --- a/src/oic/oauth2/message.py +++ b/src/oic/oauth2/message.py @@ -33,6 +33,10 @@ class MissingRequiredValue(MessageException): pass +class MissingSigningKey(PyoidcError): + pass + + class TooManyValues(MessageException): pass @@ -53,6 +57,15 @@ class SchemeError(MessageException): pass +class ParameterError(MessageException): + pass + + +class NotAllowedValue(MessageException): + pass + + + ERRTXT = "On '%s': %s" @@ -240,7 +253,7 @@ class Message(object): try: self._dict[key] = typ(val[0]) except KeyError: - raise ValueError() + raise ParameterError(key) else: raise TooManyValues @@ -517,6 +530,10 @@ class Message(object): else: self._add_key(keyjar, jso[ent], key) + if "alg" in header and header["alg"] != "none": + if not key: + raise MissingSigningKey() + _jws.verify_compact(txt, key) except Exception: raise @@ -530,18 +547,18 @@ class Message(object): def _type_check(self, typ, _allowed, val, na=False): if typ is basestring: if val not in _allowed: - raise ValueError("Not allowed value '%s'" % val) + raise NotAllowedValue(val) elif typ is int: if val not in _allowed: - raise ValueError("Not allowed value '%s'" % val) + raise NotAllowedValue(val) elif isinstance(typ, list): if isinstance(val, list): # _typ = typ[0] for item in val: if item not in _allowed: - raise ValueError("Not allowed value '%s'" % val) + raise NotAllowedValue(val) elif val is None and na is False: - raise ValueError("Not allowed value '%s'" % val) + raise NotAllowedValue(val) # noinspection PyUnusedLocal def verify(self, **kwargs): @@ -584,7 +601,7 @@ class Message(object): except ValueError: pass if _ityp is None: - raise ValueError("Not allowed value '%s'" % val) + raise NotAllowedValue(val) else: self._type_check(typ, _allowed[attribute], val, na) @@ -659,7 +676,7 @@ class Message(object): for key, val in item.items(): self._dict[key] = val else: - raise ValueError("Wrong type of value") + raise ValueError("Can't update message using: '%s'" % (item,)) def to_jwe(self, keys, enc, alg, lev=0): """ diff --git a/src/oic/oic/message.py b/src/oic/oic/message.py index dee0fc3..4c966a7 100644 --- a/src/oic/oic/message.py +++ b/src/oic/oic/message.py @@ -12,9 +12,13 @@ from oic.oauth2 import message from oic.oauth2 import MissingRequiredValue from oic.oauth2 import MissingRequiredAttribute from oic.oauth2 import VerificationError -from oic.exception import InvalidRequest, NotForMe +from oic.exception import InvalidRequest +from oic.exception import NotForMe +from oic.exception import MessageException from oic.exception import PyoidcError -from oic.oauth2.message import Message, SchemeError +from oic.oauth2.message import Message +from oic.oauth2.message import SchemeError +from oic.oauth2.message import NotAllowedValue from oic.oauth2.message import REQUIRED_LIST_OF_SP_SEP_STRINGS from oic.oauth2.message import SINGLE_OPTIONAL_JSON from oic.oauth2.message import SINGLE_OPTIONAL_STRING @@ -104,7 +108,7 @@ def msg_ser(inst, sformat, lev=0): elif isinstance(inst, dict): res = inst else: - raise ValueError("%s" % type(inst)) + raise MessageException("Wrong type: %s" % type(inst)) else: raise PyoidcError("Unknown sformat", inst) @@ -119,7 +123,7 @@ def msg_ser_json(inst, sformat="json", lev=0): elif isinstance(inst, dict): res = inst else: - raise ValueError("%s" % type(inst), inst) + raise MessageException("Wrong type: %s" % type(inst)) else: sformat = "json" if isinstance(inst, dict) or isinstance(inst, Message): @@ -157,7 +161,7 @@ def claims_ser(val, sformat="urlencoded", lev=0): if isinstance(item, dict): res = item else: - raise ValueError("%s" % type(item)) + raise MessageException("Wrong type: %s" % type(item)) else: raise PyoidcError("Unknown sformat: %s" % sformat, val) @@ -214,7 +218,7 @@ for char in ['\x21', ('\x23', '\x5b'), ('\x5d', '\x7E')]: def check_char_set(string, allowed): for c in string: if c not in allowed: - raise ValueError("'%c' not in the allowed character set" % c) + raise NotAllowedValue("'%c' not in the allowed character set" % c) # -----------------------------------------------------------------------------