class for an indexable pickle object
This commit is contained in:
parent
5b2f4a76c8
commit
6af35222db
|
@ -0,0 +1,65 @@
|
|||
import struct
|
||||
import pickle
|
||||
|
||||
|
||||
class RamPickleWrite(object):
|
||||
def __init__(self, sequence):
|
||||
self.sequence = sequence
|
||||
|
||||
def pickle(self, fd):
|
||||
base_offset = fd.tell()
|
||||
index = []
|
||||
fd.write(struct.pack('L', 0))
|
||||
i = 0
|
||||
while True:
|
||||
batch = self.sequence[i:i + 100]
|
||||
if not batch:
|
||||
break
|
||||
index.append(fd.tell())
|
||||
pickle.dump(batch, fd)
|
||||
i += 100
|
||||
index_offset = fd.tell()
|
||||
pickle.dump(index, fd)
|
||||
fd.seek(base_offset)
|
||||
fd.write(struct.pack('L', index_offset))
|
||||
|
||||
|
||||
class RamPickleRead(object):
|
||||
def __init__(self, fd):
|
||||
self.fd = fd
|
||||
buf = fd.read(struct.calcsize('L'))
|
||||
index_offset, = struct.unpack('L', buf)
|
||||
fd.seek(index_offset)
|
||||
self.index = pickle.load(fd)
|
||||
self.batches = {}
|
||||
|
||||
def load_batch(self, index):
|
||||
page = index // 100
|
||||
if len(self.index) <= page:
|
||||
return
|
||||
if page not in self.batches:
|
||||
self.fd.seek(self.index[page])
|
||||
self.batches[page] = pickle.load(self.fd)
|
||||
return self.batches[page]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, (long, int)):
|
||||
batch = self.load_batch(index)
|
||||
if not batch:
|
||||
raise IndexError(index)
|
||||
offset = index % 100
|
||||
if len(batch) <= offset:
|
||||
raise IndexError(index)
|
||||
return batch[offset]
|
||||
elif isinstance(index, slice):
|
||||
l = []
|
||||
i = index.start or 0
|
||||
while index.stop is None or i < index.stop:
|
||||
try:
|
||||
l.append(self[i])
|
||||
except IndexError:
|
||||
break
|
||||
i += index.step or 1
|
||||
return l
|
||||
else:
|
||||
raise TypeError(index)
|
|
@ -0,0 +1,35 @@
|
|||
import pytest
|
||||
|
||||
from ram_pickle import RamPickleWrite, RamPickleRead
|
||||
|
||||
|
||||
def test_ram():
|
||||
import io
|
||||
|
||||
sequence = list(range(1000))
|
||||
fd = io.BytesIO()
|
||||
write = RamPickleWrite(sequence)
|
||||
write.pickle(fd)
|
||||
|
||||
fd.seek(0)
|
||||
read = RamPickleRead(fd)
|
||||
for i in range(1000):
|
||||
assert read[i] == i
|
||||
|
||||
fd.seek(0)
|
||||
read = RamPickleRead(fd)
|
||||
for i in range(1000, 2000):
|
||||
with pytest.raises(IndexError):
|
||||
read[i]
|
||||
|
||||
fd.seek(0)
|
||||
read = RamPickleRead(fd)
|
||||
assert read[100:] == range(100, 1000)
|
||||
|
||||
fd.seek(0)
|
||||
read = RamPickleRead(fd)
|
||||
assert read[:100] == range(100)
|
||||
|
||||
fd.seek(0)
|
||||
read = RamPickleRead(fd)
|
||||
assert read[100:200:2] == range(100, 200, 2)
|
Loading…
Reference in New Issue