1073 lines
30 KiB
Python
1073 lines
30 KiB
Python
# Copyright (C) 2016 JWCrypto Project Contributors - see LICENSE file
|
|
|
|
import abc
|
|
import os
|
|
import struct
|
|
from binascii import hexlify, unhexlify
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import constant_time, hashes, hmac
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from cryptography.hazmat.primitives.asymmetric import utils as ec_utils
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap, aes_key_wrap
|
|
from cryptography.hazmat.primitives.padding import PKCS7
|
|
|
|
import six
|
|
|
|
from jwcrypto.common import InvalidCEKeyLength
|
|
from jwcrypto.common import InvalidJWAAlgorithm
|
|
from jwcrypto.common import InvalidJWEKeyLength
|
|
from jwcrypto.common import InvalidJWEKeyType
|
|
from jwcrypto.common import InvalidJWEOperation
|
|
from jwcrypto.common import base64url_decode, base64url_encode
|
|
from jwcrypto.common import json_decode
|
|
from jwcrypto.jwk import JWK
|
|
|
|
# Implements RFC 7518 - JSON Web Algorithms (JWA)
|
|
|
|
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class JWAAlgorithm(object):
|
|
|
|
@abc.abstractproperty
|
|
def name(self):
|
|
"""The algorithm Name"""
|
|
pass
|
|
|
|
@abc.abstractproperty
|
|
def description(self):
|
|
"""A short description"""
|
|
pass
|
|
|
|
@abc.abstractproperty
|
|
def keysize(self):
|
|
"""The actual/recommended/minimum key size"""
|
|
pass
|
|
|
|
@abc.abstractproperty
|
|
def algorithm_usage_location(self):
|
|
"""One of 'alg', 'enc' or 'JWK'"""
|
|
pass
|
|
|
|
@abc.abstractproperty
|
|
def algorithm_use(self):
|
|
"""One of 'sig', 'kex', 'enc'"""
|
|
pass
|
|
|
|
|
|
def _bitsize(x):
|
|
return len(x) * 8
|
|
|
|
|
|
def _inbytes(x):
|
|
return x // 8
|
|
|
|
|
|
def _randombits(x):
|
|
if x % 8 != 0:
|
|
raise ValueError("lenght must be a multiple of 8")
|
|
return os.urandom(_inbytes(x))
|
|
|
|
|
|
# Note: the number of bits should be a multiple of 16
|
|
def _encode_int(n, bits):
|
|
e = '{:x}'.format(n)
|
|
ilen = ((bits + 7) // 8) * 2 # number of bytes rounded up times 2 bytes
|
|
return unhexlify(e.rjust(ilen, '0')[:ilen])
|
|
|
|
|
|
def _decode_int(n):
|
|
return int(hexlify(n), 16)
|
|
|
|
|
|
class _RawJWS(object):
|
|
|
|
def sign(self, key, payload):
|
|
raise NotImplementedError
|
|
|
|
def verify(self, key, payload, signature):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _RawHMAC(_RawJWS):
|
|
|
|
def __init__(self, hashfn):
|
|
self.backend = default_backend()
|
|
self.hashfn = hashfn
|
|
|
|
def _hmac_setup(self, key, payload):
|
|
h = hmac.HMAC(key, self.hashfn, backend=self.backend)
|
|
h.update(payload)
|
|
return h
|
|
|
|
def sign(self, key, payload):
|
|
skey = base64url_decode(key.get_op_key('sign'))
|
|
h = self._hmac_setup(skey, payload)
|
|
return h.finalize()
|
|
|
|
def verify(self, key, payload, signature):
|
|
vkey = base64url_decode(key.get_op_key('verify'))
|
|
h = self._hmac_setup(vkey, payload)
|
|
h.verify(signature)
|
|
|
|
|
|
class _RawRSA(_RawJWS):
|
|
def __init__(self, padfn, hashfn):
|
|
self.padfn = padfn
|
|
self.hashfn = hashfn
|
|
|
|
def sign(self, key, payload):
|
|
skey = key.get_op_key('sign')
|
|
return skey.sign(payload, self.padfn, self.hashfn)
|
|
|
|
def verify(self, key, payload, signature):
|
|
pkey = key.get_op_key('verify')
|
|
pkey.verify(signature, payload, self.padfn, self.hashfn)
|
|
|
|
|
|
class _RawEC(_RawJWS):
|
|
def __init__(self, curve, hashfn):
|
|
self._curve = curve
|
|
self.hashfn = hashfn
|
|
|
|
@property
|
|
def curve(self):
|
|
return self._curve
|
|
|
|
def sign(self, key, payload):
|
|
skey = key.get_op_key('sign', self._curve)
|
|
signature = skey.sign(payload, ec.ECDSA(self.hashfn))
|
|
r, s = ec_utils.decode_dss_signature(signature)
|
|
size = key.get_curve(self._curve).key_size
|
|
return _encode_int(r, size) + _encode_int(s, size)
|
|
|
|
def verify(self, key, payload, signature):
|
|
pkey = key.get_op_key('verify', self._curve)
|
|
r = signature[:len(signature) // 2]
|
|
s = signature[len(signature) // 2:]
|
|
enc_signature = ec_utils.encode_dss_signature(
|
|
int(hexlify(r), 16), int(hexlify(s), 16))
|
|
pkey.verify(enc_signature, payload, ec.ECDSA(self.hashfn))
|
|
|
|
|
|
class _RawNone(_RawJWS):
|
|
|
|
def sign(self, key, payload):
|
|
return ''
|
|
|
|
def verify(self, key, payload, signature):
|
|
raise InvalidSignature('The "none" signature cannot be verified')
|
|
|
|
|
|
class _HS256(_RawHMAC, JWAAlgorithm):
|
|
|
|
name = "HS256"
|
|
description = "HMAC using SHA-256"
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_HS256, self).__init__(hashes.SHA256())
|
|
|
|
|
|
class _HS384(_RawHMAC, JWAAlgorithm):
|
|
|
|
name = "HS384"
|
|
description = "HMAC using SHA-384"
|
|
keysize = 384
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_HS384, self).__init__(hashes.SHA384())
|
|
|
|
|
|
class _HS512(_RawHMAC, JWAAlgorithm):
|
|
|
|
name = "HS512"
|
|
description = "HMAC using SHA-512"
|
|
keysize = 512
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_HS512, self).__init__(hashes.SHA512())
|
|
|
|
|
|
class _RS256(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "RS256"
|
|
description = "RSASSA-PKCS1-v1_5 using SHA-256"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_RS256, self).__init__(padding.PKCS1v15(), hashes.SHA256())
|
|
|
|
|
|
class _RS384(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "RS384"
|
|
description = "RSASSA-PKCS1-v1_5 using SHA-384"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_RS384, self).__init__(padding.PKCS1v15(), hashes.SHA384())
|
|
|
|
|
|
class _RS512(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "RS512"
|
|
description = "RSASSA-PKCS1-v1_5 using SHA-512"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_RS512, self).__init__(padding.PKCS1v15(), hashes.SHA512())
|
|
|
|
|
|
class _ES256(_RawEC, JWAAlgorithm):
|
|
|
|
name = "ES256"
|
|
description = "ECDSA using P-256 and SHA-256"
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_ES256, self).__init__('P-256', hashes.SHA256())
|
|
|
|
|
|
class _ES384(_RawEC, JWAAlgorithm):
|
|
|
|
name = "ES384"
|
|
description = "ECDSA using P-384 and SHA-384"
|
|
keysize = 384
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_ES384, self).__init__('P-384', hashes.SHA384())
|
|
|
|
|
|
class _ES512(_RawEC, JWAAlgorithm):
|
|
|
|
name = "ES512"
|
|
description = "ECDSA using P-521 and SHA-512"
|
|
keysize = 512
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
super(_ES512, self).__init__('P-521', hashes.SHA512())
|
|
|
|
|
|
class _PS256(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "PS256"
|
|
description = "RSASSA-PSS using SHA-256 and MGF1 with SHA-256"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
padfn = padding.PSS(padding.MGF1(hashes.SHA256()),
|
|
hashes.SHA256.digest_size)
|
|
super(_PS256, self).__init__(padfn, hashes.SHA256())
|
|
|
|
|
|
class _PS384(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "PS384"
|
|
description = "RSASSA-PSS using SHA-384 and MGF1 with SHA-384"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
padfn = padding.PSS(padding.MGF1(hashes.SHA384()),
|
|
hashes.SHA384.digest_size)
|
|
super(_PS384, self).__init__(padfn, hashes.SHA384())
|
|
|
|
|
|
class _PS512(_RawRSA, JWAAlgorithm):
|
|
|
|
name = "PS512"
|
|
description = "RSASSA-PSS using SHA-512 and MGF1 with SHA-512"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
def __init__(self):
|
|
padfn = padding.PSS(padding.MGF1(hashes.SHA512()),
|
|
hashes.SHA512.digest_size)
|
|
super(_PS512, self).__init__(padfn, hashes.SHA512())
|
|
|
|
|
|
class _None(_RawNone, JWAAlgorithm):
|
|
|
|
name = "none"
|
|
description = "No digital signature or MAC performed"
|
|
keysize = 0
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'sig'
|
|
|
|
|
|
class _RawKeyMgmt(object):
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
raise NotImplementedError
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _RSA(_RawKeyMgmt):
|
|
|
|
def __init__(self, padfn):
|
|
self.padfn = padfn
|
|
|
|
def _check_key(self, key):
|
|
if not isinstance(key, JWK):
|
|
raise ValueError('key is not a JWK object')
|
|
if key.key_type != 'RSA':
|
|
raise InvalidJWEKeyType('RSA', key.key_type)
|
|
|
|
# FIXME: get key size and insure > 2048 bits
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
self._check_key(key)
|
|
if not cek:
|
|
cek = _randombits(bitsize)
|
|
rk = key.get_op_key('wrapKey')
|
|
ek = rk.encrypt(cek, self.padfn)
|
|
return {'cek': cek, 'ek': ek}
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
self._check_key(key)
|
|
rk = key.get_op_key('decrypt')
|
|
cek = rk.decrypt(ek, self.padfn)
|
|
if _bitsize(cek) != bitsize:
|
|
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
|
|
return cek
|
|
|
|
|
|
class _Rsa15(_RSA, JWAAlgorithm):
|
|
|
|
name = 'RSA1_5'
|
|
description = "RSAES-PKCS1-v1_5"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
def __init__(self):
|
|
super(_Rsa15, self).__init__(padding.PKCS1v15())
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
self._check_key(key)
|
|
# Address MMA attack by implementing RFC 3218 - 2.3.2. Random Filling
|
|
# provides a random cek that will cause the decryption engine to
|
|
# run to the end, but will fail decryption later.
|
|
|
|
# always generate a random cek so we spend roughly the
|
|
# same time as in the exception side of the branch
|
|
cek = _randombits(bitsize)
|
|
try:
|
|
cek = super(_Rsa15, self).unwrap(key, bitsize, ek, headers)
|
|
# always raise so we always run through the exception handling
|
|
# code in all cases
|
|
raise Exception('Dummy')
|
|
except Exception: # pylint: disable=broad-except
|
|
return cek
|
|
|
|
|
|
class _RsaOaep(_RSA, JWAAlgorithm):
|
|
|
|
name = 'RSA-OAEP'
|
|
description = "RSAES OAEP using default parameters"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
def __init__(self):
|
|
super(_RsaOaep, self).__init__(
|
|
padding.OAEP(padding.MGF1(hashes.SHA1()),
|
|
hashes.SHA1(), None))
|
|
|
|
|
|
class _RsaOaep256(_RSA, JWAAlgorithm): # noqa: ignore=N801
|
|
|
|
name = 'RSA-OAEP-256'
|
|
description = "RSAES OAEP using SHA-256 and MGF1 with SHA-256"
|
|
keysize = 2048
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
def __init__(self):
|
|
super(_RsaOaep256, self).__init__(
|
|
padding.OAEP(padding.MGF1(hashes.SHA256()),
|
|
hashes.SHA256(), None))
|
|
|
|
|
|
class _AesKw(_RawKeyMgmt):
|
|
|
|
keysize = None
|
|
|
|
def __init__(self):
|
|
self.backend = default_backend()
|
|
|
|
def _get_key(self, key, op):
|
|
if not isinstance(key, JWK):
|
|
raise ValueError('key is not a JWK object')
|
|
if key.key_type != 'oct':
|
|
raise InvalidJWEKeyType('oct', key.key_type)
|
|
rk = base64url_decode(key.get_op_key(op))
|
|
if _bitsize(rk) != self.keysize:
|
|
raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
|
|
return rk
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
rk = self._get_key(key, 'encrypt')
|
|
if not cek:
|
|
cek = _randombits(bitsize)
|
|
|
|
ek = aes_key_wrap(rk, cek, default_backend())
|
|
|
|
return {'cek': cek, 'ek': ek}
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
rk = self._get_key(key, 'decrypt')
|
|
|
|
cek = aes_key_unwrap(rk, ek, default_backend())
|
|
if _bitsize(cek) != bitsize:
|
|
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
|
|
return cek
|
|
|
|
|
|
class _A128KW(_AesKw, JWAAlgorithm):
|
|
|
|
name = 'A128KW'
|
|
description = "AES Key Wrap using 128-bit key"
|
|
keysize = 128
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _A192KW(_AesKw, JWAAlgorithm):
|
|
|
|
name = 'A192KW'
|
|
description = "AES Key Wrap using 192-bit key"
|
|
keysize = 192
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _A256KW(_AesKw, JWAAlgorithm):
|
|
|
|
name = 'A256KW'
|
|
description = "AES Key Wrap using 256-bit key"
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _AesGcmKw(_RawKeyMgmt):
|
|
|
|
keysize = None
|
|
|
|
def __init__(self):
|
|
self.backend = default_backend()
|
|
|
|
def _get_key(self, key, op):
|
|
if not isinstance(key, JWK):
|
|
raise ValueError('key is not a JWK object')
|
|
if key.key_type != 'oct':
|
|
raise InvalidJWEKeyType('oct', key.key_type)
|
|
rk = base64url_decode(key.get_op_key(op))
|
|
if _bitsize(rk) != self.keysize:
|
|
raise InvalidJWEKeyLength(self.keysize, _bitsize(rk))
|
|
return rk
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
rk = self._get_key(key, 'encrypt')
|
|
if not cek:
|
|
cek = _randombits(bitsize)
|
|
|
|
iv = _randombits(96)
|
|
cipher = Cipher(algorithms.AES(rk), modes.GCM(iv),
|
|
backend=self.backend)
|
|
encryptor = cipher.encryptor()
|
|
ek = encryptor.update(cek) + encryptor.finalize()
|
|
|
|
tag = encryptor.tag
|
|
return {'cek': cek, 'ek': ek,
|
|
'header': {'iv': base64url_encode(iv),
|
|
'tag': base64url_encode(tag)}}
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
rk = self._get_key(key, 'decrypt')
|
|
|
|
if 'iv' not in headers:
|
|
raise ValueError('Invalid Header, missing "iv" parameter')
|
|
iv = base64url_decode(headers['iv'])
|
|
if 'tag' not in headers:
|
|
raise ValueError('Invalid Header, missing "tag" parameter')
|
|
tag = base64url_decode(headers['tag'])
|
|
|
|
cipher = Cipher(algorithms.AES(rk), modes.GCM(iv, tag),
|
|
backend=self.backend)
|
|
decryptor = cipher.decryptor()
|
|
cek = decryptor.update(ek) + decryptor.finalize()
|
|
if _bitsize(cek) != bitsize:
|
|
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
|
|
return cek
|
|
|
|
|
|
class _A128GcmKw(_AesGcmKw, JWAAlgorithm):
|
|
|
|
name = 'A128GCMKW'
|
|
description = "Key wrapping with AES GCM using 128-bit key"
|
|
keysize = 128
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _A192GcmKw(_AesGcmKw, JWAAlgorithm):
|
|
|
|
name = 'A192GCMKW'
|
|
description = "Key wrapping with AES GCM using 192-bit key"
|
|
keysize = 192
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _A256GcmKw(_AesGcmKw, JWAAlgorithm):
|
|
|
|
name = 'A256GCMKW'
|
|
description = "Key wrapping with AES GCM using 256-bit key"
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _Pbes2HsAesKw(_RawKeyMgmt):
|
|
|
|
name = None
|
|
keysize = None
|
|
hashsize = None
|
|
|
|
def __init__(self):
|
|
self.backend = default_backend()
|
|
self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}
|
|
|
|
def _get_key(self, alg, key, p2s, p2c):
|
|
if isinstance(key, bytes):
|
|
plain = key
|
|
else:
|
|
plain = key.encode('utf8')
|
|
salt = bytes(self.name.encode('utf8')) + b'\x00' + p2s
|
|
|
|
if self.hashsize == 256:
|
|
hashalg = hashes.SHA256()
|
|
elif self.hashsize == 384:
|
|
hashalg = hashes.SHA384()
|
|
elif self.hashsize == 512:
|
|
hashalg = hashes.SHA512()
|
|
else:
|
|
raise ValueError('Unknown Hash Size')
|
|
|
|
kdf = PBKDF2HMAC(algorithm=hashalg, length=_inbytes(self.keysize),
|
|
salt=salt, iterations=p2c, backend=self.backend)
|
|
rk = kdf.derive(plain)
|
|
if _bitsize(rk) != self.keysize:
|
|
raise InvalidJWEKeyLength(self.keysize, len(rk))
|
|
return JWK(kty="oct", use="enc", k=base64url_encode(rk))
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
p2s = _randombits(128)
|
|
p2c = 8192
|
|
kek = self._get_key(headers['alg'], key, p2s, p2c)
|
|
|
|
aeskw = self.aeskwmap[self.keysize]()
|
|
ret = aeskw.wrap(kek, bitsize, cek, headers)
|
|
ret['header'] = {'p2s': base64url_encode(p2s), 'p2c': p2c}
|
|
return ret
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
if 'p2s' not in headers:
|
|
raise ValueError('Invalid Header, missing "p2s" parameter')
|
|
if 'p2c' not in headers:
|
|
raise ValueError('Invalid Header, missing "p2c" parameter')
|
|
p2s = base64url_decode(headers['p2s'])
|
|
p2c = headers['p2c']
|
|
kek = self._get_key(headers['alg'], key, p2s, p2c)
|
|
|
|
aeskw = self.aeskwmap[self.keysize]()
|
|
return aeskw.unwrap(kek, bitsize, ek, headers)
|
|
|
|
|
|
class _Pbes2Hs256A128Kw(_Pbes2HsAesKw, JWAAlgorithm):
|
|
|
|
name = 'PBES2-HS256+A128KW'
|
|
description = 'PBES2 with HMAC SHA-256 and "A128KW" wrapping'
|
|
keysize = 128
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
hashsize = 256
|
|
|
|
|
|
class _Pbes2Hs384A192Kw(_Pbes2HsAesKw, JWAAlgorithm):
|
|
|
|
name = 'PBES2-HS384+A192KW'
|
|
description = 'PBES2 with HMAC SHA-384 and "A192KW" wrapping'
|
|
keysize = 192
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
hashsize = 384
|
|
|
|
|
|
class _Pbes2Hs512A256Kw(_Pbes2HsAesKw, JWAAlgorithm):
|
|
|
|
name = 'PBES2-HS512+A256KW'
|
|
description = 'PBES2 with HMAC SHA-512 and "A256KW" wrapping'
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
hashsize = 512
|
|
|
|
|
|
class _Direct(_RawKeyMgmt, JWAAlgorithm):
|
|
|
|
name = 'dir'
|
|
description = "Direct use of a shared symmetric key"
|
|
keysize = 128
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
def _check_key(self, key):
|
|
if not isinstance(key, JWK):
|
|
raise ValueError('key is not a JWK object')
|
|
if key.key_type != 'oct':
|
|
raise InvalidJWEKeyType('oct', key.key_type)
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
self._check_key(key)
|
|
if cek:
|
|
return (cek, None)
|
|
k = base64url_decode(key.get_op_key('encrypt'))
|
|
if _bitsize(k) != bitsize:
|
|
raise InvalidCEKeyLength(bitsize, _bitsize(k))
|
|
return {'cek': k}
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
self._check_key(key)
|
|
if ek != b'':
|
|
raise ValueError('Invalid Encryption Key.')
|
|
cek = base64url_decode(key.get_op_key('decrypt'))
|
|
if _bitsize(cek) != bitsize:
|
|
raise InvalidJWEKeyLength(bitsize, _bitsize(cek))
|
|
return cek
|
|
|
|
|
|
class _EcdhEs(_RawKeyMgmt, JWAAlgorithm):
|
|
|
|
name = 'ECDH-ES'
|
|
description = "ECDH-ES using Concat KDF"
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
keysize = None
|
|
|
|
def __init__(self):
|
|
self.backend = default_backend()
|
|
self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW}
|
|
|
|
def _check_key(self, key):
|
|
if not isinstance(key, JWK):
|
|
raise ValueError('key is not a JWK object')
|
|
if key.key_type != 'EC':
|
|
raise InvalidJWEKeyType('EC', key.key_type)
|
|
|
|
def _derive(self, privkey, pubkey, alg, bitsize, headers):
|
|
# OtherInfo is defined in NIST SP 56A 5.8.1.2.1
|
|
|
|
# AlgorithmID
|
|
otherinfo = struct.pack('>I', len(alg))
|
|
otherinfo += bytes(alg.encode('utf8'))
|
|
|
|
# PartyUInfo
|
|
apu = base64url_decode(headers['apu']) if 'apu' in headers else b''
|
|
otherinfo += struct.pack('>I', len(apu))
|
|
otherinfo += apu
|
|
|
|
# PartyVInfo
|
|
apv = base64url_decode(headers['apv']) if 'apv' in headers else b''
|
|
otherinfo += struct.pack('>I', len(apv))
|
|
otherinfo += apv
|
|
|
|
# SuppPubInfo
|
|
otherinfo += struct.pack('>I', bitsize)
|
|
|
|
# no SuppPrivInfo
|
|
|
|
shared_key = privkey.exchange(ec.ECDH(), pubkey)
|
|
ckdf = ConcatKDFHash(algorithm=hashes.SHA256(),
|
|
length=_inbytes(bitsize),
|
|
otherinfo=otherinfo,
|
|
backend=self.backend)
|
|
return ckdf.derive(shared_key)
|
|
|
|
def wrap(self, key, bitsize, cek, headers):
|
|
self._check_key(key)
|
|
dk_size = self.keysize
|
|
if self.keysize is None:
|
|
if cek is not None:
|
|
raise InvalidJWEOperation('ECDH-ES cannot use an existing CEK')
|
|
alg = headers['enc']
|
|
dk_size = bitsize
|
|
else:
|
|
alg = headers['alg']
|
|
|
|
epk = JWK.generate(kty=key.key_type, crv=key.key_curve)
|
|
dk = self._derive(epk.get_op_key('unwrapKey'),
|
|
key.get_op_key('wrapKey'),
|
|
alg, dk_size, headers)
|
|
|
|
if self.keysize is None:
|
|
ret = {'cek': dk}
|
|
else:
|
|
aeskw = self.aeskwmap[self.keysize]()
|
|
kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
|
|
ret = aeskw.wrap(kek, bitsize, cek, headers)
|
|
|
|
ret['header'] = {'epk': json_decode(epk.export_public())}
|
|
return ret
|
|
|
|
def unwrap(self, key, bitsize, ek, headers):
|
|
if 'epk' not in headers:
|
|
raise ValueError('Invalid Header, missing "epk" parameter')
|
|
self._check_key(key)
|
|
dk_size = self.keysize
|
|
if self.keysize is None:
|
|
alg = headers['enc']
|
|
dk_size = bitsize
|
|
else:
|
|
alg = headers['alg']
|
|
|
|
epk = JWK(**headers['epk'])
|
|
dk = self._derive(key.get_op_key('unwrapKey'),
|
|
epk.get_op_key('wrapKey'),
|
|
alg, dk_size, headers)
|
|
if self.keysize is None:
|
|
return dk
|
|
else:
|
|
aeskw = self.aeskwmap[self.keysize]()
|
|
kek = JWK(kty="oct", use="enc", k=base64url_encode(dk))
|
|
cek = aeskw.unwrap(kek, bitsize, ek, headers)
|
|
return cek
|
|
|
|
|
|
class _EcdhEsAes128Kw(_EcdhEs):
|
|
|
|
name = 'ECDH-ES+A128KW'
|
|
description = 'ECDH-ES using Concat KDF and "A128KW" wrapping'
|
|
keysize = 128
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _EcdhEsAes192Kw(_EcdhEs):
|
|
|
|
name = 'ECDH-ES+A192KW'
|
|
description = 'ECDH-ES using Concat KDF and "A192KW" wrapping'
|
|
keysize = 192
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _EcdhEsAes256Kw(_EcdhEs):
|
|
|
|
name = 'ECDH-ES+A256KW'
|
|
description = 'ECDH-ES using Concat KDF and "A256KW" wrapping'
|
|
keysize = 256
|
|
algorithm_usage_location = 'alg'
|
|
algorithm_use = 'kex'
|
|
|
|
|
|
class _RawJWE(object):
|
|
|
|
def encrypt(self, k, a, m):
|
|
raise NotImplementedError
|
|
|
|
def decrypt(self, k, a, iv, e, t):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _AesCbcHmacSha2(_RawJWE):
|
|
|
|
keysize = None
|
|
|
|
def __init__(self, hashfn):
|
|
self.backend = default_backend()
|
|
self.hashfn = hashfn
|
|
self.blocksize = algorithms.AES.block_size
|
|
self.wrap_key_size = self.keysize * 2
|
|
|
|
def _mac(self, k, a, iv, e):
|
|
al = _encode_int(_bitsize(a), 64)
|
|
h = hmac.HMAC(k, self.hashfn, backend=self.backend)
|
|
h.update(a)
|
|
h.update(iv)
|
|
h.update(e)
|
|
h.update(al)
|
|
m = h.finalize()
|
|
return m[:_inbytes(self.keysize)]
|
|
|
|
# RFC 7518 - 5.2.2
|
|
def encrypt(self, k, a, m):
|
|
""" Encrypt according to the selected encryption and hashing
|
|
functions.
|
|
|
|
:param k: Encryption key (optional)
|
|
:param a: Additional Authentication Data
|
|
:param m: Plaintext
|
|
|
|
Returns a dictionary with the computed data.
|
|
"""
|
|
hkey = k[:_inbytes(self.keysize)]
|
|
ekey = k[_inbytes(self.keysize):]
|
|
|
|
# encrypt
|
|
iv = _randombits(self.blocksize)
|
|
cipher = Cipher(algorithms.AES(ekey), modes.CBC(iv),
|
|
backend=self.backend)
|
|
encryptor = cipher.encryptor()
|
|
padder = PKCS7(self.blocksize).padder()
|
|
padded_data = padder.update(m) + padder.finalize()
|
|
e = encryptor.update(padded_data) + encryptor.finalize()
|
|
|
|
# mac
|
|
t = self._mac(hkey, a, iv, e)
|
|
|
|
return (iv, e, t)
|
|
|
|
def decrypt(self, k, a, iv, e, t):
|
|
""" Decrypt according to the selected encryption and hashing
|
|
functions.
|
|
:param k: Encryption key (optional)
|
|
:param a: Additional Authenticated Data
|
|
:param iv: Initialization Vector
|
|
:param e: Ciphertext
|
|
:param t: Authentication Tag
|
|
|
|
Returns plaintext or raises an error
|
|
"""
|
|
hkey = k[:_inbytes(self.keysize)]
|
|
dkey = k[_inbytes(self.keysize):]
|
|
|
|
# verify mac
|
|
if not constant_time.bytes_eq(t, self._mac(hkey, a, iv, e)):
|
|
raise InvalidSignature('Failed to verify MAC')
|
|
|
|
# decrypt
|
|
cipher = Cipher(algorithms.AES(dkey), modes.CBC(iv),
|
|
backend=self.backend)
|
|
decryptor = cipher.decryptor()
|
|
d = decryptor.update(e) + decryptor.finalize()
|
|
unpadder = PKCS7(self.blocksize).unpadder()
|
|
return unpadder.update(d) + unpadder.finalize()
|
|
|
|
|
|
class _A128CbcHs256(_AesCbcHmacSha2, JWAAlgorithm):
|
|
|
|
name = 'A128CBC-HS256'
|
|
description = "AES_128_CBC_HMAC_SHA_256 authenticated"
|
|
keysize = 128
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
def __init__(self):
|
|
super(_A128CbcHs256, self).__init__(hashes.SHA256())
|
|
|
|
|
|
class _A192CbcHs384(_AesCbcHmacSha2, JWAAlgorithm):
|
|
|
|
name = 'A192CBC-HS384'
|
|
description = "AES_192_CBC_HMAC_SHA_384 authenticated"
|
|
keysize = 192
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
def __init__(self):
|
|
super(_A192CbcHs384, self).__init__(hashes.SHA384())
|
|
|
|
|
|
class _A256CbcHs512(_AesCbcHmacSha2, JWAAlgorithm):
|
|
|
|
name = 'A256CBC-HS512'
|
|
description = "AES_256_CBC_HMAC_SHA_512 authenticated"
|
|
keysize = 256
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
def __init__(self):
|
|
super(_A256CbcHs512, self).__init__(hashes.SHA512())
|
|
|
|
|
|
class _AesGcm(_RawJWE):
|
|
|
|
keysize = None
|
|
|
|
def __init__(self):
|
|
self.backend = default_backend()
|
|
self.wrap_key_size = self.keysize
|
|
|
|
# RFC 7518 - 5.3
|
|
def encrypt(self, k, a, m):
|
|
""" Encrypt accoriding to the selected encryption and hashing
|
|
functions.
|
|
|
|
:param k: Encryption key (optional)
|
|
:param a: Additional Authentication Data
|
|
:param m: Plaintext
|
|
|
|
Returns a dictionary with the computed data.
|
|
"""
|
|
iv = _randombits(96)
|
|
cipher = Cipher(algorithms.AES(k), modes.GCM(iv),
|
|
backend=self.backend)
|
|
encryptor = cipher.encryptor()
|
|
encryptor.authenticate_additional_data(a)
|
|
e = encryptor.update(m) + encryptor.finalize()
|
|
|
|
return (iv, e, encryptor.tag)
|
|
|
|
def decrypt(self, k, a, iv, e, t):
|
|
""" Decrypt accoriding to the selected encryption and hashing
|
|
functions.
|
|
:param k: Encryption key (optional)
|
|
:param a: Additional Authenticated Data
|
|
:param iv: Initialization Vector
|
|
:param e: Ciphertext
|
|
:param t: Authentication Tag
|
|
|
|
Returns plaintext or raises an error
|
|
"""
|
|
cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t),
|
|
backend=self.backend)
|
|
decryptor = cipher.decryptor()
|
|
decryptor.authenticate_additional_data(a)
|
|
return decryptor.update(e) + decryptor.finalize()
|
|
|
|
|
|
class _A128Gcm(_AesGcm, JWAAlgorithm):
|
|
|
|
name = 'A128GCM'
|
|
description = "AES GCM using 128-bit key"
|
|
keysize = 128
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
|
|
class _A192Gcm(_AesGcm, JWAAlgorithm):
|
|
|
|
name = 'A192GCM'
|
|
description = "AES GCM using 192-bit key"
|
|
keysize = 192
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
|
|
class _A256Gcm(_AesGcm, JWAAlgorithm):
|
|
|
|
name = 'A256GCM'
|
|
description = "AES GCM using 256-bit key"
|
|
keysize = 256
|
|
algorithm_usage_location = 'enc'
|
|
algorithm_use = 'enc'
|
|
|
|
|
|
class JWA(object):
|
|
"""JWA Signing Algorithms.
|
|
|
|
This class provides access to all JWA algorithms.
|
|
"""
|
|
|
|
algorithms_registry = {
|
|
'HS256': _HS256,
|
|
'HS384': _HS384,
|
|
'HS512': _HS512,
|
|
'RS256': _RS256,
|
|
'RS384': _RS384,
|
|
'RS512': _RS512,
|
|
'ES256': _ES256,
|
|
'ES384': _ES384,
|
|
'ES512': _ES512,
|
|
'PS256': _PS256,
|
|
'PS384': _PS384,
|
|
'PS512': _PS512,
|
|
'none': _None,
|
|
'RSA1_5': _Rsa15,
|
|
'RSA-OAEP': _RsaOaep,
|
|
'RSA-OAEP-256': _RsaOaep256,
|
|
'A128KW': _A128KW,
|
|
'A192KW': _A192KW,
|
|
'A256KW': _A256KW,
|
|
'dir': _Direct,
|
|
'ECDH-ES': _EcdhEs,
|
|
'ECDH-ES+A128KW': _EcdhEsAes128Kw,
|
|
'ECDH-ES+A192KW': _EcdhEsAes192Kw,
|
|
'ECDH-ES+A256KW': _EcdhEsAes256Kw,
|
|
'A128GCMKW': _A128GcmKw,
|
|
'A192GCMKW': _A192GcmKw,
|
|
'A256GCMKW': _A256GcmKw,
|
|
'PBES2-HS256+A128KW': _Pbes2Hs256A128Kw,
|
|
'PBES2-HS384+A192KW': _Pbes2Hs384A192Kw,
|
|
'PBES2-HS512+A256KW': _Pbes2Hs512A256Kw,
|
|
'A128CBC-HS256': _A128CbcHs256,
|
|
'A192CBC-HS384': _A192CbcHs384,
|
|
'A256CBC-HS512': _A256CbcHs512,
|
|
'A128GCM': _A128Gcm,
|
|
'A192GCM': _A192Gcm,
|
|
'A256GCM': _A256Gcm
|
|
}
|
|
|
|
@classmethod
|
|
def instantiate_alg(cls, name, use=None):
|
|
alg = cls.algorithms_registry[name]
|
|
if use is not None and alg.algorithm_use != use:
|
|
raise KeyError
|
|
return alg()
|
|
|
|
@classmethod
|
|
def signing_alg(cls, name):
|
|
try:
|
|
return cls.instantiate_alg(name, use='sig')
|
|
except KeyError:
|
|
raise InvalidJWAAlgorithm(
|
|
'%s is not a valid Signign algorithm name' % name)
|
|
|
|
@classmethod
|
|
def keymgmt_alg(cls, name):
|
|
try:
|
|
return cls.instantiate_alg(name, use='kex')
|
|
except KeyError:
|
|
raise InvalidJWAAlgorithm(
|
|
'%s is not a valid Key Management algorithm name' % name)
|
|
|
|
@classmethod
|
|
def encryption_alg(cls, name):
|
|
try:
|
|
return cls.instantiate_alg(name, use='enc')
|
|
except KeyError:
|
|
raise InvalidJWAAlgorithm(
|
|
'%s is not a valid Encryption algorithm name' % name)
|