Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Datasets Should Use New torchvision.io Image Loader APIs and Return TVTensor Images by Default #8762

Open
fang-d opened this issue Nov 28, 2024 · 1 comment

Comments

@fang-d
Copy link
Contributor

fang-d commented Nov 28, 2024

🚀 The feature

  1. Add "torchvision" image loader backend based on new torchvision.io APIs (See: Release Notes v0.20) and enable it by default.
  2. VisionDatasets should return TVTensor images by default instead of PIL.Image.

Motivation, pitch

  1. TorchVision v0.20 introduces new torchvision.io APIs that enhance its encoding/decoding capabilities.
  2. Current VisionDatasets returns PIL.Image by default, but the first step of transforms is usually transforms.ToImage().
  3. PIL is slow (See: Pillow-SIMD), especially when compared with new torchvision.io APIs.
  4. Current TorchVision image loader backends are based on PIL or accimage, not including new torchvision.io APIs.

Alternatives

  1. The return type of datasets can be PIL.Image when using the PIL or the accimage backends, and be TVTensor if using new APIs (may lose consistency).

Additional context

I would like to make a pull request if the community likes this feature.

@NicolasHug
Copy link
Member

Hi @fang-d , thank you for the feature request. This is a great idea, and I think the torchvision decoders are in a stable enough state to enable this now.

We already support the loader parameter for some datasets (mostly ImageFolder I think https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder). But we should enable the same for all existing datasets.

I think the way to go would probably be to add that loader parameter to all datasets that currently call Image.open(...).

~/dev/vision (main*) » git grep Image.open torchvision/datasets                                         nicolashug@nicolashug-fedora-PF2MMKSN
torchvision/datasets/_optical_flow.py:        img = Image.open(file_name)
torchvision/datasets/_stereo_matching.py:        img = Image.open(file_path)
torchvision/datasets/_stereo_matching.py:        disparity_map = np.asarray(Image.open(file_path)) / 256.0
torchvision/datasets/_stereo_matching.py:        disparity_map = np.asarray(Image.open(file_path)) / 256.0
torchvision/datasets/_stereo_matching.py:        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py:        depth = np.asarray(Image.open(file_path))
torchvision/datasets/_stereo_matching.py:        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py:        valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
torchvision/datasets/_stereo_matching.py:        off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
torchvision/datasets/_stereo_matching.py:        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py:        valid_mask = Image.open(mask_path)
torchvision/datasets/caltech.py:        img = Image.open(
torchvision/datasets/caltech.py:        img = Image.open(
torchvision/datasets/celeba.py:        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
torchvision/datasets/cityscapes.py:        image = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/cityscapes.py:                target = Image.open(self.targets[index][i])  # type: ignore[assignment]
torchvision/datasets/clevr.py:        image = Image.open(image_file).convert("RGB")
torchvision/datasets/coco.py:        return Image.open(os.path.join(self.root, path)).convert("RGB")
torchvision/datasets/dtd.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/fgvc_aircraft.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/flickr.py:        img = Image.open(img_id).convert("RGB")
torchvision/datasets/flickr.py:        img = Image.open(filename).convert("RGB")
torchvision/datasets/flowers102.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/folder.py:        img = Image.open(f)
torchvision/datasets/food101.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/gtsrb.py:        sample = PIL.Image.open(path).convert("RGB")
torchvision/datasets/imagenette.py:        image = Image.open(path).convert("RGB")
torchvision/datasets/inaturalist.py:        img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
torchvision/datasets/kitti.py:        image = Image.open(self.images[index])
torchvision/datasets/lfw.py:            img = Image.open(f)
torchvision/datasets/lsun.py:        img = Image.open(buf).convert("RGB")
torchvision/datasets/omniglot.py:        image = Image.open(image_path, mode="r").convert("L")
torchvision/datasets/oxford_iiit_pet.py:        image = Image.open(self._images[idx]).convert("RGB")
torchvision/datasets/oxford_iiit_pet.py:                target.append(Image.open(self._segs[idx]))
torchvision/datasets/phototour.py:        img = Image.open(fpath)
torchvision/datasets/rendered_sst2.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/sbd.py:        img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/sbu.py:        img = Image.open(filename).convert("RGB")
torchvision/datasets/stanford_cars.py:        pil_image = Image.open(image_path).convert("RGB")
torchvision/datasets/sun397.py:        image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/voc.py:        img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/voc.py:        target = Image.open(self.masks[index])
torchvision/datasets/voc.py:        img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/widerface.py:        img = Image.open(self.img_info[index]["img_path"])  # type: ignore[arg-type]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants