Skip to content

Commit

Permalink
add test split for imagenet (#4866)
Browse files Browse the repository at this point in the history
* add test split for imagenet

* add infinite buffer size to shuffler
  • Loading branch information
pmeier authored Nov 8, 2021
1 parent f093d08 commit 6d9a42c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 29 deletions.
28 changes: 18 additions & 10 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,7 @@ def caltech256(info, root, config):

@dataset_mocks.register_mock_data_fn
def imagenet(info, root, config):
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()

wnids = tuple(info.extra.wnid_to_category.keys())

if config.split == "train":
images_root = root / "ILSVRC2012_img_train"

Expand All @@ -470,7 +466,7 @@ def imagenet(info, root, config):
num_examples=1,
)
make_tar(images_root, f"{wnid}.tar", files[0].parent)
else:
elif config.split == "val":
num_samples = 3
files = create_image_folder(
root=root,
Expand All @@ -479,14 +475,26 @@ def imagenet(info, root, config):
num_examples=num_samples,
)
images_root = files[0].parent
else: # config.split == "test"
images_root = root / "ILSVRC2012_img_test_v10102019"

data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
num_samples = 3

create_image_folder(
root=images_root,
name="test",
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
make_tar(root, f"{images_root.name}.tar", images_root)

devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")

return num_samples
48 changes: 29 additions & 19 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ def _make_info(self) -> DatasetInfo:
type=DatasetType.IMAGE,
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val")),
valid_options=dict(split=("train", "val", "test")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]),
sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
),
)

Expand All @@ -53,17 +59,15 @@ def category_to_wnid(self) -> Dict[str, str]:
def wnid_to_category(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.wnid_to_category)

_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
images = HttpResource(
"ILSVRC2012_img_train.tar",
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
)
else: # config.split == "val"
images = HttpResource(
"ILSVRC2012_img_val.tar",
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
)
name = "test_v10102019" if config.split == "test" else config.split
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])

devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz",
Expand All @@ -81,11 +85,11 @@ def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, s
label = self.categories.index(category)
return (label, category, wnid), data

_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")

def _val_image_key(self, data: Tuple[str, Any]) -> int:
def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0])
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]

def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
Expand All @@ -96,9 +100,12 @@ def _collate_val_data(
wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data

def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
return (None, None, None), data

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
Expand All @@ -108,7 +115,7 @@ def _collate_and_decode_sample(
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=torch.tensor(label),
label=label,
category=category,
wnid=wnid,
)
Expand All @@ -129,7 +136,7 @@ def _make_datapipe(
dp = TarArchiveReader(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
else:
elif config.split == "val":
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
Expand All @@ -141,10 +148,13 @@ def _make_datapipe(
devkit_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_image_key,
ref_key_fn=self._val_test_image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_test_data)

return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

Expand Down

0 comments on commit 6d9a42c

Please sign in to comment.