ZipLoader refactoring

This commit is contained in:
Sergey Lavrinenko 2015-03-28 00:18:10 +03:00
parent 95c5b30481
commit 7618f71175
1 changed files with 29 additions and 17 deletions

View File

@ -165,47 +165,59 @@ class FileSystemLoader(BaseLoader):
class ZipLoader(BaseLoader):
"""
Load files from zip file
"""
common_filename_charsets = ['ascii', 'cp866', 'cp1251', 'utf-8']
def __init__(self, file, encoding='utf-8', base_path=None):
self.zipfile = ZipFile(file, 'r')
if not isinstance(file, ZipFile):
file = ZipFile(file, 'r')
self.zipfile = file
self.encoding = encoding
self.base_path = base_path
self._filenames = None
self._decoded_filenames = None
self._original_filenames = None
def _decode_zip_filename(self, name):
for enc in ('cp866', 'cp1251', 'utf-8'):
def _decode_filename(self, name):
for enc in self.common_filename_charsets:
try:
return to_unicode(name, enc)
except UnicodeDecodeError:
pass
return name
def _unpack_zip(self):
if self._filenames is None:
self._filenames = {}
for name in self.zipfile.namelist():
decoded_name = self._decode_zip_filename(name)
self._filenames[decoded_name] = name
def _unpack(self):
if self._decoded_filenames is None:
self._original_filenames = set(self.zipfile.namelist())
self._decoded_filenames = dict([(self._decode_filename(name), name) for name in self._original_filenames])
def get_file(self, name):
if self.base_path:
name = path.join(self.base_path, name)
self._unpack_zip()
self._unpack()
if isinstance(name, str):
name = to_unicode(name, 'utf-8')
original_name = self._filenames.get(name)
if name not in self._original_filenames:
name = self._decoded_filenames.get(name)
if original_name is None:
if name is None:
raise FileNotFound(name)
return self.zipfile.read(original_name), name
return self.zipfile.read(name), name
def list_files(self):
self._unpack_zip()
return sorted(self._filenames)
self._unpack()
return sorted(self._decoded_filenames)
class MsgLoader(BaseLoader):
@ -259,7 +271,7 @@ class MsgLoader(BaseLoader):
def add_text_part(self, part):
self._text_parts.append({'data': self.extract_part_text(part),
'content_type': part.get_content_type()})
'content_type': part.get_content_type()})
def add_attachment_part(self, part):
counter = 1