Make all exceptions thrown by pyoidc code have a common ancestor.

This commit is contained in:
Roland Hedberg 2014-12-28 16:56:04 +01:00
parent 75a5293e5a
commit 35229fead4
11 changed files with 137 additions and 55 deletions

View File

@ -61,11 +61,31 @@ class ResponseError(PyoidcError):
pass
class TimeFormatError(Exception):
class TimeFormatError(PyoidcError):
pass
class CapabilitiesMisMatch(Exception):
class CapabilitiesMisMatch(PyoidcError):
pass
class MissingEndpoint(PyoidcError):
pass
class TokenError(PyoidcError):
pass
class GrantError(PyoidcError):
pass
class ParseError(PyoidcError):
pass
class OtherError(PyoidcError):
pass
@ -510,10 +530,10 @@ class Client(PBase):
try:
uri = getattr(self, endpoint)
except Exception:
raise Exception("No '%s' specified" % endpoint)
raise MissingEndpoint("No '%s' specified" % endpoint)
if not uri:
raise Exception("No '%s' specified" % endpoint)
raise MissingEndpoint("No '%s' specified" % endpoint)
return uri
@ -528,7 +548,7 @@ class Client(PBase):
try:
return self.grant[state]
except:
raise Exception("No grant found for state:'%s'" % state)
raise GrantError("No grant found for state:'%s'" % state)
def get_token(self, also_expired=False, **kwargs):
try:
@ -544,17 +564,17 @@ class Client(PBase):
try:
token = self.grant[kwargs["state"]].get_token("")
except KeyError:
raise Exception("No token found for scope")
raise TokenError("No token found for scope")
if token is None:
raise Exception("No suitable token found")
raise TokenError("No suitable token found")
if also_expired:
return token
elif token.is_valid():
return token
else:
raise ExpiredToken("Token has expired")
raise TokenError("Token has expired")
def construct_request(self, request, request_args=None, extra_args=None):
if request_args is None:
@ -696,7 +716,7 @@ class Client(PBase):
else:
kwargs["headers"] = header_ext
else:
raise Exception("Unsupported HTTP method: '%s'" % method)
raise UnSupported("Unsupported HTTP method: '%s'" % method)
return path, body, kwargs
@ -889,14 +909,14 @@ class Client(PBase):
pass
elif reqresp.status_code == 500:
logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
raise Exception("ERROR: Something went wrong: %s" % reqresp.text)
raise ParseError("ERROR: Something went wrong: %s" % reqresp.text)
elif reqresp.status_code in [400, 401]:
#expecting an error response
if issubclass(response, ErrorResponse):
pass
else:
logger.error("(%d) %s" % (reqresp.status_code, reqresp.text))
raise Exception("HTTP ERROR: %s [%s] on %s" % (
raise HTTP_ERROR("HTTP ERROR: %s [%s] on %s" % (
reqresp.text, reqresp.status_code, reqresp.url))
if body_type:
@ -904,7 +924,7 @@ class Client(PBase):
return self.parse_response(response, reqresp.text, body_type,
state, **kwargs)
else:
raise Exception("Didn't expect a response body")
raise OtherError("Didn't expect a response body")
else:
return reqresp

View File

@ -16,6 +16,10 @@ from oic.exception import MessageException
logger = logging.getLogger(__name__)
class FormatError(PyoidcError):
pass
class MissingRequiredAttribute(MessageException):
def __init__(self, attr, message=""):
Exception.__init__(self, attr)
@ -180,7 +184,7 @@ class Message(object):
try:
func = getattr(self, "from_%s" % method)
except AttributeError, err:
raise Exception("Unknown method (%s)" % method)
raise FormatError("Unknown serialization method (%s)" % method)
else:
return func(info, **kwargs)
@ -474,28 +478,29 @@ class Message(object):
if isinstance(jso, basestring):
jso = json.loads(jso)
if "jku" in header:
if not keyjar.find(header["jku"], jso["iss"]):
# This is really questionable
if keyjar:
if "jku" in header:
if not keyjar.find(header["jku"], jso["iss"]):
# This is really questionable
try:
if kwargs["trusting"]:
keyjar.add(jso["iss"], header["jku"])
except KeyError:
pass
if _kid:
try:
if kwargs["trusting"]:
keyjar.add(jso["iss"], header["jku"])
_key = keyjar.get_key_by_kid(_kid, jso["iss"])
if _key:
key.append(_key)
except KeyError:
pass
if _kid:
try:
_key = keyjar.get_key_by_kid(_kid, jso["iss"])
if _key:
key.append(_key)
self._add_key(keyjar, kwargs["opponent_id"], key)
except KeyError:
pass
try:
self._add_key(keyjar, kwargs["opponent_id"], key)
except KeyError:
pass
if verify:
if keyjar:
for ent in ["iss", "aud", "client_id"]:
@ -901,7 +906,7 @@ def factory(msgtype):
try:
return MSG[msgtype]
except KeyError:
raise Exception("Unknown message type: %s" % msgtype)
raise FormatError("Unknown message type: %s" % msgtype)
if __name__ == "__main__":

View File

@ -30,6 +30,14 @@ from jwkest import jws
logger = logging.getLogger(__name__)
class AtHashError(VerificationError):
pass
class CHashError(VerificationError):
pass
#noinspection PyUnusedLocal
def json_ser(val, sformat=None, lev=0):
return json.dumps(val)
@ -292,7 +300,7 @@ class AuthorizationResponse(message.AuthorizationResponse,
assert idt["at_hash"] == jws.left_hash(
self["access_token"], hfunc)
except AssertionError:
raise VerificationError(
raise AtHashError(
"Failed to verify access_token hash", idt)
if "code" in self:
@ -304,7 +312,7 @@ class AuthorizationResponse(message.AuthorizationResponse,
try:
assert idt["c_hash"] == jws.left_hash(self["code"], hfunc)
except AssertionError:
raise VerificationError("Failed to verify code hash", idt)
raise CHashError("Failed to verify code hash", idt)
self["id_token"] = idt

View File

@ -15,6 +15,10 @@ POSTFIX_MODE = {
BLOCK_SIZE = 16
class AESError(Exception):
pass
def build_cipher(key, iv, alg="aes_128_cbc"):
"""
:param key: encryption key
@ -30,16 +34,16 @@ def build_cipher(key, iv, alg="aes_128_cbc"):
assert len(iv) == AES.block_size
if bits not in ["128", "192", "256"]:
raise Exception("Unsupported key length")
raise AESError("Unsupported key length")
try:
assert len(key) == int(bits) >> 3
except AssertionError:
raise Exception("Wrong Key length")
raise AESError("Wrong Key length")
try:
return AES.new(key, POSTFIX_MODE[cmode], iv), iv
except KeyError:
raise Exception("Unsupported chaining mode")
raise AESError("Unsupported chaining mode")
def encrypt(key, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7",

View File

@ -223,7 +223,7 @@ class BearerBody(ClientAuthnMethod):
_ = kwargs["state"]
except KeyError:
if not self.cli.state:
raise Exception("Missing state specification")
raise AuthnFailure("Missing state specification")
kwargs["state"] = self.cli.state
cis["access_token"] = self.cli.get_token(**kwargs).access_token
@ -255,7 +255,7 @@ class JWSAuthnMethod(ClientAuthnMethod):
except KeyError:
algorithm = DEF_SIGN_ALG[entity]
if not algorithm:
raise Exception("Missing algorithm specification")
raise AuthnFailure("Missing algorithm specification")
return algorithm
def get_signing_key(self, algorithm):

View File

@ -1,4 +1,5 @@
import ldap
from oic.exception import PyoidcError
from oic.utils.authn.user import UsernamePasswordMako
@ -10,6 +11,10 @@ SCOPE_MAP = {
}
class LDAPCError(PyoidcError):
pass
class LDAPAuthn(UsernamePasswordMako):
def __init__(self, srv, ldapsrv, return_to, pattern, mako_template,
template_lookup, ldap_user="", ldap_pwd="",
@ -52,7 +57,7 @@ class LDAPAuthn(UsernamePasswordMako):
try:
_pat = self.pattern["search"]
except:
raise Exception("unknown pattern")
raise LDAPCError("unknown search pattern")
else:
args = {
"filterstr": _pat["filterstr"] % user,

View File

@ -6,6 +6,7 @@ from urllib import urlencode
import urllib
from urlparse import parse_qs
from urlparse import urlsplit
from oic.exception import PyoidcError
from oic.utils import aes
from oic.utils.http_util import Response
@ -36,19 +37,23 @@ LOC = {
}
class NoSuchAuthentication(Exception):
class NoSuchAuthentication(PyoidcError):
pass
class TamperAllert(Exception):
class TamperAllert(PyoidcError):
pass
class ToOld(Exception):
class ToOld(PyoidcError):
pass
class FailedAuthentication(Exception):
class FailedAuthentication(PyoidcError):
pass
class InstantiationError(PyoidcError):
pass
@ -121,7 +126,7 @@ class UserAuthnMethod(CookieDealer):
def url_encode_params(params=None):
if not isinstance(params, dict):
raise Exception("You must pass in a dictionary!")
raise InstantiationError("You must pass in a dictionary!")
params_list = []
for k, v in params.items():
if isinstance(v, list):
@ -368,3 +373,23 @@ class SymKeyAuthn(UserAuthnMethod):
raise FailedAuthentication()
return {"uid": user}
class NoAuthn(UserAuthnMethod):
# Just for testing allows anyone it without authentication
def __init__(self, srv, user):
UserAuthnMethod.__init__(self, srv)
self.user = user
def authenticated_as(self, cookie=None, authorization="", **kwargs):
"""
:param cookie: A HTTP Cookie
:param authorization: The HTTP Authorization header
:param args: extra args
:param kwargs: extra key word arguments
:return:
"""
return {"uid": self.user}

View File

@ -2,7 +2,7 @@ import json
import time
from Crypto.PublicKey import RSA
from cryptlib.ecc import NISTEllipticCurve
from oic.exception import MessageException
from oic.exception import MessageException, PyoidcError
__author__ = 'rohe0002'
@ -27,11 +27,15 @@ logger = logging.getLogger(__name__)
traceback.format_exception(*sys.exc_info())
class UnknownKeyType(Exception):
class KeyIOError(PyoidcError):
pass
class UpdateFailed(Exception):
class UnknownKeyType(KeyIOError):
pass
class UpdateFailed(KeyIOError):
pass
@ -83,7 +87,7 @@ class KeyBundle(object):
elif source == "":
return
else:
raise Exception("Unsupported source type: %s" % source)
raise KeyIOError("Unsupported source type: %s" % source)
if not self.remote: # local file
if self.fileformat == "jwk":
@ -271,7 +275,7 @@ def keybundle_from_local_file(filename, typ, usage):
elif typ.lower() == "jwk":
kb = KeyBundle(source=filename, fileformat="jwk", keyusage=usage)
else:
raise Exception("Unsupported key type")
raise UnknownKeyType("Unsupported key type")
return kb
@ -526,7 +530,7 @@ class KeyJar(object):
if url.startswith(owner):
return owner
raise Exception("No keys for '%s'" % url)
raise KeyIOError("No keys for '%s'" % url)
def __str__(self):
_res = {}

View File

@ -31,6 +31,11 @@ TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
TIME_FORMAT_WITH_FRAGMENT = re.compile(
"^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$")
class TimeUtilError(Exception):
pass
# ---------------------------------------------------------------------------
# I'm sure this is implemented somewhere else can't find it now though, so I
# made an attempt.
@ -86,14 +91,14 @@ def parse_duration(duration):
for code, typ in D_FORMAT:
#print duration[index:], code
if duration[index] == '-':
raise Exception("Negation not allowed on individual items")
raise TimeUtilError("Negation not allowed on individual items")
if code == "T":
if duration[index] == "T":
index += 1
if index == len(duration):
raise Exception("Not allowed to end with 'T'")
raise TimeUtilError("Not allowed to end with 'T'")
else:
raise Exception("Missing T")
raise TimeUtilError("Missing T")
else:
try:
mod = duration[index:].index(code)
@ -104,9 +109,9 @@ def parse_duration(duration):
try:
dic[typ] = float(duration[index:index + mod])
except ValueError:
raise Exception("Not a float")
raise TimeUtilError("Not a float")
else:
raise Exception(
raise TimeUtilError(
"Fractions not allow on anything byt seconds")
index = mod + index + 1
except ValueError:

View File

@ -1,5 +1,6 @@
import copy
import logging
from oic.exception import MissingAttribute
from oic.oic import OpenIDSchema
from oic.oic.claims_provider import ClaimsClient
@ -116,7 +117,8 @@ class DistributedAggregatedUserInfo(UserInfo):
pass
if remaining:
raise Exception("Missing properties '%s'" % remaining)
raise MissingAttribute(
"Missing properties '%s'" % remaining)
for srv, what in cpoints.items():
cc = self.oidcsrv.claims_clients[srv]

View File

@ -4,6 +4,7 @@ import logging
import re
from urllib import urlencode
import urlparse
from oic.exception import PyoidcError
import requests
from oic.utils.time_util import in_a_while
@ -16,6 +17,9 @@ logger = logging.getLogger(__name__)
WF_URL = "https://%s/.well-known/webfinger"
OIC_ISSUER = "http://openid.net/specs/connect/1.0/issuer"
class WebFingerError(PyoidcError):
pass
class Base(object):
c_param = {}
@ -243,7 +247,7 @@ class WebFinger(object):
elif resource.startswith("device:"):
host = resource.split(':')[1]
else:
raise Exception("Unknown schema")
raise WebFingerError("Unknown schema")
return "%s?%s" % (WF_URL % host, urlencode(info))
@ -288,7 +292,7 @@ class WebFinger(object):
elif rsp.status_code in [302, 301, 307]:
return self.discovery_query(rsp.headers["location"])
else:
raise Exception(rsp.status_code)
raise WebFingerError(rsp.status_code)
def response(self, subject, base):
self.jrd = JRD()