Make all exceptions thrown by pyoidc code have a common ancestor.
This commit is contained in:
parent
75a5293e5a
commit
35229fead4
|
@ -61,11 +61,31 @@ class ResponseError(PyoidcError):
|
|||
pass
|
||||
|
||||
|
||||
class TimeFormatError(Exception):
|
||||
class TimeFormatError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class CapabilitiesMisMatch(Exception):
|
||||
class CapabilitiesMisMatch(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingEndpoint(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class TokenError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class GrantError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class ParseError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class OtherError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -510,10 +530,10 @@ class Client(PBase):
|
|||
try:
|
||||
uri = getattr(self, endpoint)
|
||||
except Exception:
|
||||
raise Exception("No '%s' specified" % endpoint)
|
||||
raise MissingEndpoint("No '%s' specified" % endpoint)
|
||||
|
||||
if not uri:
|
||||
raise Exception("No '%s' specified" % endpoint)
|
||||
raise MissingEndpoint("No '%s' specified" % endpoint)
|
||||
|
||||
return uri
|
||||
|
||||
|
@ -528,7 +548,7 @@ class Client(PBase):
|
|||
try:
|
||||
return self.grant[state]
|
||||
except:
|
||||
raise Exception("No grant found for state:'%s'" % state)
|
||||
raise GrantError("No grant found for state:'%s'" % state)
|
||||
|
||||
def get_token(self, also_expired=False, **kwargs):
|
||||
try:
|
||||
|
@ -544,17 +564,17 @@ class Client(PBase):
|
|||
try:
|
||||
token = self.grant[kwargs["state"]].get_token("")
|
||||
except KeyError:
|
||||
raise Exception("No token found for scope")
|
||||
raise TokenError("No token found for scope")
|
||||
|
||||
if token is None:
|
||||
raise Exception("No suitable token found")
|
||||
raise TokenError("No suitable token found")
|
||||
|
||||
if also_expired:
|
||||
return token
|
||||
elif token.is_valid():
|
||||
return token
|
||||
else:
|
||||
raise ExpiredToken("Token has expired")
|
||||
raise TokenError("Token has expired")
|
||||
|
||||
def construct_request(self, request, request_args=None, extra_args=None):
|
||||
if request_args is None:
|
||||
|
@ -696,7 +716,7 @@ class Client(PBase):
|
|||
else:
|
||||
kwargs["headers"] = header_ext
|
||||
else:
|
||||
raise Exception("Unsupported HTTP method: '%s'" % method)
|
||||
raise UnSupported("Unsupported HTTP method: '%s'" % method)
|
||||
|
||||
return path, body, kwargs
|
||||
|
||||
|
@ -889,14 +909,14 @@ class Client(PBase):
|
|||
pass
|
||||
elif reqresp.status_code == 500:
|
||||
logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
|
||||
raise Exception("ERROR: Something went wrong: %s" % reqresp.text)
|
||||
raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
|
||||
elif reqresp.status_code in [400, 401]:
|
||||
#expecting an error response
|
||||
if issubclass(response, ErrorResponse):
|
||||
pass
|
||||
else:
|
||||
logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
|
||||
raise Exception("HTTP ERROR: %s [%s] on %s" % (
|
||||
raise HTTP_ERROR("HTTP ERROR: %s [%s] on %s" % (
|
||||
reqresp.text, reqresp.status_code, reqresp.url))
|
||||
|
||||
if body_type:
|
||||
|
@ -904,7 +924,7 @@ class Client(PBase):
|
|||
return self.parse_response(response, reqresp.text, body_type,
|
||||
state, **kwargs)
|
||||
else:
|
||||
raise Exception("Didn't expect a response body")
|
||||
raise OtherError("Didn't expect a response body")
|
||||
else:
|
||||
return reqresp
|
||||
|
||||
|
|
|
@ -16,6 +16,10 @@ from oic.exception import MessageException
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FormatError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingRequiredAttribute(MessageException):
|
||||
def __init__(self, attr, message=""):
|
||||
Exception.__init__(self, attr)
|
||||
|
@ -180,7 +184,7 @@ class Message(object):
|
|||
try:
|
||||
func = getattr(self, "from_%s" % method)
|
||||
except AttributeError, err:
|
||||
raise Exception("Unknown method (%s)" % method)
|
||||
raise FormatError("Unknown serialization method (%s)" % method)
|
||||
else:
|
||||
return func(info, **kwargs)
|
||||
|
||||
|
@ -474,28 +478,29 @@ class Message(object):
|
|||
if isinstance(jso, basestring):
|
||||
jso = json.loads(jso)
|
||||
|
||||
if "jku" in header:
|
||||
if not keyjar.find(header["jku"], jso["iss"]):
|
||||
# This is really questionable
|
||||
if keyjar:
|
||||
if "jku" in header:
|
||||
if not keyjar.find(header["jku"], jso["iss"]):
|
||||
# This is really questionable
|
||||
try:
|
||||
if kwargs["trusting"]:
|
||||
keyjar.add(jso["iss"], header["jku"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if _kid:
|
||||
try:
|
||||
if kwargs["trusting"]:
|
||||
keyjar.add(jso["iss"], header["jku"])
|
||||
_key = keyjar.get_key_by_kid(_kid, jso["iss"])
|
||||
if _key:
|
||||
key.append(_key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if _kid:
|
||||
try:
|
||||
_key = keyjar.get_key_by_kid(_kid, jso["iss"])
|
||||
if _key:
|
||||
key.append(_key)
|
||||
self._add_key(keyjar, kwargs["opponent_id"], key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._add_key(keyjar, kwargs["opponent_id"], key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if verify:
|
||||
if keyjar:
|
||||
for ent in ["iss", "aud", "client_id"]:
|
||||
|
@ -901,7 +906,7 @@ def factory(msgtype):
|
|||
try:
|
||||
return MSG[msgtype]
|
||||
except KeyError:
|
||||
raise Exception("Unknown message type: %s" % msgtype)
|
||||
raise FormatError("Unknown message type: %s" % msgtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -30,6 +30,14 @@ from jwkest import jws
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AtHashError(VerificationError):
|
||||
pass
|
||||
|
||||
|
||||
class CHashError(VerificationError):
|
||||
pass
|
||||
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def json_ser(val, sformat=None, lev=0):
|
||||
return json.dumps(val)
|
||||
|
@ -292,7 +300,7 @@ class AuthorizationResponse(message.AuthorizationResponse,
|
|||
assert idt["at_hash"] == jws.left_hash(
|
||||
self["access_token"], hfunc)
|
||||
except AssertionError:
|
||||
raise VerificationError(
|
||||
raise AtHashError(
|
||||
"Failed to verify access_token hash", idt)
|
||||
|
||||
if "code" in self:
|
||||
|
@ -304,7 +312,7 @@ class AuthorizationResponse(message.AuthorizationResponse,
|
|||
try:
|
||||
assert idt["c_hash"] == jws.left_hash(self["code"], hfunc)
|
||||
except AssertionError:
|
||||
raise VerificationError("Failed to verify code hash", idt)
|
||||
raise CHashError("Failed to verify code hash", idt)
|
||||
|
||||
self["id_token"] = idt
|
||||
|
||||
|
|
|
@ -15,6 +15,10 @@ POSTFIX_MODE = {
|
|||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
class AESError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def build_cipher(key, iv, alg="aes_128_cbc"):
|
||||
"""
|
||||
:param key: encryption key
|
||||
|
@ -30,16 +34,16 @@ def build_cipher(key, iv, alg="aes_128_cbc"):
|
|||
assert len(iv) == AES.block_size
|
||||
|
||||
if bits not in ["128", "192", "256"]:
|
||||
raise Exception("Unsupported key length")
|
||||
raise AESError("Unsupported key length")
|
||||
try:
|
||||
assert len(key) == int(bits) >> 3
|
||||
except AssertionError:
|
||||
raise Exception("Wrong Key length")
|
||||
raise AESError("Wrong Key length")
|
||||
|
||||
try:
|
||||
return AES.new(key, POSTFIX_MODE[cmode], iv), iv
|
||||
except KeyError:
|
||||
raise Exception("Unsupported chaining mode")
|
||||
raise AESError("Unsupported chaining mode")
|
||||
|
||||
|
||||
def encrypt(key, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7",
|
||||
|
|
|
@ -223,7 +223,7 @@ class BearerBody(ClientAuthnMethod):
|
|||
_ = kwargs["state"]
|
||||
except KeyError:
|
||||
if not self.cli.state:
|
||||
raise Exception("Missing state specification")
|
||||
raise AuthnFailure("Missing state specification")
|
||||
kwargs["state"] = self.cli.state
|
||||
|
||||
cis["access_token"] = self.cli.get_token(**kwargs).access_token
|
||||
|
@ -255,7 +255,7 @@ class JWSAuthnMethod(ClientAuthnMethod):
|
|||
except KeyError:
|
||||
algorithm = DEF_SIGN_ALG[entity]
|
||||
if not algorithm:
|
||||
raise Exception("Missing algorithm specification")
|
||||
raise AuthnFailure("Missing algorithm specification")
|
||||
return algorithm
|
||||
|
||||
def get_signing_key(self, algorithm):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import ldap
|
||||
from oic.exception import PyoidcError
|
||||
|
||||
from oic.utils.authn.user import UsernamePasswordMako
|
||||
|
||||
|
@ -10,6 +11,10 @@ SCOPE_MAP = {
|
|||
}
|
||||
|
||||
|
||||
class LDAPCError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class LDAPAuthn(UsernamePasswordMako):
|
||||
def __init__(self, srv, ldapsrv, return_to, pattern, mako_template,
|
||||
template_lookup, ldap_user="", ldap_pwd="",
|
||||
|
@ -52,7 +57,7 @@ class LDAPAuthn(UsernamePasswordMako):
|
|||
try:
|
||||
_pat = self.pattern["search"]
|
||||
except:
|
||||
raise Exception("unknown pattern")
|
||||
raise LDAPCError("unknown search pattern")
|
||||
else:
|
||||
args = {
|
||||
"filterstr": _pat["filterstr"] % user,
|
||||
|
|
|
@ -6,6 +6,7 @@ from urllib import urlencode
|
|||
import urllib
|
||||
from urlparse import parse_qs
|
||||
from urlparse import urlsplit
|
||||
from oic.exception import PyoidcError
|
||||
|
||||
from oic.utils import aes
|
||||
from oic.utils.http_util import Response
|
||||
|
@ -36,19 +37,23 @@ LOC = {
|
|||
}
|
||||
|
||||
|
||||
class NoSuchAuthentication(Exception):
|
||||
class NoSuchAuthentication(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class TamperAllert(Exception):
|
||||
class TamperAllert(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class ToOld(Exception):
|
||||
class ToOld(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class FailedAuthentication(Exception):
|
||||
class FailedAuthentication(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class InstantiationError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -121,7 +126,7 @@ class UserAuthnMethod(CookieDealer):
|
|||
|
||||
def url_encode_params(params=None):
|
||||
if not isinstance(params, dict):
|
||||
raise Exception("You must pass in a dictionary!")
|
||||
raise InstantiationError("You must pass in a dictionary!")
|
||||
params_list = []
|
||||
for k, v in params.items():
|
||||
if isinstance(v, list):
|
||||
|
@ -368,3 +373,23 @@ class SymKeyAuthn(UserAuthnMethod):
|
|||
raise FailedAuthentication()
|
||||
|
||||
return {"uid": user}
|
||||
|
||||
|
||||
class NoAuthn(UserAuthnMethod):
|
||||
# Just for testing allows anyone it without authentication
|
||||
|
||||
def __init__(self, srv, user):
|
||||
UserAuthnMethod.__init__(self, srv)
|
||||
self.user = user
|
||||
|
||||
def authenticated_as(self, cookie=None, authorization="", **kwargs):
|
||||
"""
|
||||
|
||||
:param cookie: A HTTP Cookie
|
||||
:param authorization: The HTTP Authorization header
|
||||
:param args: extra args
|
||||
:param kwargs: extra key word arguments
|
||||
:return:
|
||||
"""
|
||||
|
||||
return {"uid": self.user}
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import time
|
||||
from Crypto.PublicKey import RSA
|
||||
from cryptlib.ecc import NISTEllipticCurve
|
||||
from oic.exception import MessageException
|
||||
from oic.exception import MessageException, PyoidcError
|
||||
|
||||
|
||||
__author__ = 'rohe0002'
|
||||
|
@ -27,11 +27,15 @@ logger = logging.getLogger(__name__)
|
|||
traceback.format_exception(*sys.exc_info())
|
||||
|
||||
|
||||
class UnknownKeyType(Exception):
|
||||
class KeyIOError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateFailed(Exception):
|
||||
class UnknownKeyType(KeyIOError):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateFailed(KeyIOError):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -83,7 +87,7 @@ class KeyBundle(object):
|
|||
elif source == "":
|
||||
return
|
||||
else:
|
||||
raise Exception("Unsupported source type: %s" % source)
|
||||
raise KeyIOError("Unsupported source type: %s" % source)
|
||||
|
||||
if not self.remote: # local file
|
||||
if self.fileformat == "jwk":
|
||||
|
@ -271,7 +275,7 @@ def keybundle_from_local_file(filename, typ, usage):
|
|||
elif typ.lower() == "jwk":
|
||||
kb = KeyBundle(source=filename, fileformat="jwk", keyusage=usage)
|
||||
else:
|
||||
raise Exception("Unsupported key type")
|
||||
raise UnknownKeyType("Unsupported key type")
|
||||
|
||||
return kb
|
||||
|
||||
|
@ -526,7 +530,7 @@ class KeyJar(object):
|
|||
if url.startswith(owner):
|
||||
return owner
|
||||
|
||||
raise Exception("No keys for '%s'" % url)
|
||||
raise KeyIOError("No keys for '%s'" % url)
|
||||
|
||||
def __str__(self):
|
||||
_res = {}
|
||||
|
|
|
@ -31,6 +31,11 @@ TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
|
|||
TIME_FORMAT_WITH_FRAGMENT = re.compile(
|
||||
"^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$")
|
||||
|
||||
|
||||
class TimeUtilError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# I'm sure this is implemented somewhere else can't find it now though, so I
|
||||
# made an attempt.
|
||||
|
@ -86,14 +91,14 @@ def parse_duration(duration):
|
|||
for code, typ in D_FORMAT:
|
||||
#print duration[index:], code
|
||||
if duration[index] == '-':
|
||||
raise Exception("Negation not allowed on individual items")
|
||||
raise TimeUtilError("Negation not allowed on individual items")
|
||||
if code == "T":
|
||||
if duration[index] == "T":
|
||||
index += 1
|
||||
if index == len(duration):
|
||||
raise Exception("Not allowed to end with 'T'")
|
||||
raise TimeUtilError("Not allowed to end with 'T'")
|
||||
else:
|
||||
raise Exception("Missing T")
|
||||
raise TimeUtilError("Missing T")
|
||||
else:
|
||||
try:
|
||||
mod = duration[index:].index(code)
|
||||
|
@ -104,9 +109,9 @@ def parse_duration(duration):
|
|||
try:
|
||||
dic[typ] = float(duration[index:index + mod])
|
||||
except ValueError:
|
||||
raise Exception("Not a float")
|
||||
raise TimeUtilError("Not a float")
|
||||
else:
|
||||
raise Exception(
|
||||
raise TimeUtilError(
|
||||
"Fractions not allow on anything byt seconds")
|
||||
index = mod + index + 1
|
||||
except ValueError:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import copy
|
||||
import logging
|
||||
from oic.exception import MissingAttribute
|
||||
|
||||
from oic.oic import OpenIDSchema
|
||||
from oic.oic.claims_provider import ClaimsClient
|
||||
|
@ -116,7 +117,8 @@ class DistributedAggregatedUserInfo(UserInfo):
|
|||
pass
|
||||
|
||||
if remaining:
|
||||
raise Exception("Missing properties '%s'" % remaining)
|
||||
raise MissingAttribute(
|
||||
"Missing properties '%s'" % remaining)
|
||||
|
||||
for srv, what in cpoints.items():
|
||||
cc = self.oidcsrv.claims_clients[srv]
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import re
|
||||
from urllib import urlencode
|
||||
import urlparse
|
||||
from oic.exception import PyoidcError
|
||||
import requests
|
||||
|
||||
from oic.utils.time_util import in_a_while
|
||||
|
@ -16,6 +17,9 @@ logger = logging.getLogger(__name__)
|
|||
WF_URL = "https://%s/.well-known/webfinger"
|
||||
OIC_ISSUER = "http://openid.net/specs/connect/1.0/issuer"
|
||||
|
||||
class WebFingerError(PyoidcError):
|
||||
pass
|
||||
|
||||
|
||||
class Base(object):
|
||||
c_param = {}
|
||||
|
@ -243,7 +247,7 @@ class WebFinger(object):
|
|||
elif resource.startswith("device:"):
|
||||
host = resource.split(':')[1]
|
||||
else:
|
||||
raise Exception("Unknown schema")
|
||||
raise WebFingerError("Unknown schema")
|
||||
|
||||
return "%s?%s" % (WF_URL % host, urlencode(info))
|
||||
|
||||
|
@ -288,7 +292,7 @@ class WebFinger(object):
|
|||
elif rsp.status_code in [302, 301, 307]:
|
||||
return self.discovery_query(rsp.headers["location"])
|
||||
else:
|
||||
raise Exception(rsp.status_code)
|
||||
raise WebFingerError(rsp.status_code)
|
||||
|
||||
def response(self, subject, base):
|
||||
self.jrd = JRD()
|
||||
|
|
Reference in New Issue