class for an indexable pickle object

This commit is contained in:
Benjamin Dauvergne 2019-06-28 00:35:35 +02:00
parent 5b2f4a76c8
commit 6af35222db
2 changed files with 100 additions and 0 deletions

View File

@ -0,0 +1,65 @@
import struct
import pickle
class RamPickleWrite(object):
def __init__(self, sequence):
self.sequence = sequence
def pickle(self, fd):
base_offset = fd.tell()
index = []
fd.write(struct.pack('L', 0))
i = 0
while True:
batch = self.sequence[i:i + 100]
if not batch:
break
index.append(fd.tell())
pickle.dump(batch, fd)
i += 100
index_offset = fd.tell()
pickle.dump(index, fd)
fd.seek(base_offset)
fd.write(struct.pack('L', index_offset))
class RamPickleRead(object):
def __init__(self, fd):
self.fd = fd
buf = fd.read(struct.calcsize('L'))
index_offset, = struct.unpack('L', buf)
fd.seek(index_offset)
self.index = pickle.load(fd)
self.batches = {}
def load_batch(self, index):
page = index // 100
if len(self.index) <= page:
return
if page not in self.batches:
self.fd.seek(self.index[page])
self.batches[page] = pickle.load(self.fd)
return self.batches[page]
def __getitem__(self, index):
if isinstance(index, (long, int)):
batch = self.load_batch(index)
if not batch:
raise IndexError(index)
offset = index % 100
if len(batch) <= offset:
raise IndexError(index)
return batch[offset]
elif isinstance(index, slice):
l = []
i = index.start or 0
while index.stop is None or i < index.stop:
try:
l.append(self[i])
except IndexError:
break
i += index.step or 1
return l
else:
raise TypeError(index)

View File

@ -0,0 +1,35 @@
import pytest
from ram_pickle import RamPickleWrite, RamPickleRead
def test_ram():
import io
sequence = list(range(1000))
fd = io.BytesIO()
write = RamPickleWrite(sequence)
write.pickle(fd)
fd.seek(0)
read = RamPickleRead(fd)
for i in range(1000):
assert read[i] == i
fd.seek(0)
read = RamPickleRead(fd)
for i in range(1000, 2000):
with pytest.raises(IndexError):
read[i]
fd.seek(0)
read = RamPickleRead(fd)
assert read[100:] == range(100, 1000)
fd.seek(0)
read = RamPickleRead(fd)
assert read[:100] == range(100)
fd.seek(0)
read = RamPickleRead(fd)
assert read[100:200:2] == range(100, 200, 2)