Make sure I have the necessary keys available.

Added some tests.
This commit is contained in:
Roland Hedberg 2015-01-10 13:28:24 +01:00
parent 8493602411
commit c2ef2eecb3
5 changed files with 495 additions and 54 deletions

View File

@ -821,11 +821,13 @@ class Client(PBase):
elif resp.only_extras():
resp = None
else:
if "key" not in kwargs and "keyjar" not in kwargs:
kwargs["keyjar"] = self.keyjar
verf = resp.verify(**kwargs)
if not verf:
raise PyoidcError("Verification of the response failed")
if resp.type() == "AuthorizationResponse" and \
"scope" not in resp:
"scope" not in resp:
try:
resp["scope"] = kwargs["scope"]
except KeyError:
@ -950,7 +952,7 @@ class Client(PBase):
except Exception:
raise
if not "keyjar" in kwargs:
if "keyjar" not in kwargs:
kwargs["keyjar"] = self.keyjar
return self.parse_request_response(resp, response, body_type, state,

View File

@ -311,11 +311,41 @@ class Consumer(Client):
return sid, location
def _parse_authz(self, query="", **kwargs):
_log_info = logger.info
# Might be an error response
_log_info("Expect Authorization Response")
aresp = self.parse_response(AuthorizationResponse,
info=query,
sformat="urlencoded",
keyjar=self.keyjar)
if aresp.type() == "ErrorResponse":
_log_info("ErrorResponse: %s" % aresp)
raise AuthzError(aresp.error, aresp)
_log_info("Aresp: %s" % aresp)
_state = aresp["state"]
try:
self.update(_state)
except KeyError:
raise UnknownState(_state, aresp)
self.redirect_uris = [self.sdb[_state]["redirect_uris"]]
return aresp, _state
#noinspection PyUnusedLocal
def parse_authz(self, query="", **kwargs):
"""
This is where we get redirect back to after authorization at the
authorization server has happened.
Couple of cases
["code"]
["code", "token"]
["code", "id_token", "token"]
["id_token"]
["id_token", "token"]
["token"]
:return: A AccessTokenResponse instance
"""
@ -329,25 +359,7 @@ class Consumer(Client):
_log_info("response: %s" % query)
if "code" in self.config["response_type"]:
# Might be an error response
_log_info("Expect Authorization Response")
aresp = self.parse_response(AuthorizationResponse,
info=query,
sformat="urlencoded",
keyjar=self.keyjar)
if aresp.type() == "ErrorResponse":
_log_info("ErrorResponse: %s" % aresp)
raise AuthzError(aresp.error, aresp)
_log_info("Aresp: %s" % aresp)
_state = aresp["state"]
try:
self.update(_state)
except KeyError:
raise UnknownState(_state, aresp)
self.redirect_uris = [self.sdb[_state]["redirect_uris"]]
aresp, _state = self._parse_authz(query, **kwargs)
# May have token and id_token information too
if "access_token" in aresp:
@ -366,7 +378,7 @@ class Consumer(Client):
idt = None
return aresp, atr, idt
else: # implicit flow
elif "token" in self.config["response_type"]: # implicit flow
_log_info("Expect Access Token Response")
atr = self.parse_response(AccessTokenResponse, info=query,
sformat="urlencoded",
@ -376,6 +388,14 @@ class Consumer(Client):
idt = None
return None, atr, idt
else: # only id_token
aresp, _state = self._parse_authz(query, **kwargs)
try:
idt = aresp["id_token"]
except KeyError:
idt = None
return None, None, idt
def complete(self, state):
"""

317
tests/mitmsrv.py Normal file
View File

@ -0,0 +1,317 @@
#!/usr/bin/env python
from urlparse import parse_qs
from jwkest.jws import alg2keytype
from oic.oauth2 import rndstr
from oic.oauth2.message import by_schema
from oic.oic import Server
from oic.oic.message import *
from oic.utils.sdb import SessionDB, AuthnEvent
from oic.utils.time_util import utc_time_sans_frac
from oic.utils.webfinger import WebFinger
__author__ = 'rohe0002'
class Response():
def __init__(self, base=None):
self.status_code = 200
if base:
for key, val in base.items():
self.__setitem__(key, val)
def __setitem__(self, key, value):
setattr(self, key, value)
def __getitem__(self, item):
return getattr(self, item)
ENDPOINT = {
"authorization_endpoint": "/authorization",
"token_endpoint": "/token",
"user_info_endpoint": "/userinfo",
"check_session_endpoint": "/check_session",
"refresh_session_endpoint": "/refresh_session",
"end_session_endpoint": "/end_session",
"registration_endpoint": "/registration",
"discovery_endpoint": "/discovery",
"register_endpoint": "/register"
}
class MITMServer(Server):
def __init__(self, name=""):
Server.__init__(self)
self.sdb = SessionDB(name)
self.name = name
self.client = {}
self.registration_expires_in = 3600
self.host = ""
self.webfinger = WebFinger()
self.userinfo_signed_response_alg = ""
# noinspection PyUnusedLocal
def http_request(self, path, method="GET", **kwargs):
part = urlparse(path)
path = part[2]
query = part[4]
self.host = "%s://%s" % (part.scheme, part.netloc)
response = Response
response.status_code = 500
response.text = ""
if path == ENDPOINT["authorization_endpoint"]:
assert method == "GET"
response = self.authorization_endpoint(query)
elif path == ENDPOINT["token_endpoint"]:
assert method == "POST"
response = self.token_endpoint(kwargs["data"])
elif path == ENDPOINT["user_info_endpoint"]:
assert method == "POST"
response = self.userinfo_endpoint(kwargs["data"])
elif path == ENDPOINT["refresh_session_endpoint"]:
assert method == "GET"
response = self.refresh_session_endpoint(query)
elif path == ENDPOINT["check_session_endpoint"]:
assert method == "GET"
response = self.check_session_endpoint(query)
elif path == ENDPOINT["end_session_endpoint"]:
assert method == "GET"
response = self.end_session_endpoint(query)
elif path == ENDPOINT["registration_endpoint"]:
if method == "POST":
response = self.registration_endpoint(kwargs["data"])
elif path == "/.well-known/webfinger":
assert method == "GET"
qdict = parse_qs(query)
response.status_code = 200
response.text = self.webfinger.response(qdict["resource"][0],
"%s/" % self.name)
elif path == "/.well-known/openid-configuration":
assert method == "GET"
response = self.openid_conf()
return response
def authorization_endpoint(self, query):
req = self.parse_authorization_request(query=query)
aevent = AuthnEvent("user", authn_info="acr")
sid = self.sdb.create_authz_session(aevent, areq=req)
_ = self.sdb.do_sub(sid)
_info = self.sdb[sid]
if "code" in req["response_type"]:
if "token" in req["response_type"]:
grant = _info["code"]
_dict = self.sdb.upgrade_to_token(grant)
_dict["oauth_state"] = "authz",
_dict = by_schema(AuthorizationResponse(), **_dict)
resp = AuthorizationResponse(**_dict)
# resp.code = grant
else:
_state = req["state"]
resp = AuthorizationResponse(state=_state,
code=_info["code"])
else: # "implicit" in req.response_type:
grant = _info["code"]
params = AccessTokenResponse.c_param.keys()
if "token" in req["response_type"]:
_dict = dict([(k, v) for k, v in
self.sdb.upgrade_to_token(grant).items() if k in
params])
try:
del _dict["refresh_token"]
except KeyError:
pass
else:
_dict = {"state": req["state"]}
if "id_token" in req["response_type"]:
_idt = self.make_id_token(_info, issuer=self.name)
alg = "RS256"
ckey = self.keyjar.get_signing_key(alg2keytype(alg),
_info["client_id"])
_signed_jwt = _idt.to_jwt(key=ckey, algorithm=alg)
p = _signed_jwt.split(".")
p[2] = "aaa"
_dict["id_token"] = ".".join(p)
resp = AuthorizationResponse(**_dict)
location = resp.request(req["redirect_uri"])
response = Response()
response.headers = {"location": location}
response.status_code = 302
response.text = ""
return response
def token_endpoint(self, data):
if "grant_type=refresh_token" in data:
req = self.parse_refresh_token_request(body=data)
_info = self.sdb.refresh_token(req["refresh_token"])
elif "grant_type=authorization_code":
req = self.parse_token_request(body=data)
_info = self.sdb.upgrade_to_token(req["code"])
else:
response = TokenErrorResponse(error="unsupported_grant_type")
return response, ""
resp = AccessTokenResponse(**by_schema(AccessTokenResponse, **_info))
response = Response()
response.headers = {"content-type": "application/json"}
response.text = resp.to_json()
return response
def userinfo_endpoint(self, data):
_ = self.parse_user_info_request(data)
_info = {
"sub": "melgar",
"name": "Melody Gardot",
"nickname": "Mel",
"email": "mel@example.com",
"verified": True,
}
resp = OpenIDSchema(**_info)
response = Response()
if self.userinfo_signed_response_alg:
alg = self.userinfo_signed_response_alg
response.headers = {"content-type": "application/jwt"}
key = self.keyjar.get_signing_key(alg2keytype(alg), "", alg=alg)
response.text = resp.to_jwt(key, alg)
else:
response.headers = {"content-type": "application/json"}
response.text = resp.to_json()
return response
def registration_endpoint(self, data):
try:
req = self.parse_registration_request(data, "json")
except ValueError:
req = self.parse_registration_request(data)
client_secret = rndstr()
expires = utc_time_sans_frac() + self.registration_expires_in
kwargs = {}
if "client_id" not in req:
client_id = rndstr(10)
registration_access_token = rndstr(20)
_client_info = req.to_dict()
kwargs.update(_client_info)
_client_info.update({
"client_secret": client_secret,
"info": req.to_dict(),
"expires": expires,
"registration_access_token": registration_access_token,
"registration_client_uri": "register_endpoint"
})
self.client[client_id] = _client_info
kwargs["registration_access_token"] = registration_access_token
kwargs["registration_client_uri"] = "register_endpoint"
try:
del kwargs["operation"]
except KeyError:
pass
else:
client_id = req.client_id
_cinfo = self.client[req.client_id]
_cinfo["info"].update(req.to_dict())
_cinfo["client_secret"] = client_secret
_cinfo["expires"] = expires
resp = RegistrationResponse(client_id=client_id,
client_secret=client_secret,
client_secret_expires_at=expires,
**kwargs)
response = Response()
response.headers = {"content-type": "application/json"}
response.text = resp.to_json()
return response
def check_session_endpoint(self, query):
try:
idtoken = self.parse_check_session_request(query=query)
except Exception:
raise
response = Response()
response.text = idtoken.to_json()
response.headers = {"content-type": "application/json"}
return response
# noinspection PyUnusedLocal
def refresh_session_endpoint(self, query):
try:
req = self.parse_refresh_session_request(query=query)
except Exception:
raise
resp = RegistrationResponse(client_id="anonymous",
client_secret="hemligt")
response = Response()
response.headers = {"content-type": "application/json"}
response.text = resp.to_json()
return response
def end_session_endpoint(self, query):
try:
req = self.parse_end_session_request(query=query)
except Exception:
raise
# redirect back
resp = EndSessionResponse(state=req["state"])
url = resp.request(req["redirect_url"])
response = Response()
response.headers = {"location": url}
response.status_code = 302 # redirect
response.text = ""
return response
# noinspection PyUnusedLocal
@staticmethod
def add_credentials(user, passwd):
return
def openid_conf(self):
endpoint = {}
for point, path in ENDPOINT.items():
endpoint[point] = "%s%s" % (self.host, path)
signing_algs = jws.SIGNER_ALGS.keys()
resp = ProviderConfigurationResponse(
issuer=self.name,
scopes_supported=["openid", "profile", "email", "address"],
identifiers_supported=["public", "PPID"],
flows_supported=["code", "token", "code token", "id_token",
"code id_token", "token id_token"],
subject_types_supported=["pairwise", "public"],
response_types_supported=["code", "token", "id_token",
"code token", "code id_token",
"token id_token", "code token id_token"],
jwks_uri="http://example.com/oidc/jwks",
id_token_signing_alg_values_supported=signing_algs,
grant_types_supported=["authorization_code", "implicit"],
**endpoint)
response = Response()
response.headers = {"content-type": "application/json"}
response.text = resp.to_json()
return response

View File

@ -1,8 +1,12 @@
import json
import os
import shutil
import tempfile
from jwkest import BadSignature
from jwkest.jwk import SYMKey
from oic.oauth2.message import MissingSigningKey
from oic.oic.message import AccessTokenResponse, AuthorizationResponse
from oic.oic.message import AccessTokenResponse, AuthorizationResponse, IdToken
from oic.utils.keyio import KeyBundle, keybundle_from_local_file
from oic.utils.keyio import KeyJar
@ -19,6 +23,7 @@ from oic.utils.time_util import utc_time_sans_frac
from oic.utils.sdb import SessionDB
from fakeoicsrv import MyFakeOICServer
from mitmsrv import MITMServer
from utils_for_tests import _eq
@ -543,16 +548,6 @@ def test_discover():
res = c.discover(principal)
assert res == "http://localhost:8088/"
#def test_discover_redirect():
# c = Consumer(None, None)
# mfos = MyFakeOICServer(name="http://example.com/")
# c.http_request = mfos.http_request
#
# principal = "bar@example.org"
#
# res = c.discover(principal)
# assert res == "http://example.net/providerconf"
def test_provider_config():
c = Consumer(None, None)
@ -606,8 +601,103 @@ def test_client_register():
assert c.registration_expires > utc_time_sans_frac()
SYMKEY = SYMKey(key="TestPassword")
def _faulty_id_token():
idval = {'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject',
'iss': 'https://alpha.cloud.nds.rub.de', 'exp': 1420823073,
'iat': 1420822473, 'aud': 'TestClient'}
idts = IdToken(**idval)
_signed_jwt = idts.to_jwt(key=[SYMKEY], algorithm="HS256")
#Mess with the signed id_token
p = _signed_jwt.split(".")
p[2] = "aaa"
return ".".join(p)
def test_faulty_id_token():
_faulty_signed_jwt = _faulty_id_token()
try:
_ = IdToken().from_jwt(_faulty_signed_jwt, key=[SYMKEY])
except BadSignature:
pass
else:
assert False
# What if no verification key is given ?
# Should also result in an exception
try:
_ = IdToken().from_jwt(_faulty_signed_jwt)
except MissingSigningKey:
pass
else:
assert False
def test_faulty_id_token_in_access_token_response():
c = Consumer(None, None)
c.keyjar.add_symmetric("", "TestPassword", ["sig"])
_info = {"access_token": "accessTok", "id_token": _faulty_id_token(),
"token_type": "Bearer", "expires_in": 3600}
_json = json.dumps(_info)
try:
resp = c.parse_response(AccessTokenResponse, _json, sformat="json")
except BadSignature:
pass
else:
assert False
def test_faulty_idtoken_from_accesstoken_endpoint():
consumer = Consumer(SessionDB(SERVER_INFO["issuer"]), CONFIG,
CLIENT_CONFIG, SERVER_INFO)
consumer.keyjar = CLIKEYS
mfos = MITMServer("http://localhost:8088")
mfos.keyjar = SRVKEYS
consumer.http_request = mfos.http_request
consumer.redirect_uris = ["http://example.com/authz"]
_state = "state0"
consumer.nonce = rndstr()
consumer.client_secret = "hemlig"
consumer.secret_type = "basic"
consumer.config["response_type"] = ["id_token"]
args = {
"client_id": consumer.client_id,
"response_type": consumer.config["response_type"],
"scope": ["openid"],
}
result = consumer.do_authorization_request(state=_state,
request_args=args)
consumer._backup("state0")
assert result.status_code == 302
#assert result.location.startswith(consumer.redirect_uri[0])
_, query = result.headers["location"].split("?")
print query
part = consumer.parse_authz(query=query)
print part
auth = part[0]
acc = part[1]
assert part[2] is None
#print auth.dictionary()
#print acc.dictionary()
assert auth is None
assert acc.type() == "AccessTokenResponse"
assert _eq(acc.keys(), ['access_token', 'id_token', 'expires_in',
'token_type', 'state', 'scope'])
if __name__ == "__main__":
# t = TestOICConsumer()
# t.setup_class()
# t.test_complete()
test_sign_userinfo()
test_faulty_idtoken_from_accesstoken_endpoint()

View File

@ -1,9 +1,13 @@
# -*- coding: utf-8 -*-
from jwkest import BadSignature
from jwkest.jwk import SYMKey
__author__ = 'rohe0002'
import json
from oic.oic.message import ProviderConfigurationResponse, RegistrationResponse, AuthorizationRequest
from oic.oic.message import ProviderConfigurationResponse, RegistrationResponse, AuthorizationRequest, \
IdToken, AccessTokenResponse
from oic.oic.message import msg_ser
from oic.oic.message import claims_ser
from oic.oic.message import RegistrationRequest
@ -178,24 +182,6 @@ def test_authz_request():
assert req["scope"] == ["openid", "profile"]
# def test_idtokenclaim_deser():
# claims = Claims(weather={"acr": "2"})
# pre = IDTokenClaim(claims=claims, max_age=3600)
# idt = idtokenclaim_deser(pre.to_json(), sformat="json")
# assert _eq(idt.keys(), ['claims', "max_age"])
#
#
# def test_userinfo_deser():
# CLAIM = Claims(name={"essential": True}, nickname=None,
# email={"essential": True},
# email_verified={"essential": True}, picture=None)
#
# pre_uic = UserInfoClaim(claims=CLAIM, format="signed")
#
# uic = userinfo_deser(pre_uic.to_json(), sformat="json")
# assert _eq(uic.keys(), ["claims", "format"])
def test_claims_deser_0():
_dic = {
"userinfo": {
@ -307,5 +293,31 @@ def test_registration_request():
assert _eq(ue_splits, expected_ue_splits)
def test_faulty_idtoken():
idval = {'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject',
'iss': 'https://alpha.cloud.nds.rub.de', 'exp': 1420823073,
'iat': 1420822473, 'aud': 'TestClient'}
idts = IdToken(**idval)
key = SYMKey(key="TestPassword")
_signed_jwt = idts.to_jwt(key=[key], algorithm="HS256")
#Mess with the signed id_token
p = _signed_jwt.split(".")
p[2] = "aaa"
_faulty_signed_jwt = ".".join(p)
_info = {"access_token": "accessTok", "id_token": _faulty_signed_jwt,
"token_type": "Bearer", "expires_in": 3600}
# Should fail
at = AccessTokenResponse(**_info)
try:
at.verify(key=[key])
except BadSignature:
pass
else:
raise
if __name__ == "__main__":
test_claims_deser_0()
test_faulty_idtoken()