stop hardcoding signer digest algorithm to sha1, extract if from the signer info

This commit is contained in:
Benjamin Dauvergne 2013-10-08 23:13:32 +02:00
parent 7e1785a69e
commit 6b7a585552
3 changed files with 42 additions and 20 deletions

View File

@ -20,6 +20,16 @@ id_attribute_messageDigest = univ.ObjectIdentifier((1,2,840,113549,1,9,4,))
def get_hash_oid(hashname):
return rfc3161.__dict__['id_'+hashname]
def get_hash_from_oid(oid):
h = rfc3161.oid_to_hash.get(oid)
if h is None:
raise ValueError('unsupported hash algorithm', oid)
return h
def get_hash_class_from_oid(oid):
h = get_hash_from_oid(oid)
return getattr(hashlib, h)
class TimestampingError(RuntimeError):
pass
@ -42,15 +52,14 @@ def get_timestamp(tst):
except PyAsn1Error, e:
raise ValueError('not a valid TimeStampToken', e)
def check_timestamp(tst, certificate, data=None, sha1=None, hashname=None):
hashobj = hashlib.new(hashname or 'sha1')
if not sha1:
def check_timestamp(tst, certificate, data=None, digest=None, hashname=None):
hashname = hashname or 'sha1'
hashobj = hashlib.new(hashname)
if digest is None:
if not data:
raise ValueError("check_timestamp requires data or sha1 argument")
raise ValueError("check_timestamp requires data or digest argument")
hashobj.update(data)
digest = hashobj.digest()
else:
digest = sha1
if not isinstance(tst, rfc3161.TimeStampToken):
tst, substrate = decoder.decode(tst, asn1Spec=rfc3161.TimeStampToken())
@ -66,7 +75,7 @@ def check_timestamp(tst, certificate, data=None, sha1=None, hashname=None):
return False, "missing certificate"
# check message imprint with respect to locally computed digest
message_imprint = tst.tst_info.message_imprint
if message_imprint.hash_algorithm[0] != get_hash_oid(hashobj.name) or \
if message_imprint.hash_algorithm[0] != get_hash_oid(hashname) or \
str(message_imprint.hashed_message) != digest:
return False, 'Message imprint mismatch'
#
@ -87,7 +96,10 @@ def check_timestamp(tst, certificate, data=None, sha1=None, hashname=None):
# signed data
if len(signer_info['authenticatedAttributes']):
authenticated_attributes = signer_info['authenticatedAttributes']
content_digest = hashlib.sha1(content).digest()
signer_digest_algorithm = signer_info['digestAlgorithm']['algorithm']
signer_hash_class = get_hash_class_from_oid(signer_digest_algorithm)
signer_hash_name = get_hash_from_oid(signer_digest_algorithm)
content_digest = signer_hash_class(content).digest()
for authenticated_attribute in authenticated_attributes:
if authenticated_attribute[0] == id_attribute_messageDigest:
try:
@ -110,6 +122,7 @@ def check_timestamp(tst, certificate, data=None, sha1=None, hashname=None):
# check signature
signature = signer_info['encryptedDigest']
pub_key = certificate.get_pubkey()
pub_key.reset_context(signer_hash_name)
pub_key.verify_init()
pub_key.verify_update(signed_data)
if pub_key.verify_final(str(signature)) != 1:
@ -133,25 +146,25 @@ class RemoteTimestamper(object):
Check validity of a TimeStampResponse
'''
tst = response.time_stamp_token
return check_timestamp(tst, sha1=digest, certificate=self.certificate, hashname=self.hashobj.name)
return check_timestamp(tst, digest=digest, certificate=self.certificate, hashname=self.hashobj.name)
def __call__(self, data=None, sha1=None):
def __call__(self, data=None, digest=None):
algorithm_identifier = rfc2459.AlgorithmIdentifier()
algorithm_identifier.setComponentByPosition(0, get_hash_oid(self.hashobj.name))
message_imprint = rfc3161.MessageImprint()
message_imprint.setComponentByPosition(0, algorithm_identifier)
if data:
self.hashobj.update(data)
sha1 = self.hashobj.digest()
elif sha1:
assert len(sha1) == self.hashobj.digest_size
digest = self.hashobj.digest()
elif digest:
assert len(digest) == self.hashobj.digest_size, 'digest length is wrong'
else:
raise ValueError('You must pass some data to digest, or the sha1 digest')
message_imprint.setComponentByPosition(1, sha1)
raise ValueError('You must pass some data to digest, or the digest')
message_imprint.setComponentByPosition(1, digest)
request = rfc3161.TimeStampReq()
request.setComponentByPosition(0, 'v1')
request.setComponentByPosition(1, message_imprint)
request.setComponentByPosition(4)
request.setComponentByPosition(4, True)
binary_request = encoder.encode(request)
http_request = urllib2.Request(self.url, binary_request,
{ 'Content-Type': 'application/timestamp-query' })
@ -162,11 +175,11 @@ class RemoteTimestamper(object):
response = urllib2.urlopen(http_request).read()
except (IOError, socket.error), e:
raise TimestampingError('Unable to send the request to %s' % self.url, e)
# open('response.tsr', 'w').write(response)
open('response.tsr', 'w').write(response)
tst_response, substrate = decoder.decode(response, asn1Spec=rfc3161.TimeStampResp())
if substrate:
return False, 'Extra data returned'
result, message = self.check_response(tst_response, sha1)
result, message = self.check_response(tst_response, digest)
if result:
return encoder.encode(tst_response.time_stamp_token), ''
else:

View File

@ -1,6 +1,7 @@
from pyasn1.type import univ
__all__ = ('id_kp_timeStamping','id_sha1', 'id_sha256', 'id_sha384', 'id_sha512', 'id_ct_TSTInfo',)
__all__ = ('id_kp_timeStamping','id_sha1', 'id_sha256', 'id_sha384',
'id_sha512', 'id_ct_TSTInfo', 'oid_to_hash',)
id_kp_timeStamping = univ.ObjectIdentifier((1,3,6,1,5,5,7,3,8))
id_sha1 = univ.ObjectIdentifier((1,3,14,3,2,26))
@ -8,3 +9,10 @@ id_sha256 = univ.ObjectIdentifier((2,16,840,1,101,3,4,2,1,))
id_sha384 = univ.ObjectIdentifier((2,16,840,1,101,3,4,2,2,))
id_sha512 = univ.ObjectIdentifier((2,16,840,1,101,3,4,2,3,))
id_ct_TSTInfo = univ.ObjectIdentifier((1,2,840,113549,1,9,16,1,4))
oid_to_hash = {
id_sha1: 'sha1',
id_sha256: 'sha256',
id_sha384: 'sha384',
id_sha512: 'sha512',
}

View File

@ -12,9 +12,10 @@ class Rfc3161(unittest.TestCase):
'../data/certum_certificate.crt')
def test_timestamp(self):
data = 'xx'
certificate = file(self.CERTIFICATE).read()
value, substrate = rfc3161.RemoteTimestamper(
self.PUBLIC_TSA_SERVER, certificate=certificate)(data='xx')
self.PUBLIC_TSA_SERVER, certificate=certificate)(data=data)
self.assertIsInstance(rfc3161.get_timestamp(value), datetime.datetime)
self.assertNotEqual(value, None)
self.assertEqual(substrate, '')