From 6913bacef4416b1b8a5638556464eda5914d3742 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 24 Sep 2020 18:18:26 +0200 Subject: [PATCH] misc: rewrite x509utils using modern API (#46984) --- src/authentic2/saml/x509utils.py | 129 ++++++++++++++++--------------- tests/test_saml_x509utils.py | 9 ++- 2 files changed, 72 insertions(+), 66 deletions(-) diff --git a/src/authentic2/saml/x509utils.py b/src/authentic2/saml/x509utils.py index f3a6bd86f..1f6e843bf 100644 --- a/src/authentic2/saml/x509utils.py +++ b/src/authentic2/saml/x509utils.py @@ -19,7 +19,6 @@ import binascii import tempfile import os import subprocess -import stat import six _openssl = 'openssl' @@ -54,108 +53,102 @@ def _call_openssl(args): return 1, None -def _protect_file(fd, filepath): - '''Make a file targeted by a file descriptor readable only by the current user - - It's needed to be sure nobody can read the private key file we manage. - ''' - os.fchmod(fd, stat.S_IRUSR | stat.S_IWUSR) - - def check_key_pair_consistency(publickey=None, privatekey=None): '''Check if two PEM key pair whether they are publickey or certificate, are well formed and related. ''' - if publickey and privatekey: - try: - privatekey_file_fd, privatekey_fn = tempfile.mkstemp() - publickey_file_fd, publickey_fn = tempfile.mkstemp() - _protect_file(privatekey_file_fd, privatekey_fn) - _protect_file(publickey_file_fd, publickey_fn) - os.fdopen(privatekey_file_fd, 'w').write(privatekey) - os.fdopen(publickey_file_fd, 'w').write(publickey) - if 'BEGIN CERTIFICATE' in publickey: - rc1, modulus1 = _call_openssl(['x509', '-in', publickey_fn, '-noout', '-modulus']) - else: - rc1, modulus1 = _call_openssl(['rsa', '-pubin', '-in', publickey_fn, '-noout', '-modulus']) - if rc1 != 0: - rc1, modulus1 = _call_openssl(['dsa', '-pubin', '-in', publickey_fn, '-noout', '-modulus']) + if not publickey or not privatekey: + return None + privatekey_file = tempfile.NamedTemporaryFile(mode='w') + publickey_file = tempfile.NamedTemporaryFile(mode='w') + with privatekey_file, publickey_file: + + privatekey_file.write(privatekey) + privatekey_file.flush() + publickey_file.write(publickey) + publickey_file.flush() + + if 'BEGIN CERTIFICATE' in publickey: + rc1, modulus1 = _call_openssl(['x509', '-in', publickey_file.name, '-noout', '-modulus']) + else: + rc1, modulus1 = _call_openssl(['rsa', '-pubin', '-in', publickey_file.name, '-noout', '-modulus']) if rc1 != 0: - return False + rc1, modulus1 = _call_openssl(['dsa', '-pubin', '-in', publickey_file.name, '-noout', '-modulus']) - rc2, modulus2 = _call_openssl(['rsa', '-in', privatekey_fn, '-noout', '-modulus']) - if rc2 != 0: - rc2, modulus2 = _call_openssl(['dsa', '-in', privatekey_fn, '-noout', '-modulus']) + if rc1 != 0: + return False - if rc1 == 0 and rc2 == 0 and modulus1 == modulus2: - return True - else: - return False - finally: - os.unlink(privatekey_fn) - os.unlink(publickey_fn) - return None + rc2, modulus2 = _call_openssl(['rsa', '-in', privatekey_file.name, '-noout', '-modulus']) + if rc2 != 0: + rc2, modulus2 = _call_openssl(['dsa', '-in', privatekey_file.name, '-noout', '-modulus']) + + if rc1 == 0 and rc2 == 0 and modulus1 == modulus2: + return True + else: + return False def generate_rsa_keypair(numbits=1024): '''Generate simple RSA public and private key files ''' - try: - privatekey_file_fd, privatekey_fn = tempfile.mkstemp() - publickey_file_fd, publickey_fn = tempfile.mkstemp() - _protect_file(privatekey_file_fd, privatekey_fn) - _protect_file(publickey_file_fd, publickey_fn) - rc1, _ = _call_openssl(['genrsa', '-out', privatekey_fn, '-passout', 'pass:', str(numbits)]) - rc2, _ = _call_openssl(['rsa', '-in', privatekey_fn, '-pubout', '-out', publickey_fn]) - if rc1 != 0 or rc2 != 0: + privatekey_file = tempfile.NamedTemporaryFile(mode='r') + publickey_file = tempfile.NamedTemporaryFile(mode='r') + + with privatekey_file, publickey_file: + rc1, _ = _call_openssl(['genrsa', '-out', privatekey_file.name, '-passout', 'pass:', str(numbits)]) + if rc1 != 0: raise Exception('Failed to generate a key') - return (os.fdopen(publickey_file_fd).read(), os.fdopen(privatekey_file_fd).read()) - finally: - os.unlink(privatekey_fn) - os.unlink(publickey_fn) + rc2, _ = _call_openssl(['rsa', '-in', privatekey_file.name, '-pubout', '-out', publickey_file.name]) + if rc2 != 0: + raise Exception('Failed to generate a key') + return (publickey_file.read(), privatekey_file.read()) def get_rsa_public_key_modulus(publickey): - try: - publickey_file_fd, publickey_fn = tempfile.mkstemp() - os.fdopen(publickey_file_fd, 'w').write(publickey) + publickey_file = tempfile.NamedTemporaryFile(mode='w') + + with publickey_file: + publickey_file.write(publickey) + publickey_file.flush() + if 'BEGIN PUBLIC' in publickey: - rc, modulus = _call_openssl(['rsa', '-pubin', '-in', publickey_fn, '-noout', '-modulus']) + rc, modulus = _call_openssl(['rsa', '-pubin', '-in', publickey_file.name, '-noout', '-modulus']) elif 'BEGIN RSA PRIVATE KEY' in publickey: - rc, modulus = _call_openssl(['rsa', '-in', publickey_fn, '-noout', '-modulus']) + rc, modulus = _call_openssl(['rsa', '-in', publickey_file.name, '-noout', '-modulus']) elif 'BEGIN CERTIFICATE' in publickey: - rc, modulus = _call_openssl(['x509', '-in', publickey_fn, '-noout', '-modulus']) + rc, modulus = _call_openssl(['x509', '-in', publickey_file.name, '-noout', '-modulus']) else: return None + i = modulus.find('=') + if rc == 0 and i: return int(modulus[i + 1:].strip(), 16) - finally: - os.unlink(publickey_fn) return None def get_rsa_public_key_exponent(publickey): - try: - publickey_file_fd, publickey_fn = tempfile.mkstemp() - os.fdopen(publickey_file_fd, 'w').write(publickey) + publickey_file = tempfile.NamedTemporaryFile(mode='w') + + with publickey_file: + publickey_file.write(publickey) + publickey_file.flush() + _exponent = 'Exponent: ' if 'BEGIN PUBLIC' in publickey: - rc, modulus = _call_openssl(['rsa', '-pubin', '-in', publickey_fn, '-noout', '-text']) + rc, modulus = _call_openssl(['rsa', '-pubin', '-in', publickey_file.name, '-noout', '-text']) elif 'BEGIN RSA PRIVATE' in publickey: - rc, modulus = _call_openssl(['rsa', '-in', publickey_fn, '-noout', '-text']) + rc, modulus = _call_openssl(['rsa', '-in', publickey_file.name, '-noout', '-text']) _exponent = 'publicExponent: ' elif 'BEGIN CERTIFICATE' in publickey: - rc, modulus = _call_openssl(['x509', '-in', publickey_fn, '-noout', '-text']) + rc, modulus = _call_openssl(['x509', '-in', publickey_file.name, '-noout', '-text']) else: return None i = modulus.find(_exponent) j = modulus.find('(', i) if rc == 0 and i and j: return int(modulus[i + len(_exponent):j].strip()) - finally: - os.unlink(publickey_fn) return None @@ -178,8 +171,16 @@ def get_xmldsig_rsa_key_value(publickey): mod = get_rsa_public_key_modulus(publickey) exp = get_rsa_public_key_exponent(publickey) + mod_byte_length = (mod.bit_length() + 7) // 8 + exp_byte_length = (exp.bit_length() + 7) // 8 + mod_bytes = mod.to_bytes(mod_byte_length, byteorder='big') + exp_bytes = exp.to_bytes(exp_byte_length, byteorder='big') + mod_cryptobinary = base64.b64encode(mod_bytes).decode('ascii') + exp_cryptobinary = base64.b64encode(exp_bytes).decode('ascii') return ( '\n\t' '%s\n\t' '%s\n' % ( - base64.b64encode(int_to_bin(mod)), base64.b64encode(int_to_bin(exp)))) + mod_cryptobinary, + exp_cryptobinary) + ) diff --git a/tests/test_saml_x509utils.py b/tests/test_saml_x509utils.py index 64b33234c..b748525c5 100644 --- a/tests/test_saml_x509utils.py +++ b/tests/test_saml_x509utils.py @@ -68,6 +68,11 @@ pkkt86tIOLEtaNO97CcF/t+Un5QAh9MqLmQv5pwUDo4Lqo7qo1bAfyHjOlr5kdaP 8eM47A92x9uplD/sN550pTKM7XLhHBvEfLujUoGHpWQxGA== -----END RSA PRIVATE KEY-----''' assert check_key_pair_consistency(cert, key) - assert get_xmldsig_rsa_key_value(cert) + assert get_xmldsig_rsa_key_value(cert) == '''\ + + rU2w03vy6w/oLytEoH/657wIOM3HP7yXBHbIhrCd7IU7Huzj3+XyGg+5a8vo5rRV+tZ/EZ9XdYsPsIbmG6xukdzoQ8OdL7n29ka0UhwTgOwnH4ikA+gk9qd9ZrL/goh2xpqB0Rcrdgp0RsthQl9jos3+asX4x2iRF7tLZP0nTdk= + AQAB +''' + assert get_rsa_public_key_modulus(cert) is not None + assert get_rsa_public_key_exponent(cert) is not None assert len(decapsulate_pem_file(key).splitlines()) == len(key.splitlines()) - 2 -