Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sun397 prototype datapipe #5667

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import PIL.Image
import pytest
import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, random_subsets
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision._utils import sequence_to_str
Expand Down Expand Up @@ -1454,3 +1454,51 @@ def usps(info, root, config):
fh.write("\n".join(lines).encode())

return num_samples


@register_mock
def sun397(info, root, config):
images_root = root / "SUN397"
images_root.mkdir()
YosuaMichael marked this conversation as resolved.
Show resolved Hide resolved

categories = ["abbey", "airplane_cabin", "wrestling_ring/indoor"]
YosuaMichael marked this conversation as resolved.
Show resolved Hide resolved
category_keys = [f"/{category[0]}/{category}" for category in categories]

keys = []
for key in category_keys:
parts = key.split("/")
image_files = create_image_folder(
root=images_root.joinpath(*parts[1:-1]),
name=parts[-1],
file_name_fn=lambda idx: f"sun_{idx:05d}.jpg",
num_examples=5,
)

keys.extend([f"/{image_file.relative_to(images_root).as_posix()}" for image_file in image_files])
YosuaMichael marked this conversation as resolved.
Show resolved Hide resolved

partitions_root = root / "Partitions"
partitions_root.mkdir()

splits = ["train", "test"]
for fold in range(1, 11):
random.shuffle(keys)

for split, keys_in_split in zip(splits, random_subsets(keys, len(splits))):
if config.split == "all":
num_samples = len(keys)
elif config.split == f"{split}-{fold}":
num_samples = len(keys_in_split)

with open(partitions_root / f"{split.capitalize()}ing_{fold:02d}.txt", "w") as fh:
fh.write("\n".join(sorted(keys_in_split)))

# Both archives contain this file. Although it shouldn't be used at runtime, it serves as sentinel to test the
# filtering of the correct files
YosuaMichael marked this conversation as resolved.
Show resolved Hide resolved
for path in (images_root, partitions_root):
with open(path / "ClassName.txt", "w") as fh:
fh.write("\n".join(category_keys))

make_tar(root, f"{images_root.name}.tar.gz", compression="gz")
make_zip(root, f"{partitions_root.name}.zip")

return num_samples
29 changes: 29 additions & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,32 @@ def make_fake_flo_file(h, w, file_name):
)
with open(file_name, "wb") as f:
f.write(content)


def random_subsets(collection, n):
"""Splits collection into non-overlapping subsets.

Args:
collection: Collection of items to be split.
n: Number of subsets.

Returns:
Tuple of subsets. Each subset is a list of random items from ``collection`` without overlap
to the other subsets. Each subset contains at least one element.

Examples:
>>> collection = range(10)
>>> random_subsets(collection, 2)
([3, 5, 6, 7], [0, 1, 2, 4, 8, 9])
>>> random_subsets(collection, 3)
([0, 1, 4, 8, 9], [2], [3, 5, 6, 7])
"""
while True:
idcs = torch.randint(n, (len(collection),)).tolist()
if len(set(idcs)) == n:
break

subsets = tuple([] for _ in range(n))
for idx, item in zip(idcs, collection):
subsets[idx].append(item)
return subsets
1 change: 1 addition & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .sbd import SBD
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .sun397 import SUN397
from .svhn import SVHN
from .usps import USPS
from .voc import VOC
Loading