diff --git a/torchvision/prototype/datasets/_builtin/sbu.py b/torchvision/prototype/datasets/_builtin/sbu.py index 63ee3ddf827..5677fdc4975 100644 --- a/torchvision/prototype/datasets/_builtin/sbu.py +++ b/torchvision/prototype/datasets/_builtin/sbu.py @@ -1,9 +1,20 @@ +import itertools import pathlib import warnings from typing import List, Any, Dict, Optional, Tuple, BinaryIO from torch.utils.model_zoo import tqdm -from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, LineReader, Zipper, Mapper, IterKeyZipper +from torchdata.datapipes.iter import ( + IterDataPipe, + Demultiplexer, + LineReader, + Zipper, + Mapper, + IterKeyZipper, + FileLister, + FileOpener, + IterableWrapper, +) from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, @@ -24,8 +35,29 @@ def _make_info(self) -> DatasetInfo: homepage="http://www.cs.virginia.edu/~vicente/sbucaptions/", ) - def _preprocess(self, path: pathlib.Path) -> pathlib.Path: - folder = OnlineResource._extract(path) + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz", + sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09", + ) + ] + + def _maybe_download_images( + self, resource_dp: IterDataPipe[Tuple[str, BinaryIO]] + ) -> IterDataPipe[Tuple[str, BinaryIO]]: + resource_dp = iter(resource_dp) + data = next(resource_dp) + path = pathlib.Path(data[0]) + try: + archive = next( + iter(archive for archive in path.parents if archive.name == "SBUCaptionedPhotoDataset.tar.gz") + ) + except StopIteration: + # we already loaded the extracted folder + return IterableWrapper(itertools.chain([data], resource_dp), deepcopy=False) + + folder = OnlineResource._extract(archive) data_folder = folder / "dataset" image_folder = data_folder / "images" image_folder.mkdir() @@ -50,16 +82,7 @@ def _preprocess(self, path: pathlib.Path) -> pathlib.Path: with open(broken_urls_file, "w") as fh: fh.write("\n".join(broken_urls) + "\n") - return folder - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - return [ - HttpResource( - "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz", - sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09", - preprocess=self._preprocess, - ) - ] + return FileOpener(FileLister(str(folder), recursive=True), mode="rb") def _classify_files(self, data: Tuple[str, Any]) -> Optional[int]: path = pathlib.Path(data[0]) @@ -78,9 +101,10 @@ def _make_datapipe( *, config: DatasetConfig, ) -> IterDataPipe[Dict[str, Any]]: + resource_dp = self._maybe_download_images(resource_dps[0]) images_dp, urls_dp, captions_dp = Demultiplexer( - resource_dps[0], 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + resource_dp, 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) images_dp = hint_shuffling(images_dp)