This repository has been archived on 2023-02-21. You can view files and clone it, but cannot push or open issues or pull requests.
pyoidc-ozwillo/src/oic/oauth2/dynreg.py

655 lines
23 KiB
Python

import logging
import urllib
import urlparse
import requests
from oic.oic import OIDCONF_PATTERN
from oic.oic.message import ProviderConfigurationResponse, AuthorizationResponse
from oic.utils.keyio import KeyJar
from oic.utils.time_util import utc_time_sans_frac
from oic.oic.provider import secret
from oic.oic.provider import RegistrationEndpoint
from oic.oic.provider import Endpoint
from oic import oauth2
from oic.oauth2 import provider
from oic.oauth2 import VerificationError
from oic.oauth2 import rndstr
from oic.oauth2 import ErrorResponse
from oic.oauth2 import UnSupported
from oic.oauth2 import Message
from oic.oauth2 import message
from oic.oauth2 import SINGLE_REQUIRED_STRING
from oic.oauth2 import OPTIONAL_LIST_OF_SP_SEP_STRINGS
from oic.oauth2 import REQUIRED_LIST_OF_STRINGS
from oic.oauth2 import OPTIONAL_LIST_OF_STRINGS
from oic.oauth2 import SINGLE_OPTIONAL_STRING
from oic.oauth2 import SINGLE_OPTIONAL_INT
from oic.exception import UnknownAssertionType
from oic.exception import PyoidcError
from oic.exception import AuthzError
from oic.utils.authn.client import AuthnFailure
from oic.utils.http_util import Unauthorized, NoContent
from oic.utils.http_util import Response
from oic.utils.http_util import BadRequest
from oic.utils.http_util import Forbidden
logger = logging.getLogger(__name__)
__author__ = 'roland'
# -----------------------------------------------------------------------------
class InvalidRedirectUri(Exception):
pass
class MissingPage(Exception):
pass
class ModificationForbidden(Exception):
pass
class RegistrationRequest(Message):
c_param = {
"redirect_uris": REQUIRED_LIST_OF_STRINGS,
"client_name": SINGLE_OPTIONAL_STRING,
"client_uri": SINGLE_OPTIONAL_STRING,
"logo_uri": SINGLE_OPTIONAL_STRING,
"contacts": OPTIONAL_LIST_OF_STRINGS,
"tos_uri": SINGLE_OPTIONAL_STRING,
"policy_uri": SINGLE_OPTIONAL_STRING,
"token_endpoint_auth_method": SINGLE_OPTIONAL_STRING,
"scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS,
"grant_types": OPTIONAL_LIST_OF_STRINGS,
"response_types": OPTIONAL_LIST_OF_STRINGS,
"jwks_uri": SINGLE_OPTIONAL_STRING,
"software_id": SINGLE_OPTIONAL_STRING,
"software_version": SINGLE_OPTIONAL_STRING,
}
def verify(self, **kwargs):
if "initiate_login_uri" in self:
assert self["initiate_login_uri"].startswith("https:")
if "redirect_uris" in self:
for uri in self["redirect_uris"]:
if urlparse.urlparse(uri).fragment:
raise InvalidRedirectUri(
"redirect_uri contains fragment: %s" % uri)
for uri in ["client_uri", "logo_uri", "tos_uri", "policy_uri"]:
if uri in self:
try:
resp = requests.request("GET", self[uri],
allow_redirects=True)
except requests.ConnectionError:
raise MissingPage(self[uri])
if not resp.status_code in [200, 201]:
raise MissingPage(self[uri])
if "grant_types" in self and "response_types" in self:
for typ in self["grant_types"]:
if typ == "authorization_code":
try:
assert "code" in self["response_types"]
except AssertionError:
self["response_types"].append("code")
elif typ == "implicit":
try:
assert "token" in self["response_types"]
except AssertionError:
self["response_types"].append("token")
return super(RegistrationRequest, self).verify(**kwargs)
class ClientInfoResponse(RegistrationRequest):
c_param = RegistrationRequest.c_param.copy()
c_param.update({
"client_id": SINGLE_REQUIRED_STRING,
"client_secret": SINGLE_OPTIONAL_STRING,
"client_id_issued_at": SINGLE_OPTIONAL_INT,
"client_secret_expires_at": SINGLE_OPTIONAL_INT,
"registration_access_token": SINGLE_REQUIRED_STRING,
"registration_client_uri": SINGLE_REQUIRED_STRING
})
class ClientRegistrationError(ErrorResponse):
c_param = ErrorResponse.c_param.copy()
c_param.update({"state": SINGLE_OPTIONAL_STRING})
c_allowed_values = ErrorResponse.c_allowed_values.copy()
c_allowed_values.update({"error": ["invalid_redirect_uri",
"invalid_client_metadata",
"invalid_client_id"]})
class ClientUpdateRequest(RegistrationRequest):
c_param = RegistrationRequest.c_param.copy()
c_param.update({
"client_id": SINGLE_REQUIRED_STRING,
"client_secret": SINGLE_OPTIONAL_STRING,
})
MSG = {
"RegistrationRequest": RegistrationRequest,
"ClientInfoResponse": ClientInfoResponse,
"ClientRegistrationError": ClientRegistrationError,
"ClientUpdateRequest": ClientUpdateRequest
}
def factory(msgtype):
try:
return MSG[msgtype]
except KeyError:
return message.factory(msgtype)
# -----------------------------------------------------------------------------
class ClientInfoEndpoint(Endpoint):
etype = "clientinfo"
class Provider(provider.Provider):
def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn,
symkey="", urlmap=None, iv=0, default_scope="",
ca_bundle=None, seed="", client_authn_methods=None,
authn_at_registration="", client_info_url="",
secret_lifetime=86400):
provider.Provider.__init__(self, name, sdb, cdb, authn_broker, authz,
client_authn, symkey, urlmap, iv,
default_scope, ca_bundle)
self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint])
# dictionary of client authentication methods
self.client_authn_methods = client_authn_methods
if authn_at_registration:
assert authn_at_registration in client_authn_methods
self.authn_at_registration = authn_at_registration
self.seed = seed
self.client_info_url = client_info_url
self.secret_lifetime = secret_lifetime
# @staticmethod
# def _uris_to_dict(uris):
# ruri = {}
# for uri in uris:
# base, query = urllib.splitquery(uri)
# if query:
# try:
# ruri[base].append(urlparse.parse_qs(query))
# except KeyError:
# ruri[base] = [urlparse.parse_qs(query)]
# else:
# ruri[base] = [""]
# return ruri
#
# @staticmethod
# def _dict_to_uris(spec):
# _uri = []
# for url, qlist in spec.items():
# for query in qlist:
# if query:
# _uri.append("%s?%s" % (url, query))
# else:
# _uri.append(url)
# return _uri
@staticmethod
def _uris_to_tuples(uris):
tup = []
for uri in uris:
base, query = urllib.splitquery(uri)
if query:
tup.append((base, query))
else:
tup.append((base, ""))
return tup
@staticmethod
def _tuples_to_uris(items):
_uri = []
for url, query in items:
if query:
_uri.append("%s?%s" % (url, query))
else:
_uri.append(url)
return _uri
def create_new_client(self, request):
"""
:param request: The Client registration request
:return: The client_id
"""
_cinfo = request.to_dict()
# create new id and secret
_id = rndstr(12)
while _id in self.cdb:
_id = rndstr(12)
_cinfo["client_id"] = _id
_cinfo["client_secret"] = secret(self.seed, _id)
_cinfo["client_id_issued_at"] = utc_time_sans_frac()
_cinfo["client_secret_expires_at"] = utc_time_sans_frac() + \
self.secret_lifetime
# If I support client info endpoint
if ClientInfoEndpoint in self.endp:
_cinfo["registration_access_token"] = rndstr(32)
_cinfo["registration_client_uri"] = "%s%s?client_id=%s" % (
self.client_info_url, ClientInfoEndpoint.etype, _id)
if "redirect_uris" in request:
_cinfo["redirect_uris"] = self._uris_to_tuples(
request["redirect_uris"])
self.cdb[_id] = _cinfo
return _id
def client_info(self, client_id):
_cinfo = self.cdb[client_id].copy()
try:
_cinfo["redirect_uris"] = self._tuples_to_uris(
_cinfo["redirect_uris"])
except KeyError:
pass
msg = ClientInfoResponse(**_cinfo)
return Response(msg.to_json(), content="application/json")
def client_info_update(self, client_id, request):
_cinfo = self.cdb[client_id].copy()
try:
_cinfo["redirect_uris"] = self._tuples_to_uris(
_cinfo["redirect_uris"])
except KeyError:
pass
for key, value in request.items():
if key in ["client_secret", "client_id"]:
# assure it's the same
try:
assert value == _cinfo[key]
except AssertionError:
raise ModificationForbidden("Not allowed to change")
else:
_cinfo[key] = value
for key in _cinfo.keys():
if key in ["client_id_issued_at", "client_secret_expires_at",
"registration_access_token", "registration_client_uri"]:
continue
if key not in request:
del _cinfo[key]
if "redirect_uris" in request:
_cinfo["redirect_uris"] = self._uris_to_tuples(
request["redirect_uris"])
self.cdb[client_id] = _cinfo
def verify_client(self, environ, areq, authn_method, client_id=""):
"""
:param environ: WSGI environ
:param areq: The request
:param authn_method: client authentication method
:return:
"""
if not client_id:
client_id = self.get_client_id(areq, environ["HTTP_AUTHORIZATION"])
try:
method = self.client_authn_methods[authn_method]
except KeyError:
raise UnSupported()
return method(self).verify(environ, client_id=client_id)
def registration_endpoint(self, request, environ, **kwargs):
"""
:param request: The request
:param authn: Client authentication information
:param kwargs: extra keyword arguments
:return: A Response instance
"""
_request = RegistrationRequest().deserialize(request, "json")
try:
_request.verify()
except InvalidRedirectUri, err:
msg = ClientRegistrationError(error="invalid_redirect_uri",
error_description="%s" % err)
return BadRequest(msg.to_json(), content="application/json")
except (MissingPage, VerificationError), err:
msg = ClientRegistrationError(error="invalid_client_metadata",
error_description="%s" % err)
return BadRequest(msg.to_json(), content="application/json")
# authenticated client
if self.authn_at_registration:
try:
_ = self.verify_client(environ, _request,
self.authn_at_registration)
except (AuthnFailure, UnknownAssertionType):
return Unauthorized()
client_id = self.create_new_client(_request)
return self.client_info(client_id)
def client_info_endpoint(self, request, environ,
method="GET", query="", **kwargs):
"""
Operations on this endpoint are switched through the use of different
HTTP methods
:param request: The request
:param authn: Client authentication information
:param method: HTTP method used for the request
:param query: The query part of the URL used, this is where the
client_id is supposed to reside.
:param kwargs: extra keyword arguments
:return: A Response instance
"""
_query = urlparse.parse_qs(query)
try:
_id = _query["client_id"][0]
except KeyError:
return BadRequest("Missing query component")
try:
assert _id in self.cdb
except AssertionError:
return Unauthorized()
# authenticated client
try:
_ = self.verify_client(environ, request, "bearer_header",
client_id=_id)
except (AuthnFailure, UnknownAssertionType):
return Unauthorized()
if method == "GET":
return self.client_info(_id)
elif method == "PUT":
try:
_request = ClientUpdateRequest().from_json(request)
except ValueError:
return BadRequest()
try:
_request.verify()
except InvalidRedirectUri, err:
msg = ClientRegistrationError(error="invalid_redirect_uri",
error_description="%s" % err)
return BadRequest(msg.to_json(), content="application/json")
except (MissingPage, VerificationError), err:
msg = ClientRegistrationError(error="invalid_client_metadata",
error_description="%s" % err)
return BadRequest(msg.to_json(), content="application/json")
try:
self.client_info_update(_id, _request)
return self.client_info(_id)
except ModificationForbidden:
return Forbidden()
elif method == "DELETE":
try:
del self.cdb[_id]
except KeyError:
return Unauthorized()
else:
return NoContent()
def providerinfo_endpoint(self):
pass
RESPONSE2ERROR = {
"ClientInfoResponse": [ClientRegistrationError],
"ClientUpdateRequest": [ClientRegistrationError]
}
class Client(oauth2.Client):
def __init__(self, client_id=None, ca_certs=None,
client_authn_method=None, keyjar=None, verify_ssl=True):
oauth2.Client.__init__(self, client_id=client_id, ca_certs=ca_certs,
client_authn_method=client_authn_method,
keyjar=keyjar, verify_ssl=verify_ssl)
self.allow = {}
self.request2endpoint.update({
"RegistrationRequest": "registration_endpoint",
"ClientUpdateRequest": "clientinfo_endpoint"
})
self.registration_response = None
def construct_RegistrationRequest(self, request=RegistrationRequest,
request_args=None, extra_args=None,
**kwargs):
if request_args is None:
request_args = {}
return self.construct_request(request, request_args, extra_args)
def do_client_registration(self, request=RegistrationRequest,
body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None,
response_cls=ClientInfoResponse,
**kwargs):
url, body, ht_args, csi = self.request_info(request, method,
request_args, extra_args,
**kwargs)
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
resp = self.request_and_return(url, response_cls, method, body,
body_type, http_args=http_args)
return resp
def do_client_read_request(self, request=ClientUpdateRequest,
body_type="", method="GET",
request_args=None, extra_args=None,
http_args=None,
response_cls=ClientInfoResponse,
**kwargs):
url, body, ht_args, csi = self.request_info(request, method,
request_args, extra_args,
**kwargs)
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
resp = self.request_and_return(url, response_cls, method, body,
body_type, http_args=http_args)
return resp
def do_client_update_request(self, request=ClientUpdateRequest,
body_type="", method="PUT",
request_args=None, extra_args=None,
http_args=None,
response_cls=ClientInfoResponse,
**kwargs):
url, body, ht_args, csi = self.request_info(request, method,
request_args, extra_args,
**kwargs)
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
resp = self.request_and_return(url, response_cls, method, body,
body_type, http_args=http_args)
return resp
def do_client_delete_request(self, request=ClientUpdateRequest,
body_type="", method="DELETE",
request_args=None, extra_args=None,
http_args=None,
response_cls=ClientInfoResponse,
**kwargs):
url, body, ht_args, csi = self.request_info(request, method,
request_args, extra_args,
**kwargs)
if http_args is None:
http_args = ht_args
else:
http_args.update(http_args)
resp = self.request_and_return(url, response_cls, method, body,
body_type, http_args=http_args)
return resp
def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
"""
Deal with Provider Config Response
:param pcr: The ProviderConfigResponse instance
:param issuer: The one I thought should be the issuer of the config
:param keys: Should I deal with keys
:param endpoints: Should I deal with endpoints, that is store them
as attributes in self.
"""
if "issuer" in pcr:
_pcr_issuer = pcr["issuer"]
if pcr["issuer"].endswith("/"):
if issuer.endswith("/"):
_issuer = issuer
else:
_issuer = issuer + "/"
else:
if issuer.endswith("/"):
_issuer = issuer[:-1]
else:
_issuer = issuer
try:
_ = self.allow["issuer_mismatch"]
except KeyError:
try:
assert _issuer == _pcr_issuer
except AssertionError:
raise PyoidcError(
"provider info issuer mismatch '%s' != '%s'" % (
_issuer, _pcr_issuer))
self.provider_info[_pcr_issuer] = pcr
else:
_pcr_issuer = issuer
if endpoints:
for key, val in pcr.items():
if key.endswith("_endpoint"):
setattr(self, key, val)
if keys:
if self.keyjar is None:
self.keyjar = KeyJar()
self.keyjar.load_keys(pcr, _pcr_issuer)
def provider_config(self, issuer, keys=True, endpoints=True,
response_cls=ProviderConfigurationResponse,
serv_pattern=OIDCONF_PATTERN):
if issuer.endswith("/"):
_issuer = issuer[:-1]
else:
_issuer = issuer
url = serv_pattern % _issuer
pcr = None
r = self.http_request(url)
if r.status_code == 200:
pcr = response_cls().from_json(r.text)
elif r.status_code == 302:
while r.status_code == 302:
r = self.http_request(r.headers["location"])
if r.status_code == 200:
pcr = response_cls().from_json(r.text)
break
if pcr is None:
raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))
self.handle_provider_config(pcr, issuer, keys, endpoints)
return pcr
def store_registration_info(self, reginfo):
self.registration_response = reginfo
self.client_secret = reginfo["client_secret"]
self.client_id = reginfo["client_id"]
self.redirect_uris = reginfo["redirect_uris"]
def handle_registration_info(self, response):
if response.status_code in [200, 201]:
resp = ClientInfoResponse().deserialize(response.text, "json")
self.store_registration_info(resp)
else:
err = ErrorResponse().deserialize(response.text, "json")
raise PyoidcError("Registration failed: %s" % err.to_json())
return resp
def register(self, url, **kwargs):
"""
Register the client at an OP
:param url: The OPs registration endpoint
:param kwargs: parameters to the registration request
:return:
"""
req = self.construct_RegistrationRequest(request_args=kwargs)
headers = {"content-type": "application/json"}
rsp = self.http_request(url, "POST", data=req.to_json(),
headers=headers)
return self.handle_registration_info(rsp)
def parse_authz_response(self, query):
aresp = self.parse_response(AuthorizationResponse,
info=query,
sformat="urlencoded",
keyjar=self.keyjar)
if aresp.type() == "ErrorResponse":
logger.info("ErrorResponse: %s" % aresp)
raise AuthzError(aresp.error)
logger.info("Aresp: %s" % aresp)
return aresp