diff --git a/emails/loader/local_store.py b/emails/loader/local_store.py index 22a385f..e88d297 100644 --- a/emails/loader/local_store.py +++ b/emails/loader/local_store.py @@ -165,47 +165,59 @@ class FileSystemLoader(BaseLoader): class ZipLoader(BaseLoader): + + """ + Load files from zip file + """ + + common_filename_charsets = ['ascii', 'cp866', 'cp1251', 'utf-8'] + + def __init__(self, file, encoding='utf-8', base_path=None): - self.zipfile = ZipFile(file, 'r') + + if not isinstance(file, ZipFile): + file = ZipFile(file, 'r') + + self.zipfile = file self.encoding = encoding self.base_path = base_path - self._filenames = None + self._decoded_filenames = None + self._original_filenames = None - def _decode_zip_filename(self, name): - for enc in ('cp866', 'cp1251', 'utf-8'): + def _decode_filename(self, name): + for enc in self.common_filename_charsets: try: return to_unicode(name, enc) except UnicodeDecodeError: pass return name - def _unpack_zip(self): - if self._filenames is None: - self._filenames = {} - for name in self.zipfile.namelist(): - decoded_name = self._decode_zip_filename(name) - self._filenames[decoded_name] = name + def _unpack(self): + if self._decoded_filenames is None: + self._original_filenames = set(self.zipfile.namelist()) + self._decoded_filenames = dict([(self._decode_filename(name), name) for name in self._original_filenames]) def get_file(self, name): if self.base_path: name = path.join(self.base_path, name) - self._unpack_zip() + self._unpack() if isinstance(name, str): name = to_unicode(name, 'utf-8') - original_name = self._filenames.get(name) + if name not in self._original_filenames: + name = self._decoded_filenames.get(name) - if original_name is None: + if name is None: raise FileNotFound(name) - return self.zipfile.read(original_name), name + return self.zipfile.read(name), name def list_files(self): - self._unpack_zip() - return sorted(self._filenames) + self._unpack() + return sorted(self._decoded_filenames) class MsgLoader(BaseLoader): @@ -259,7 +271,7 @@ class MsgLoader(BaseLoader): def add_text_part(self, part): self._text_parts.append({'data': self.extract_part_text(part), - 'content_type': part.get_content_type()}) + 'content_type': part.get_content_type()}) def add_attachment_part(self, part): counter = 1