406 lines
13 KiB
Python
406 lines
13 KiB
Python
import functools
|
|
import os
|
|
import struct
|
|
|
|
from cryptography.exceptions import InvalidTag
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
from cryptography.hazmat.primitives.ciphers import (
|
|
Cipher, algorithms, modes
|
|
)
|
|
from cryptography.hazmat.primitives.serialization import (
|
|
Encoding, PublicFormat
|
|
)
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
|
MAX_RECORD_SIZE = pow(2, 31) - 1
|
|
MIN_RECORD_SIZE = 3
|
|
KEY_LENGTH = 16
|
|
NONCE_LENGTH = 12
|
|
TAG_LENGTH = 16
|
|
|
|
# Valid content types (ordered from newest, to most obsolete)
|
|
versions = {
|
|
"aes128gcm": {"pad": 1},
|
|
"aesgcm": {"pad": 2},
|
|
"aesgcm128": {"pad": 1},
|
|
}
|
|
|
|
|
|
class ECEException(Exception):
|
|
"""Exception for ECE encryption functions"""
|
|
def __init__(self, message):
|
|
self.message = message
|
|
|
|
|
|
def derive_key(mode, version, salt, key,
|
|
private_key, dh, auth_secret,
|
|
keyid, keylabel="P-256"):
|
|
"""Derive the encryption key
|
|
|
|
:param mode: operational mode (encrypt or decrypt)
|
|
:type mode: enumerate('encrypt', 'decrypt)
|
|
:param salt: encryption salt value
|
|
:type salt: str
|
|
:param key: raw key
|
|
:type key: str
|
|
:param private_key: DH private key
|
|
:type key: object
|
|
:param dh: Diffie Helman public key value
|
|
:type dh: str
|
|
:param keyid: key identifier label
|
|
:type keyid: str
|
|
:param keylabel: label for aesgcm/aesgcm128
|
|
:type keylabel: str
|
|
:param auth_secret: authorization secret
|
|
:type auth_secret: str
|
|
:param version: Content Type identifier
|
|
:type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')
|
|
|
|
"""
|
|
context = b""
|
|
keyinfo = ""
|
|
nonceinfo = ""
|
|
|
|
def build_info(base, info_context):
|
|
return b"Content-Encoding: " + base + b"\0" + info_context
|
|
|
|
def derive_dh(mode, version, private_key, dh, keylabel):
|
|
def length_prefix(key):
|
|
return struct.pack("!H", len(key)) + key
|
|
if isinstance(dh, ec.EllipticCurvePublicKey):
|
|
pubkey = dh
|
|
dh = dh.public_bytes(
|
|
Encoding.X962,
|
|
PublicFormat.UncompressedPoint)
|
|
else:
|
|
pubkey = ec.EllipticCurvePublicKey.from_encoded_point(
|
|
ec.SECP256R1(),
|
|
dh
|
|
)
|
|
|
|
encoded = private_key.public_key().public_bytes(
|
|
Encoding.X962,
|
|
PublicFormat.UncompressedPoint)
|
|
if mode == "encrypt":
|
|
sender_pub_key = encoded
|
|
receiver_pub_key = dh
|
|
else:
|
|
sender_pub_key = dh
|
|
receiver_pub_key = encoded
|
|
|
|
if version == "aes128gcm":
|
|
context = b"WebPush: info\x00" + receiver_pub_key + sender_pub_key
|
|
else:
|
|
context = (keylabel.encode('utf-8') + b"\0" +
|
|
length_prefix(receiver_pub_key) +
|
|
length_prefix(sender_pub_key))
|
|
|
|
return private_key.exchange(ec.ECDH(), pubkey), context
|
|
|
|
if version not in versions:
|
|
raise ECEException(u"Invalid version")
|
|
if mode not in ['encrypt', 'decrypt']:
|
|
raise ECEException(u"unknown 'mode' specified: " + mode)
|
|
if salt is None or len(salt) != KEY_LENGTH:
|
|
raise ECEException(u"'salt' must be a 16 octet value")
|
|
if dh is not None:
|
|
if private_key is None:
|
|
raise ECEException(u"DH requires a private_key")
|
|
(secret, context) = derive_dh(mode=mode, version=version,
|
|
private_key=private_key, dh=dh,
|
|
keylabel=keylabel)
|
|
else:
|
|
secret = key
|
|
|
|
if secret is None:
|
|
raise ECEException(u"unable to determine the secret")
|
|
|
|
if version == "aesgcm":
|
|
keyinfo = build_info(b"aesgcm", context)
|
|
nonceinfo = build_info(b"nonce", context)
|
|
elif version == "aesgcm128":
|
|
keyinfo = b"Content-Encoding: aesgcm128"
|
|
nonceinfo = b"Content-Encoding: nonce"
|
|
elif version == "aes128gcm":
|
|
keyinfo = b"Content-Encoding: aes128gcm\x00"
|
|
nonceinfo = b"Content-Encoding: nonce\x00"
|
|
if dh is None:
|
|
# Only mix the authentication secret when using DH for aes128gcm
|
|
auth_secret = None
|
|
|
|
if auth_secret is not None:
|
|
if version == "aes128gcm":
|
|
info = context
|
|
else:
|
|
info = build_info(b'auth', b'')
|
|
hkdf_auth = HKDF(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=auth_secret,
|
|
info=info,
|
|
backend=default_backend()
|
|
)
|
|
secret = hkdf_auth.derive(secret)
|
|
|
|
hkdf_key = HKDF(
|
|
algorithm=hashes.SHA256(),
|
|
length=KEY_LENGTH,
|
|
salt=salt,
|
|
info=keyinfo,
|
|
backend=default_backend()
|
|
)
|
|
hkdf_nonce = HKDF(
|
|
algorithm=hashes.SHA256(),
|
|
length=NONCE_LENGTH,
|
|
salt=salt,
|
|
info=nonceinfo,
|
|
backend=default_backend()
|
|
)
|
|
return hkdf_key.derive(secret), hkdf_nonce.derive(secret)
|
|
|
|
|
|
def iv(base, counter):
|
|
"""Generate an initialization vector.
|
|
|
|
"""
|
|
if (counter >> 64) != 0:
|
|
raise ECEException(u"Counter too big")
|
|
(mask,) = struct.unpack("!Q", base[4:])
|
|
return base[:4] + struct.pack("!Q", counter ^ mask)
|
|
|
|
|
|
def decrypt(content, salt=None, key=None,
|
|
private_key=None, dh=None, auth_secret=None,
|
|
keyid=None, keylabel="P-256",
|
|
rs=4096, version="aes128gcm"):
|
|
"""
|
|
Decrypt a data block
|
|
|
|
:param content: Data to be decrypted
|
|
:type content: str
|
|
:param salt: Encryption salt
|
|
:type salt: str
|
|
:param key: local public key
|
|
:type key: str
|
|
:param private_key: DH private key
|
|
:type key: object
|
|
:param keyid: Internal key identifier for private key info
|
|
:type keyid: str
|
|
:param dh: Remote Diffie Hellman sequence (omit for aes128gcm)
|
|
:type dh: str
|
|
:param rs: Record size
|
|
:type rs: int
|
|
:param auth_secret: Authorization secret
|
|
:type auth_secret: str
|
|
:param version: ECE Method version
|
|
:type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')
|
|
:return: Decrypted message content
|
|
:rtype str
|
|
|
|
"""
|
|
def parse_content_header(content):
|
|
"""Parse an aes128gcm content body and extract the header values.
|
|
|
|
:param content: The encrypted body of the message
|
|
:type content: str
|
|
|
|
"""
|
|
id_len = struct.unpack("!B", content[20:21])[0]
|
|
return {
|
|
"salt": content[:16],
|
|
"rs": struct.unpack("!L", content[16:20])[0],
|
|
"keyid": content[21:21 + id_len],
|
|
"content": content[21 + id_len:],
|
|
}
|
|
|
|
def decrypt_record(key, nonce, counter, content):
|
|
decryptor = Cipher(
|
|
algorithms.AES(key),
|
|
modes.GCM(iv(nonce, counter), tag=content[-TAG_LENGTH:]),
|
|
backend=default_backend()
|
|
).decryptor()
|
|
return decryptor.update(content[:-TAG_LENGTH]) + decryptor.finalize()
|
|
|
|
def unpad_legacy(data):
|
|
pad_size = versions[version]['pad']
|
|
pad = functools.reduce(
|
|
lambda x, y: x << 8 | y, struct.unpack(
|
|
"!" + ("B" * pad_size), data[0:pad_size])
|
|
)
|
|
if pad_size + pad > len(data) or \
|
|
data[pad_size:pad_size+pad] != (b"\x00" * pad):
|
|
raise ECEException(u"Bad padding")
|
|
return data[pad_size + pad:]
|
|
|
|
def unpad(data, last):
|
|
i = len(data) - 1
|
|
for i in range(len(data) - 1, -1, -1):
|
|
v = struct.unpack('B', data[i:i+1])[0]
|
|
if v != 0:
|
|
if not last and v != 1:
|
|
raise ECEException(u'record delimiter != 1')
|
|
if last and v != 2:
|
|
raise ECEException(u'last record delimiter != 2')
|
|
return data[0:i]
|
|
raise ECEException(u'all zero record plaintext')
|
|
|
|
if version not in versions:
|
|
raise ECEException(u"Invalid version")
|
|
|
|
overhead = versions[version]['pad']
|
|
if version == "aes128gcm":
|
|
try:
|
|
content_header = parse_content_header(content)
|
|
except Exception:
|
|
raise ECEException("Could not parse the content header")
|
|
salt = content_header['salt']
|
|
rs = content_header['rs']
|
|
keyid = content_header['keyid']
|
|
if private_key is not None and not dh:
|
|
dh = keyid
|
|
else:
|
|
keyid = keyid.decode('utf-8')
|
|
content = content_header['content']
|
|
overhead += 16
|
|
|
|
(key_, nonce_) = derive_key("decrypt", version=version,
|
|
salt=salt, key=key,
|
|
private_key=private_key, dh=dh,
|
|
auth_secret=auth_secret,
|
|
keyid=keyid, keylabel=keylabel)
|
|
if rs <= overhead:
|
|
raise ECEException(u"Record size too small")
|
|
chunk = rs
|
|
if version != "aes128gcm":
|
|
chunk += 16 # account for tags in old versions
|
|
if len(content) % chunk == 0:
|
|
raise ECEException(u"Message truncated")
|
|
|
|
result = b''
|
|
counter = 0
|
|
try:
|
|
for i in list(range(0, len(content), chunk)):
|
|
data = decrypt_record(key_, nonce_, counter, content[i:i + chunk])
|
|
if version == 'aes128gcm':
|
|
last = (i + chunk) >= len(content)
|
|
result += unpad(data, last)
|
|
else:
|
|
result += unpad_legacy(data)
|
|
counter += 1
|
|
except InvalidTag as ex:
|
|
raise ECEException("Decryption error: {}".format(repr(ex)))
|
|
return result
|
|
|
|
|
|
def encrypt(content, salt=None, key=None,
|
|
private_key=None, dh=None, auth_secret=None,
|
|
keyid=None, keylabel="P-256",
|
|
rs=4096, version="aes128gcm"):
|
|
"""
|
|
Encrypt a data block
|
|
|
|
:param content: block of data to encrypt
|
|
:type content: str
|
|
:param salt: Encryption salt
|
|
:type salt: str
|
|
:param key: Encryption key data
|
|
:type key: str
|
|
:param private_key: DH private key
|
|
:type key: object
|
|
:param keyid: Internal key identifier for private key info
|
|
:type keyid: str
|
|
:param dh: Remote Diffie Hellman sequence
|
|
:type dh: str
|
|
:param rs: Record size
|
|
:type rs: int
|
|
:param auth_secret: Authorization secret
|
|
:type auth_secret: str
|
|
:param version: ECE Method version
|
|
:type version: enumerate('aes128gcm', 'aesgcm', 'aesgcm128')
|
|
:return: Encrypted message content
|
|
:rtype str
|
|
|
|
"""
|
|
def encrypt_record(key, nonce, counter, buf, last):
|
|
encryptor = Cipher(
|
|
algorithms.AES(key),
|
|
modes.GCM(iv(nonce, counter)),
|
|
backend=default_backend()
|
|
).encryptor()
|
|
|
|
if version == 'aes128gcm':
|
|
data = encryptor.update(buf + (b'\x02' if last else b'\x01'))
|
|
else:
|
|
data = encryptor.update((b"\x00" * versions[version]['pad']) + buf)
|
|
data += encryptor.finalize()
|
|
data += encryptor.tag
|
|
return data
|
|
|
|
def compose_aes128gcm(salt, content, rs, keyid):
|
|
"""Compose the header and content of an aes128gcm encrypted
|
|
message body
|
|
|
|
:param salt: The sender's salt value
|
|
:type salt: str
|
|
:param content: The encrypted body of the message
|
|
:type content: str
|
|
:param rs: Override for the content length
|
|
:type rs: int
|
|
:param keyid: The keyid to use for this message
|
|
:type keyid: str
|
|
|
|
"""
|
|
if len(keyid) > 255:
|
|
raise ECEException("keyid is too long")
|
|
header = salt
|
|
if rs > MAX_RECORD_SIZE:
|
|
raise ECEException("Too much content")
|
|
header += struct.pack("!L", rs)
|
|
header += struct.pack("!B", len(keyid))
|
|
header += keyid
|
|
return header + content
|
|
|
|
if version not in versions:
|
|
raise ECEException(u"Invalid version")
|
|
|
|
if salt is None:
|
|
salt = os.urandom(16)
|
|
|
|
(key_, nonce_) = derive_key("encrypt", version=version,
|
|
salt=salt, key=key,
|
|
private_key=private_key, dh=dh,
|
|
auth_secret=auth_secret,
|
|
keyid=keyid, keylabel=keylabel)
|
|
|
|
overhead = versions[version]['pad']
|
|
if version == 'aes128gcm':
|
|
overhead += 16
|
|
end = len(content)
|
|
else:
|
|
end = len(content) + 1
|
|
if rs <= overhead:
|
|
raise ECEException(u"Record size too small")
|
|
chunk_size = rs - overhead
|
|
|
|
result = b""
|
|
counter = 0
|
|
|
|
# the extra one on the loop ensures that we produce a padding only
|
|
# record if the data length is an exact multiple of the chunk size
|
|
for i in list(range(0, end, chunk_size)):
|
|
result += encrypt_record(key_, nonce_, counter,
|
|
content[i:i + chunk_size],
|
|
(i + chunk_size) >= end)
|
|
counter += 1
|
|
if version == "aes128gcm":
|
|
if keyid is None and private_key is not None:
|
|
kid = private_key.public_key().public_bytes(
|
|
Encoding.X962,
|
|
PublicFormat.UncompressedPoint)
|
|
else:
|
|
kid = (keyid or '').encode('utf-8')
|
|
return compose_aes128gcm(salt, result, rs, keyid=kid)
|
|
return result
|