Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions torchvision/prototype/datasets/_builtin/sbu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import itertools
import pathlib
import warnings
from typing import List, Any, Dict, Optional, Tuple, BinaryIO

from torch.utils.model_zoo import tqdm
from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, LineReader, Zipper, Mapper, IterKeyZipper
from torchdata.datapipes.iter import (
IterDataPipe,
Demultiplexer,
LineReader,
Zipper,
Mapper,
IterKeyZipper,
FileLister,
FileOpener,
IterableWrapper,
)
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
Expand All @@ -24,8 +35,29 @@ def _make_info(self) -> DatasetInfo:
homepage="http://www.cs.virginia.edu/~vicente/sbucaptions/",
)

def _preprocess(self, path: pathlib.Path) -> pathlib.Path:
folder = OnlineResource._extract(path)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz",
sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09",
)
]

def _maybe_download_images(
self, resource_dp: IterDataPipe[Tuple[str, BinaryIO]]
) -> IterDataPipe[Tuple[str, BinaryIO]]:
resource_dp = iter(resource_dp)
data = next(resource_dp)
path = pathlib.Path(data[0])
try:
archive = next(
iter(archive for archive in path.parents if archive.name == "SBUCaptionedPhotoDataset.tar.gz")
)
except StopIteration:
# we already loaded the extracted folder
return IterableWrapper(itertools.chain([data], resource_dp), deepcopy=False)

folder = OnlineResource._extract(archive)
data_folder = folder / "dataset"
image_folder = data_folder / "images"
image_folder.mkdir()
Expand All @@ -50,16 +82,7 @@ def _preprocess(self, path: pathlib.Path) -> pathlib.Path:
with open(broken_urls_file, "w") as fh:
fh.write("\n".join(broken_urls) + "\n")

return folder

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz",
sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09",
preprocess=self._preprocess,
)
]
return FileOpener(FileLister(str(folder), recursive=True), mode="rb")

def _classify_files(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
Expand All @@ -78,9 +101,10 @@ def _make_datapipe(
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
resource_dp = self._maybe_download_images(resource_dps[0])

images_dp, urls_dp, captions_dp = Demultiplexer(
resource_dps[0], 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
resource_dp, 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)

images_dp = hint_shuffling(images_dp)
Expand Down