Skip to content

Implement ZipFolder #3510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from PIL import Image

import io
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import zipfile
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union


def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
Expand Down Expand Up @@ -139,8 +141,8 @@ def __init__(
self.samples = samples
self.targets = [s[1] for s in samples]

@staticmethod
def make_dataset(
self,
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
Expand Down Expand Up @@ -190,15 +192,16 @@ def __len__(self) -> int:
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')


def pil_loader(path: str) -> Image.Image:
def pil_loader(path: Union[str, io.BytesIO]) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
f = open(path, 'rb') if isinstance(path, str) else path
img = Image.open(f).convert('RGB')
f.close()
return img


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
def accimage_loader(path: Union[str, io.BytesIO]) -> Any:
import accimage
try:
return accimage.Image(path)
Expand All @@ -207,7 +210,7 @@ def accimage_loader(path: str) -> Any:
return pil_loader(path)


def default_loader(path: str) -> Any:
def default_loader(path: Union[str, io.BytesIO]) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
Expand Down Expand Up @@ -255,3 +258,79 @@ def __init__(
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples


class ZipFolder(DatasetFolder):
def __init__(self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, memory: bool = True) -> None:
if not root.endswith('.zip'):
raise TypeError("Need ZIP file for data source: ", root)
if memory:
with open(root, 'rb') as z:
data = z.read()
self.root_zip = zipfile.ZipFile(io.BytesIO(data), 'r')
else:
self.root_zip = zipfile.ZipFile(root, 'r')
super().__init__(root, self.zip_loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)
self.imgs = self.samples

@staticmethod
def initialize_from_folder(root: str, zip_path: str = None):
root = os.path.normpath(root)
folder_dir, folder_base = os.path.split(root)
if zip_path is None:
zip_path = os.path.join(folder_dir, f'{folder_base}_store.zip')
with zipfile.ZipFile(zip_path, mode='w', compression=zipfile.ZIP_STORED) as zf:
for walk_root, walk_dirs, walk_files in os.walk(root):
# TODO: (python 3.9) zip_root = walk_root.removeprefix(folder_dir)
zip_root = walk_root[len(folder_dir):] if walk_root.startswith(folder_dir) else walk_root
for _file in walk_files:
org_path = os.path.join(walk_root, _file)
zip_path = os.path.join(zip_root, _file)
zf.write(org_path, zip_path)

def make_dataset(
self,
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
instances = []
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for filepath in self.root_zip.namelist():
if is_valid_file(filepath):
target_class = os.path.basename(os.path.dirname(filepath))
instances.append((filepath, class_to_idx[target_class]))
return instances

def zip_loader(self, path: str) -> Any:
return default_loader(io.BytesIO(self.root_zip.read(path)))

def _find_classes(self, *args, **kwargs):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = set()
for filepath in self.root_zip.namelist():
root, target_class = os.path.split(os.path.dirname(filepath))
if root:
classes.add(target_class)
classes = list(classes)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx