diff --git a/src/oic/oauth2/__init__.py b/src/oic/oauth2/__init__.py index 2608daa..3c5590e 100644 --- a/src/oic/oauth2/__init__.py +++ b/src/oic/oauth2/__init__.py @@ -821,11 +821,13 @@ class Client(PBase): elif resp.only_extras(): resp = None else: + if "key" not in kwargs and "keyjar" not in kwargs: + kwargs["keyjar"] = self.keyjar verf = resp.verify(**kwargs) if not verf: raise PyoidcError("Verification of the response failed") if resp.type() == "AuthorizationResponse" and \ - "scope" not in resp: + "scope" not in resp: try: resp["scope"] = kwargs["scope"] except KeyError: @@ -950,7 +952,7 @@ class Client(PBase): except Exception: raise - if not "keyjar" in kwargs: + if "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar return self.parse_request_response(resp, response, body_type, state, diff --git a/src/oic/oic/consumer.py b/src/oic/oic/consumer.py index ab14f1d..b8c13e7 100644 --- a/src/oic/oic/consumer.py +++ b/src/oic/oic/consumer.py @@ -311,11 +311,41 @@ class Consumer(Client): return sid, location + def _parse_authz(self, query="", **kwargs): + _log_info = logger.info + # Might be an error response + _log_info("Expect Authorization Response") + aresp = self.parse_response(AuthorizationResponse, + info=query, + sformat="urlencoded", + keyjar=self.keyjar) + if aresp.type() == "ErrorResponse": + _log_info("ErrorResponse: %s" % aresp) + raise AuthzError(aresp.error, aresp) + + _log_info("Aresp: %s" % aresp) + + _state = aresp["state"] + try: + self.update(_state) + except KeyError: + raise UnknownState(_state, aresp) + + self.redirect_uris = [self.sdb[_state]["redirect_uris"]] + return aresp, _state + #noinspection PyUnusedLocal def parse_authz(self, query="", **kwargs): """ This is where we get redirect back to after authorization at the authorization server has happened. + Couple of cases + ["code"] + ["code", "token"] + ["code", "id_token", "token"] + ["id_token"] + ["id_token", "token"] + ["token"] :return: A AccessTokenResponse instance """ @@ -329,25 +359,7 @@ class Consumer(Client): _log_info("response: %s" % query) if "code" in self.config["response_type"]: - # Might be an error response - _log_info("Expect Authorization Response") - aresp = self.parse_response(AuthorizationResponse, - info=query, - sformat="urlencoded", - keyjar=self.keyjar) - if aresp.type() == "ErrorResponse": - _log_info("ErrorResponse: %s" % aresp) - raise AuthzError(aresp.error, aresp) - - _log_info("Aresp: %s" % aresp) - - _state = aresp["state"] - try: - self.update(_state) - except KeyError: - raise UnknownState(_state, aresp) - - self.redirect_uris = [self.sdb[_state]["redirect_uris"]] + aresp, _state = self._parse_authz(query, **kwargs) # May have token and id_token information too if "access_token" in aresp: @@ -366,7 +378,7 @@ class Consumer(Client): idt = None return aresp, atr, idt - else: # implicit flow + elif "token" in self.config["response_type"]: # implicit flow _log_info("Expect Access Token Response") atr = self.parse_response(AccessTokenResponse, info=query, sformat="urlencoded", @@ -376,6 +388,14 @@ class Consumer(Client): idt = None return None, atr, idt + else: # only id_token + aresp, _state = self._parse_authz(query, **kwargs) + + try: + idt = aresp["id_token"] + except KeyError: + idt = None + return None, None, idt def complete(self, state): """ diff --git a/tests/mitmsrv.py b/tests/mitmsrv.py new file mode 100644 index 0000000..4f8a3dd --- /dev/null +++ b/tests/mitmsrv.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python +from urlparse import parse_qs +from jwkest.jws import alg2keytype + +from oic.oauth2 import rndstr +from oic.oauth2.message import by_schema + +from oic.oic import Server +from oic.oic.message import * + +from oic.utils.sdb import SessionDB, AuthnEvent +from oic.utils.time_util import utc_time_sans_frac +from oic.utils.webfinger import WebFinger + + +__author__ = 'rohe0002' + + +class Response(): + def __init__(self, base=None): + self.status_code = 200 + if base: + for key, val in base.items(): + self.__setitem__(key, val) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, item): + return getattr(self, item) + + +ENDPOINT = { + "authorization_endpoint": "/authorization", + "token_endpoint": "/token", + "user_info_endpoint": "/userinfo", + "check_session_endpoint": "/check_session", + "refresh_session_endpoint": "/refresh_session", + "end_session_endpoint": "/end_session", + "registration_endpoint": "/registration", + "discovery_endpoint": "/discovery", + "register_endpoint": "/register" +} + + +class MITMServer(Server): + def __init__(self, name=""): + Server.__init__(self) + self.sdb = SessionDB(name) + self.name = name + self.client = {} + self.registration_expires_in = 3600 + self.host = "" + self.webfinger = WebFinger() + self.userinfo_signed_response_alg = "" + + # noinspection PyUnusedLocal + def http_request(self, path, method="GET", **kwargs): + part = urlparse(path) + path = part[2] + query = part[4] + self.host = "%s://%s" % (part.scheme, part.netloc) + + response = Response + response.status_code = 500 + response.text = "" + + if path == ENDPOINT["authorization_endpoint"]: + assert method == "GET" + response = self.authorization_endpoint(query) + elif path == ENDPOINT["token_endpoint"]: + assert method == "POST" + response = self.token_endpoint(kwargs["data"]) + elif path == ENDPOINT["user_info_endpoint"]: + assert method == "POST" + response = self.userinfo_endpoint(kwargs["data"]) + elif path == ENDPOINT["refresh_session_endpoint"]: + assert method == "GET" + response = self.refresh_session_endpoint(query) + elif path == ENDPOINT["check_session_endpoint"]: + assert method == "GET" + response = self.check_session_endpoint(query) + elif path == ENDPOINT["end_session_endpoint"]: + assert method == "GET" + response = self.end_session_endpoint(query) + elif path == ENDPOINT["registration_endpoint"]: + if method == "POST": + response = self.registration_endpoint(kwargs["data"]) + elif path == "/.well-known/webfinger": + assert method == "GET" + qdict = parse_qs(query) + response.status_code = 200 + response.text = self.webfinger.response(qdict["resource"][0], + "%s/" % self.name) + elif path == "/.well-known/openid-configuration": + assert method == "GET" + response = self.openid_conf() + + return response + + def authorization_endpoint(self, query): + req = self.parse_authorization_request(query=query) + aevent = AuthnEvent("user", authn_info="acr") + sid = self.sdb.create_authz_session(aevent, areq=req) + _ = self.sdb.do_sub(sid) + _info = self.sdb[sid] + + if "code" in req["response_type"]: + if "token" in req["response_type"]: + grant = _info["code"] + _dict = self.sdb.upgrade_to_token(grant) + _dict["oauth_state"] = "authz", + + _dict = by_schema(AuthorizationResponse(), **_dict) + resp = AuthorizationResponse(**_dict) + # resp.code = grant + else: + _state = req["state"] + resp = AuthorizationResponse(state=_state, + code=_info["code"]) + + else: # "implicit" in req.response_type: + grant = _info["code"] + params = AccessTokenResponse.c_param.keys() + + if "token" in req["response_type"]: + _dict = dict([(k, v) for k, v in + self.sdb.upgrade_to_token(grant).items() if k in + params]) + try: + del _dict["refresh_token"] + except KeyError: + pass + else: + _dict = {"state": req["state"]} + + if "id_token" in req["response_type"]: + _idt = self.make_id_token(_info, issuer=self.name) + alg = "RS256" + ckey = self.keyjar.get_signing_key(alg2keytype(alg), + _info["client_id"]) + _signed_jwt = _idt.to_jwt(key=ckey, algorithm=alg) + p = _signed_jwt.split(".") + p[2] = "aaa" + _dict["id_token"] = ".".join(p) + + resp = AuthorizationResponse(**_dict) + + location = resp.request(req["redirect_uri"]) + response = Response() + response.headers = {"location": location} + response.status_code = 302 + response.text = "" + return response + + def token_endpoint(self, data): + if "grant_type=refresh_token" in data: + req = self.parse_refresh_token_request(body=data) + _info = self.sdb.refresh_token(req["refresh_token"]) + elif "grant_type=authorization_code": + req = self.parse_token_request(body=data) + _info = self.sdb.upgrade_to_token(req["code"]) + else: + response = TokenErrorResponse(error="unsupported_grant_type") + return response, "" + + resp = AccessTokenResponse(**by_schema(AccessTokenResponse, **_info)) + response = Response() + response.headers = {"content-type": "application/json"} + response.text = resp.to_json() + + return response + + def userinfo_endpoint(self, data): + + _ = self.parse_user_info_request(data) + _info = { + "sub": "melgar", + "name": "Melody Gardot", + "nickname": "Mel", + "email": "mel@example.com", + "verified": True, + } + + resp = OpenIDSchema(**_info) + response = Response() + + if self.userinfo_signed_response_alg: + alg = self.userinfo_signed_response_alg + response.headers = {"content-type": "application/jwt"} + key = self.keyjar.get_signing_key(alg2keytype(alg), "", alg=alg) + response.text = resp.to_jwt(key, alg) + else: + response.headers = {"content-type": "application/json"} + response.text = resp.to_json() + + return response + + def registration_endpoint(self, data): + try: + req = self.parse_registration_request(data, "json") + except ValueError: + req = self.parse_registration_request(data) + + client_secret = rndstr() + expires = utc_time_sans_frac() + self.registration_expires_in + kwargs = {} + if "client_id" not in req: + client_id = rndstr(10) + registration_access_token = rndstr(20) + _client_info = req.to_dict() + kwargs.update(_client_info) + _client_info.update({ + "client_secret": client_secret, + "info": req.to_dict(), + "expires": expires, + "registration_access_token": registration_access_token, + "registration_client_uri": "register_endpoint" + }) + self.client[client_id] = _client_info + kwargs["registration_access_token"] = registration_access_token + kwargs["registration_client_uri"] = "register_endpoint" + try: + del kwargs["operation"] + except KeyError: + pass + else: + client_id = req.client_id + _cinfo = self.client[req.client_id] + _cinfo["info"].update(req.to_dict()) + _cinfo["client_secret"] = client_secret + _cinfo["expires"] = expires + + resp = RegistrationResponse(client_id=client_id, + client_secret=client_secret, + client_secret_expires_at=expires, + **kwargs) + + response = Response() + response.headers = {"content-type": "application/json"} + response.text = resp.to_json() + + return response + + def check_session_endpoint(self, query): + try: + idtoken = self.parse_check_session_request(query=query) + except Exception: + raise + + response = Response() + response.text = idtoken.to_json() + response.headers = {"content-type": "application/json"} + return response + + # noinspection PyUnusedLocal + def refresh_session_endpoint(self, query): + try: + req = self.parse_refresh_session_request(query=query) + except Exception: + raise + + resp = RegistrationResponse(client_id="anonymous", + client_secret="hemligt") + + response = Response() + response.headers = {"content-type": "application/json"} + response.text = resp.to_json() + return response + + def end_session_endpoint(self, query): + try: + req = self.parse_end_session_request(query=query) + except Exception: + raise + + # redirect back + resp = EndSessionResponse(state=req["state"]) + + url = resp.request(req["redirect_url"]) + + response = Response() + response.headers = {"location": url} + response.status_code = 302 # redirect + response.text = "" + return response + + # noinspection PyUnusedLocal + @staticmethod + def add_credentials(user, passwd): + return + + def openid_conf(self): + endpoint = {} + for point, path in ENDPOINT.items(): + endpoint[point] = "%s%s" % (self.host, path) + + signing_algs = jws.SIGNER_ALGS.keys() + resp = ProviderConfigurationResponse( + issuer=self.name, + scopes_supported=["openid", "profile", "email", "address"], + identifiers_supported=["public", "PPID"], + flows_supported=["code", "token", "code token", "id_token", + "code id_token", "token id_token"], + subject_types_supported=["pairwise", "public"], + response_types_supported=["code", "token", "id_token", + "code token", "code id_token", + "token id_token", "code token id_token"], + jwks_uri="http://example.com/oidc/jwks", + id_token_signing_alg_values_supported=signing_algs, + grant_types_supported=["authorization_code", "implicit"], + **endpoint) + + response = Response() + response.headers = {"content-type": "application/json"} + response.text = resp.to_json() + return response diff --git a/tests/test_oic_consumer.py b/tests/test_oic_consumer.py index 4cf5a5d..93b4902 100644 --- a/tests/test_oic_consumer.py +++ b/tests/test_oic_consumer.py @@ -1,8 +1,12 @@ +import json import os import shutil import tempfile +from jwkest import BadSignature +from jwkest.jwk import SYMKey +from oic.oauth2.message import MissingSigningKey -from oic.oic.message import AccessTokenResponse, AuthorizationResponse +from oic.oic.message import AccessTokenResponse, AuthorizationResponse, IdToken from oic.utils.keyio import KeyBundle, keybundle_from_local_file from oic.utils.keyio import KeyJar @@ -19,6 +23,7 @@ from oic.utils.time_util import utc_time_sans_frac from oic.utils.sdb import SessionDB from fakeoicsrv import MyFakeOICServer +from mitmsrv import MITMServer from utils_for_tests import _eq @@ -543,16 +548,6 @@ def test_discover(): res = c.discover(principal) assert res == "http://localhost:8088/" -#def test_discover_redirect(): -# c = Consumer(None, None) -# mfos = MyFakeOICServer(name="http://example.com/") -# c.http_request = mfos.http_request -# -# principal = "bar@example.org" -# -# res = c.discover(principal) -# assert res == "http://example.net/providerconf" - def test_provider_config(): c = Consumer(None, None) @@ -606,8 +601,103 @@ def test_client_register(): assert c.registration_expires > utc_time_sans_frac() +SYMKEY = SYMKey(key="TestPassword") + + +def _faulty_id_token(): + idval = {'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', + 'iss': 'https://alpha.cloud.nds.rub.de', 'exp': 1420823073, + 'iat': 1420822473, 'aud': 'TestClient'} + idts = IdToken(**idval) + + _signed_jwt = idts.to_jwt(key=[SYMKEY], algorithm="HS256") + + #Mess with the signed id_token + p = _signed_jwt.split(".") + p[2] = "aaa" + + return ".".join(p) + + +def test_faulty_id_token(): + _faulty_signed_jwt = _faulty_id_token() + try: + _ = IdToken().from_jwt(_faulty_signed_jwt, key=[SYMKEY]) + except BadSignature: + pass + else: + assert False + + # What if no verification key is given ? + # Should also result in an exception + try: + _ = IdToken().from_jwt(_faulty_signed_jwt) + except MissingSigningKey: + pass + else: + assert False + + +def test_faulty_id_token_in_access_token_response(): + c = Consumer(None, None) + c.keyjar.add_symmetric("", "TestPassword", ["sig"]) + + _info = {"access_token": "accessTok", "id_token": _faulty_id_token(), + "token_type": "Bearer", "expires_in": 3600} + + _json = json.dumps(_info) + try: + resp = c.parse_response(AccessTokenResponse, _json, sformat="json") + except BadSignature: + pass + else: + assert False + + +def test_faulty_idtoken_from_accesstoken_endpoint(): + consumer = Consumer(SessionDB(SERVER_INFO["issuer"]), CONFIG, + CLIENT_CONFIG, SERVER_INFO) + consumer.keyjar = CLIKEYS + mfos = MITMServer("http://localhost:8088") + mfos.keyjar = SRVKEYS + consumer.http_request = mfos.http_request + consumer.redirect_uris = ["http://example.com/authz"] + _state = "state0" + consumer.nonce = rndstr() + consumer.client_secret = "hemlig" + consumer.secret_type = "basic" + consumer.config["response_type"] = ["id_token"] + + args = { + "client_id": consumer.client_id, + "response_type": consumer.config["response_type"], + "scope": ["openid"], + } + + result = consumer.do_authorization_request(state=_state, + request_args=args) + consumer._backup("state0") + + assert result.status_code == 302 + #assert result.location.startswith(consumer.redirect_uri[0]) + _, query = result.headers["location"].split("?") + print query + part = consumer.parse_authz(query=query) + print part + auth = part[0] + acc = part[1] + assert part[2] is None + + #print auth.dictionary() + #print acc.dictionary() + assert auth is None + assert acc.type() == "AccessTokenResponse" + assert _eq(acc.keys(), ['access_token', 'id_token', 'expires_in', + 'token_type', 'state', 'scope']) + + if __name__ == "__main__": # t = TestOICConsumer() # t.setup_class() # t.test_complete() - test_sign_userinfo() + test_faulty_idtoken_from_accesstoken_endpoint() diff --git a/tests/test_oic_message.py b/tests/test_oic_message.py index e213310..d8921e9 100644 --- a/tests/test_oic_message.py +++ b/tests/test_oic_message.py @@ -1,9 +1,13 @@ # -*- coding: utf-8 -*- +from jwkest import BadSignature +from jwkest.jwk import SYMKey + __author__ = 'rohe0002' import json -from oic.oic.message import ProviderConfigurationResponse, RegistrationResponse, AuthorizationRequest +from oic.oic.message import ProviderConfigurationResponse, RegistrationResponse, AuthorizationRequest, \ + IdToken, AccessTokenResponse from oic.oic.message import msg_ser from oic.oic.message import claims_ser from oic.oic.message import RegistrationRequest @@ -178,24 +182,6 @@ def test_authz_request(): assert req["scope"] == ["openid", "profile"] -# def test_idtokenclaim_deser(): -# claims = Claims(weather={"acr": "2"}) -# pre = IDTokenClaim(claims=claims, max_age=3600) -# idt = idtokenclaim_deser(pre.to_json(), sformat="json") -# assert _eq(idt.keys(), ['claims', "max_age"]) -# -# -# def test_userinfo_deser(): -# CLAIM = Claims(name={"essential": True}, nickname=None, -# email={"essential": True}, -# email_verified={"essential": True}, picture=None) -# -# pre_uic = UserInfoClaim(claims=CLAIM, format="signed") -# -# uic = userinfo_deser(pre_uic.to_json(), sformat="json") -# assert _eq(uic.keys(), ["claims", "format"]) - - def test_claims_deser_0(): _dic = { "userinfo": { @@ -307,5 +293,31 @@ def test_registration_request(): assert _eq(ue_splits, expected_ue_splits) +def test_faulty_idtoken(): + idval = {'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'EndUserSubject', + 'iss': 'https://alpha.cloud.nds.rub.de', 'exp': 1420823073, + 'iat': 1420822473, 'aud': 'TestClient'} + idts = IdToken(**idval) + key = SYMKey(key="TestPassword") + _signed_jwt = idts.to_jwt(key=[key], algorithm="HS256") + + #Mess with the signed id_token + p = _signed_jwt.split(".") + p[2] = "aaa" + _faulty_signed_jwt = ".".join(p) + + _info = {"access_token": "accessTok", "id_token": _faulty_signed_jwt, + "token_type": "Bearer", "expires_in": 3600} + + # Should fail + at = AccessTokenResponse(**_info) + try: + at.verify(key=[key]) + except BadSignature: + pass + else: + raise + + if __name__ == "__main__": - test_claims_deser_0() \ No newline at end of file + test_faulty_idtoken() \ No newline at end of file