diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c5608377d97..768d286e890 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -10,6 +10,7 @@ import pathlib import pickle import random +import unittest.mock import warnings import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -18,11 +19,11 @@ import PIL.Image import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file +from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision._utils import sequence_to_str -from torchvision.prototype.datasets._api import find +from torchvision.prototype import datasets make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -32,13 +33,11 @@ class DatasetMock: - def __init__(self, name, mock_data_fn): - self.dataset = find(name) - self.info = self.dataset.info - self.name = self.info.name - + def __init__(self, name, *, mock_data_fn, configs): + # FIXME: error handling for unknown names + self.name = name self.mock_data_fn = mock_data_fn - self.configs = self.info._configs + self.configs = configs def _parse_mock_info(self, mock_info): if mock_info is None: @@ -67,10 +66,13 @@ def prepare(self, home, config): root = home / self.name root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) + with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"): + required_file_names = { + resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() + } available_file_names = {path.name for path in root.glob("*")} - required_file_names = {resource.file_name for resource in self.dataset.resources(config)} missing_file_names = required_file_names - available_file_names if missing_file_names: raise pytest.UsageError( @@ -125,10 +127,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): DATASET_MOCKS = {} -def register_mock(fn): - name = fn.__name__.replace("_", "-") - DATASET_MOCKS[name] = DatasetMock(name, fn) - return fn +def register_mock(name=None, *, configs): + def wrapper(mock_data_fn): + nonlocal name + if name is None: + name = mock_data_fn.__name__ + DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs) + + return mock_data_fn + + return wrapper class MNISTMockData: @@ -206,58 +214,64 @@ def generate( return num_samples -@register_mock -def mnist(info, root, config): - train = config.split == "train" - images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" - labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz" +def mnist(root, config): + prefix = "train" if config["split"] == "train" else "t10k" return MNISTMockData.generate( root, - num_categories=len(info.categories), - images_file=images_file, - labels_file=labels_file, + num_categories=10, + images_file=f"{prefix}-images-idx3-ubyte.gz", + labels_file=f"{prefix}-labels-idx1-ubyte.gz", ) -DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) +DATASET_MOCKS.update( + { + name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) + for name in ["mnist", "fashionmnist", "kmnist"] + } +) -@register_mock -def emnist(info, root, config): - # The image sets that merge some lower case letters in their respective upper case variant, still use dense - # labels in the data files. Thus, num_categories != len(categories) there. - num_categories = defaultdict( - lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")} +@register_mock( + configs=combinations_grid( + split=("train", "test"), + image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), ) - +) +def emnist(root, config): num_samples_map = {} file_names = set() - for config_ in info._configs: - prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}" + for split, image_set in itertools.product( + ("train", "test"), + ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), + ): + prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" file_names.update({images_file, labels_file}) - num_samples_map[config_] = MNISTMockData.generate( + num_samples_map[(split, image_set)] = MNISTMockData.generate( root, - num_categories=num_categories[config_.image_set], + # The image sets that merge some lower case letters in their respective upper case variant, still use dense + # labels in the data files. Thus, num_categories != len(categories) there. + num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, images_file=images_file, labels_file=labels_file, ) make_zip(root, "emnist-gzip.zip", *file_names) - return num_samples_map[config] + return num_samples_map[(config["split"], config["image_set"])] -@register_mock -def qmnist(info, root, config): - num_categories = len(info.categories) - if config.split == "train": +@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) +def qmnist(root, config): + num_categories = 10 + if config["split"] == "train": num_samples = num_samples_gen = num_categories + 2 prefix = "qmnist-train" suffix = ".gz" compressor = gzip.open - elif config.split.startswith("test"): + elif config["split"].startswith("test"): # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # more than 10000 images for the dataset to not be empty. num_samples_gen = 10001 @@ -265,11 +279,11 @@ def qmnist(info, root, config): "test": num_samples_gen, "test10k": min(num_samples_gen, 10_000), "test50k": num_samples_gen - 10_000, - }[config.split] + }[config["split"]] prefix = "qmnist-test" suffix = ".gz" compressor = gzip.open - else: # config.split == "nist" + else: # config["split"] == "nist" num_samples = num_samples_gen = num_categories + 3 prefix = "xnist" suffix = ".xz" @@ -326,8 +340,8 @@ def generate( make_tar(root, name, folder, compression="gz") -@register_mock -def cifar10(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar10(root, config): train_files = [f"data_batch_{idx}" for idx in range(1, 6)] test_files = ["test_batch"] @@ -341,11 +355,11 @@ def cifar10(info, root, config): labels_key="labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) -@register_mock -def cifar100(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar100(root, config): train_files = ["train"] test_files = ["test"] @@ -359,11 +373,11 @@ def cifar100(info, root, config): labels_key="fine_labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) -@register_mock -def caltech101(info, root, config): +@register_mock(configs=[dict()]) +def caltech101(root, config): def create_ann_file(root, name): import scipy.io @@ -382,15 +396,17 @@ def create_ann_folder(root, name, file_name_fn, num_examples): images_root = root / "101_ObjectCategories" anns_root = root / "Annotations" - ann_category_map = { - "Faces_2": "Faces", - "Faces_3": "Faces_easy", - "Motorbikes_16": "Motorbikes", - "Airplanes_Side_2": "airplanes", + image_category_map = { + "Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2", } + categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"] + num_images_per_category = 2 - for category in info.categories: + for category in categories: create_image_folder( root=images_root, name=category, @@ -399,7 +415,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples): ) create_ann_folder( root=anns_root, - name=ann_category_map.get(category, category), + name=image_category_map.get(category, category), file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat", num_examples=num_images_per_category, ) @@ -409,19 +425,26 @@ def create_ann_folder(root, name, file_name_fn, num_examples): make_tar(root, f"{anns_root.name}.tar", anns_root) - return num_images_per_category * len(info.categories) + return num_images_per_category * len(categories) -@register_mock -def caltech256(info, root, config): +@register_mock(configs=[dict()]) +def caltech256(root, config): dir = root / "256_ObjectCategories" num_images_per_category = 2 - for idx, category in enumerate(info.categories, 1): + categories = [ + (1, "ak47"), + (127, "laptop-101"), + (198, "spider"), + (257, "clutter"), + ] + + for category_idx, category in categories: files = create_image_folder( dir, - name=f"{idx:03d}.{category}", - file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg", + name=f"{category_idx:03d}.{category}", + file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg", num_examples=num_images_per_category, ) if category == "spider": @@ -429,21 +452,21 @@ def caltech256(info, root, config): make_tar(root, f"{dir.name}.tar", dir) - return num_images_per_category * len(info.categories) + return num_images_per_category * len(categories) -@register_mock -def imagenet(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def imagenet(root, config): from scipy.io import savemat - categories = info.categories - wnids = [info.extra.category_to_wnid[category] for category in categories] - if config.split == "train": - num_samples = len(wnids) + info = datasets.info("imagenet") + + if config["split"] == "train": + num_samples = len(info["wnids"]) archive_name = "ILSVRC2012_img_train.tar" files = [] - for wnid in wnids: + for wnid in info["wnids"]: create_image_folder( root=root, name=wnid, @@ -451,7 +474,7 @@ def imagenet(info, root, config): num_examples=1, ) files.append(make_tar(root, f"{wnid}.tar")) - elif config.split == "val": + elif config["split"] == "val": num_samples = 3 archive_name = "ILSVRC2012_img_val.tar" files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -461,13 +484,13 @@ def imagenet(info, root, config): data_root.mkdir(parents=True) with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): + for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist(): file.write(f"{label}\n") num_children = 0 synsets = [ (idx, wnid, category, "", num_children, [], 0, 0) - for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) + for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1) ] num_children = 1 synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) @@ -477,7 +500,7 @@ def imagenet(info, root, config): savemat(data_root / "meta.mat", dict(synsets=synsets)) make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # config.split == "test" + else: # config["split"] == "test" num_samples = 5 archive_name = "ILSVRC2012_img_test_v10102019.tar" files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -592,9 +615,15 @@ def generate( return num_samples -@register_mock -def coco(info, root, config): - return CocoMockData.generate(root, year=config.year, num_samples=5) +@register_mock( + configs=combinations_grid( + split=("train", "val"), + year=("2017", "2014"), + annotations=("instances", "captions", None), + ) +) +def coco(root, config): + return CocoMockData.generate(root, year=config["year"], num_samples=5) class SBDMockData: @@ -666,15 +695,15 @@ def generate(cls, root): return num_samples_map -@register_mock -def sbd(info, root, config): - return SBDMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) +def sbd(root, config): + return SBDMockData.generate(root)[config["split"]] -@register_mock -def semeion(info, root, config): +@register_mock(configs=[dict()]) +def semeion(root, config): num_samples = 3 - num_categories = len(info.categories) + num_categories = 10 images = torch.rand(num_samples, 256) labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) @@ -784,10 +813,23 @@ def generate(cls, root, *, year, trainval): return num_samples_map -@register_mock -def voc(info, root, config): - trainval = config.split != "test" - return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] +@register_mock( + configs=[ + *combinations_grid( + split=("train", "val", "trainval"), + year=("2007", "2008", "2009", "2010", "2011", "2012"), + task=("detection", "segmentation"), + ), + *combinations_grid( + split=("test",), + year=("2007",), + task=("detection", "segmentation"), + ), + ], +) +def voc(root, config): + trainval = config["split"] != "test" + return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] class CelebAMockData: @@ -878,19 +920,14 @@ def generate(cls, root): return num_samples_map -@register_mock -def celeba(info, root, config): - return CelebAMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def celeba(root, config): + return CelebAMockData.generate(root)[config["split"]] -@register_mock -def country211(info, root, config): - split_name_mapper = { - "train": "train", - "val": "valid", - "test": "test", - } - split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]]) +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def country211(root, config): + split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"]) split_folder.mkdir(parents=True, exist_ok=True) num_examples = { @@ -911,8 +948,8 @@ def country211(info, root, config): return num_examples * len(classes) -@register_mock -def food101(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def food101(root, config): data_folder = root / "food-101" num_images_per_class = 3 @@ -946,11 +983,11 @@ def food101(info, root, config): make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") - return num_samples_map[config.split] + return num_samples_map[config["split"]] -@register_mock -def dtd(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) +def dtd(root, config): data_folder = root / "dtd" num_images_per_class = 3 @@ -990,20 +1027,21 @@ def dtd(info, root, config): with open(meta_folder / f"{split}{fold}.txt", "w") as file: file.write("\n".join(image_ids_in_config) + "\n") - num_samples_map[info.make_config(split=split, fold=str(fold))] = len(image_ids_in_config) + num_samples_map[(split, fold)] = len(image_ids_in_config) make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") - return num_samples_map[config] + return num_samples_map[config["split"], config["fold"]] -@register_mock -def fer2013(info, root, config): - num_samples = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def fer2013(root, config): + split = config["split"] + num_samples = 5 if split == "train" else 3 - path = root / f"{config.split}.csv" + path = root / f"{split}.csv" with open(path, "w", newline="") as file: - field_names = ["emotion"] if config.split == "train" else [] + field_names = ["emotion"] if split == "train" else [] field_names.append("pixels") file.write(",".join(field_names) + "\n") @@ -1013,7 +1051,7 @@ def fer2013(info, root, config): rowdict = { "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) } - if config.split == "train": + if split == "train": rowdict["emotion"] = int(torch.randint(7, ())) writer.writerow(rowdict) @@ -1022,9 +1060,9 @@ def fer2013(info, root, config): return num_samples -@register_mock -def gtsrb(info, root, config): - num_examples_per_class = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def gtsrb(root, config): + num_examples_per_class = 5 if config["split"] == "train" else 3 classes = ("00000", "00042", "00012") num_examples = num_examples_per_class * len(classes) @@ -1092,8 +1130,8 @@ def _make_ann_file(path, num_examples, class_idx): return num_examples -@register_mock -def clevr(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def clevr(root, config): data_folder = root / "CLEVR_v1.0" num_samples_map = { @@ -1134,7 +1172,7 @@ def clevr(info, root, config): make_zip(root, f"{data_folder.name}.zip", data_folder) - return num_samples_map[config.split] + return num_samples_map[config["split"]] class OxfordIIITPetMockData: @@ -1198,9 +1236,9 @@ def generate(self, root): return num_samples_map -@register_mock -def oxford_iiit_pet(info, root, config): - return OxfordIIITPetMockData.generate(root)[config.split] +@register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test"))) +def oxford_iiit_pet(root, config): + return OxfordIIITPetMockData.generate(root)[config["split"]] class _CUB200MockData: @@ -1364,14 +1402,14 @@ def generate(cls, root): return num_samples_map -@register_mock -def cub200(info, root, config): - num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) - return num_samples_map[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011"))) +def cub200(root, config): + num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root) + return num_samples_map[config["split"]] -@register_mock -def eurosat(info, root, config): +@register_mock(configs=[dict()]) +def eurosat(root, config): data_folder = root / "2750" data_folder.mkdir(parents=True) @@ -1388,18 +1426,18 @@ def eurosat(info, root, config): return len(categories) * num_examples_per_class -@register_mock -def svhn(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test", "extra"))) +def svhn(root, config): import scipy.io as sio num_samples = { "train": 2, "test": 3, "extra": 4, - }[config.split] + }[config["split"]] sio.savemat( - root / f"{config.split}_32x32.mat", + root / f"{config['split']}_32x32.mat", { "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), @@ -1408,13 +1446,13 @@ def svhn(info, root, config): return num_samples -@register_mock -def pcam(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def pcam(root, config): import h5py - num_images = {"train": 2, "test": 3, "val": 4}[config.split] + num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] - split = "valid" if config.split == "val" else config.split + split = "valid" if config["split"] == "val" else config["split"] images_io = io.BytesIO() with h5py.File(images_io, "w") as f: @@ -1435,18 +1473,19 @@ def pcam(info, root, config): return num_images -@register_mock -def stanford_cars(info, root, config): +@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test"))) +def stanford_cars(root, config): import scipy.io as io from numpy.core.records import fromarrays - num_samples = {"train": 5, "test": 7}[config["split"]] + split = config["split"] + num_samples = {"train": 5, "test": 7}[split] num_categories = 3 devkit = root / "devkit" devkit.mkdir(parents=True) - if config["split"] == "train": + if split == "train": images_folder_name = "cars_train" annotations_mat_path = devkit / "cars_train_annos.mat" else: @@ -1460,7 +1499,7 @@ def stanford_cars(info, root, config): num_examples=num_samples, ) - make_tar(root, f"cars_{config.split}.tgz", images_folder_name) + make_tar(root, f"cars_{split}.tgz", images_folder_name) bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) fnames = [f"{i:5d}.jpg" for i in range(num_samples)] @@ -1470,17 +1509,17 @@ def stanford_cars(info, root, config): ) io.savemat(annotations_mat_path, {"annotations": rec_array}) - if config.split == "train": + if split == "train": make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples -@register_mock -def usps(info, root, config): - num_samples = {"train": 15, "test": 7}[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"))) +def usps(root, config): + num_samples = {"train": 15, "test": 7}[config["split"]] - with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh: lines = [] for _ in range(num_samples): label = make_tensor(1, low=1, high=11, dtype=torch.int) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 8d51125f41c..fc2ebd9aa38 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -7,9 +7,10 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes -from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter +from torchdata.datapipes.iter import Shuffler, ShardingFilter from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE @@ -42,14 +43,24 @@ def test_coverage(): @pytest.mark.filterwarnings("error") class TestCommon: + @pytest.mark.parametrize("name", datasets.list_datasets()) + def test_info(self, name): + try: + info = datasets.info(name) + except ValueError: + raise AssertionError("No info available.") from None + + if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())): + raise AssertionError("Info should be a dictionary with string keys.") + @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) - if not isinstance(dataset, IterDataPipe): - raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") + if not isinstance(dataset, datasets.utils.Dataset): + raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, test_home, dataset_mock, config): @@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) - num_samples = 0 - for _ in dataset: - num_samples += 1 - - assert num_samples == mock_info["num_samples"] - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_decoding(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) - - dataset = datasets.load(dataset_mock.name, **config) - - undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} - if undecoded_features: - raise AssertionError( - f"The values of key(s) " - f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." - ) + assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, test_home, dataset_mock, config): @@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) + @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, test_home, dataset_mock, config): + def test_traversable(self, test_home, dataset_mock, config, only_datapipe): dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + traverse(dataset, only_datapipe=only_datapipe) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_serializable(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset) + @pytest.mark.parametrize("num_workers", [0, 1]) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_data_loader(self, test_home, dataset_mock, config, num_workers): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + dl = DataLoader( + dataset, + batch_size=2, + num_workers=num_workers, + collate_fn=lambda batch: batch, + ) + + next(iter(dl)) + # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. @@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config): def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): @@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): # resolved assert dp.buffer_size == INFINITE_BUFFER_SIZE + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_has_length(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + assert len(dataset) > 0 + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: @@ -186,7 +208,7 @@ class TestGTSRB: def test_label_matches_path(self, test_home, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same - if config.split != "train": + if config["split"] != "train": return dataset_mock.prepare(test_home, config) diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py deleted file mode 100644 index 70a2707d050..00000000000 --- a/test/test_prototype_datasets_api.py +++ /dev/null @@ -1,231 +0,0 @@ -import unittest.mock - -import pytest -from torchvision.prototype import datasets -from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch - - -def make_minimal_dataset_info(name="name", categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) - - -class TestFrozenMapping: - @pytest.mark.parametrize( - ("args", "kwargs"), - [ - pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"), - pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"), - pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"), - ], - ) - def test_instantiation(self, args, kwargs): - FrozenMapping(*args, **kwargs) - - def test_unhashable_items(self): - with pytest.raises(TypeError, match="unhashable type"): - FrozenMapping(foo=[]) - - def test_getitem(self): - options = dict(foo="bar", baz=1) - config = FrozenMapping(options) - - for key, value in options.items(): - assert config[key] == value - - def test_getitem_unknown(self): - with pytest.raises(KeyError): - FrozenMapping()["unknown"] - - def test_iter(self): - options = dict(foo="bar", baz=1) - assert set(iter(FrozenMapping(options))) == set(options.keys()) - - def test_len(self): - options = dict(foo="bar", baz=1) - assert len(FrozenMapping(options)) == len(options) - - def test_immutable_setitem(self): - frozen_mapping = FrozenMapping() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_mapping["foo"] = "bar" - - def test_immutable_delitem( - self, - ): - frozen_mapping = FrozenMapping(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_mapping["foo"] - - def test_eq(self): - options = dict(foo="bar", baz=1) - assert FrozenMapping(options) == FrozenMapping(options) - - def test_ne(self): - options1 = dict(foo="bar", baz=1) - options2 = options1.copy() - options2["baz"] += 1 - - assert FrozenMapping(options1) != FrozenMapping(options2) - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenMapping(options)) - - assert isinstance(output, str) - for key, value in options.items(): - assert str(key) in output and str(value) in output - - -class TestFrozenBunch: - def test_getattr(self): - options = dict(foo="bar", baz=1) - config = FrozenBunch(options) - - for key, value in options.items(): - assert getattr(config, key) == value - - def test_getattr_unknown(self): - with pytest.raises(AttributeError, match="no attribute 'unknown'"): - datasets.utils.DatasetConfig().unknown - - def test_immutable_setattr(self): - frozen_bunch = FrozenBunch() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_bunch.foo = "bar" - - def test_immutable_delattr( - self, - ): - frozen_bunch = FrozenBunch(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_bunch.foo - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenBunch(options)) - - assert isinstance(output, str) - assert output.startswith("FrozenBunch") - for key, value in options.items(): - assert f"{key}={value}" in output - - -class TestDatasetInfo: - @pytest.fixture - def info(self): - return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz"))) - - def test_default_config(self, info): - valid_options = info._valid_options - default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()}) - - assert info.default_config == default_config - - @pytest.mark.parametrize( - ("valid_options", "options", "expected_error_msg"), - [ - (dict(), dict(any_option=None), "does not take any options"), - (dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"), - (dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"), - ], - ) - def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg): - info = make_minimal_dataset_info(valid_options=valid_options) - - with pytest.raises(ValueError, match=expected_error_msg): - info.make_config(**options) - - def test_check_dependencies(self): - dependency = "fake_dependency" - info = make_minimal_dataset_info(dependencies=(dependency,)) - with pytest.raises(ModuleNotFoundError, match=dependency): - info.check_dependencies() - - def test_repr(self, info): - output = repr(info) - - assert isinstance(output, str) - assert "DatasetInfo" in output - for key, value in info._valid_options.items(): - assert f"{key}={str(value)[1:-1]}" in output - - @pytest.mark.parametrize("optional_info", ("citation", "homepage", "license")) - def test_repr_optional_info(self, optional_info): - sentinel = "sentinel" - info = make_minimal_dataset_info(**{optional_info: sentinel}) - - assert f"{optional_info}={sentinel}" in repr(info) - - -class TestDataset: - class DatasetMock(datasets.utils.Dataset): - def __init__(self, info=None, *, resources=None): - self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test"))) - self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources - self._make_datapipe = unittest.mock.Mock() - super().__init__() - - def _make_info(self): - return self._info - - def resources(self, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def _make_datapipe(self, resource_dps, *, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def test_name(self): - name = "sentinel" - dataset = self.DatasetMock(make_minimal_dataset_info(name=name)) - - assert dataset.name == name - - def test_default_config(self): - sentinel = "sentinel" - dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train")))) - - assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel) - - @pytest.mark.parametrize( - ("config", "kwarg"), - [ - pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"), - pytest.param(DatasetMock().default_config, None, id="default"), - ], - ) - def test_load_config(self, config, kwarg): - dataset = self.DatasetMock() - - dataset.load("", config=kwarg) - - dataset.resources.assert_called_with(config) - - _, call_kwargs = dataset._make_datapipe.call_args - assert call_kwargs["config"] == config - - def test_missing_dependencies(self): - dependency = "fake_dependency" - dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,))) - with pytest.raises(ModuleNotFoundError, match=dependency): - dataset.load("root") - - def test_resources(self, mocker): - resource_mock = mocker.Mock(spec=["load"]) - sentinel = object() - resource_mock.load.return_value = sentinel - dataset = self.DatasetMock(resources=[resource_mock]) - - root = "root" - dataset.load(root) - - (call_args, _) = resource_mock.load.call_args - assert call_args[0] == root - - (call_args, _) = dataset._make_datapipe.call_args - assert call_args[0][0] is sentinel diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index bd857abf02f..b1c95844574 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -5,7 +5,7 @@ import torch from datasets_utils import make_fake_flo_file from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -101,3 +101,21 @@ def preprocess_sentinel(path): assert redirected_resource.file_name == file_name assert redirected_resource.sha256 == sha256_sentinel assert redirected_resource._preprocess is preprocess_sentinel + + +def test_missing_dependency_error(): + class DummyDataset(Dataset): + def __init__(self): + super().__init__(root="root", dependencies=("fake_dependency",)) + + def _resources(self): + pass + + def _datapipe(self, resource_dps): + pass + + def __len__(self): + pass + + with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): + DummyDataset() diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1253635d51e..848d9135c2f 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -10,5 +10,6 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load # usort: skip +from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip from ._folder import from_data_folder, from_image_folder +from ._builtin import * diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 13ee920cea2..407dc23f64b 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,39 +1,50 @@ -import os -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar -from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo +from torchvision.prototype.datasets.utils import Dataset from torchvision.prototype.utils._internal import add_suggestion -from . import _builtin -DATASETS: Dict[str, Dataset] = {} +T = TypeVar("T") +D = TypeVar("D", bound=Type[Dataset]) +BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register(dataset: Dataset) -> None: - DATASETS[dataset.name] = dataset +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: + def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + BUILTIN_INFOS[name] = fn() + return fn -for name, obj in _builtin.__dict__.items(): - if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: - register(obj()) + return wrapper + + +BUILTIN_DATASETS = {} + + +def register_dataset(name: str) -> Callable[[D], D]: + def wrapper(dataset_cls: D) -> D: + BUILTIN_DATASETS[name] = dataset_cls + return dataset_cls + + return wrapper def list_datasets() -> List[str]: - return sorted(DATASETS.keys()) + return sorted(BUILTIN_DATASETS.keys()) -def find(name: str) -> Dataset: +def find(dct: Dict[str, T], name: str) -> T: name = name.lower() try: - return DATASETS[name] + return dct[name] except KeyError as error: raise ValueError( add_suggestion( f"Unknown dataset '{name}'.", word=name, - possibilities=DATASETS.keys(), + possibilities=dct.keys(), alternative_hint=lambda _: ( "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), @@ -41,19 +52,14 @@ def find(name: str) -> Dataset: ) from error -def info(name: str) -> DatasetInfo: - return find(name).info +def info(name: str) -> Dict[str, Any]: + return find(BUILTIN_INFOS, name) -def load( - name: str, - *, - skip_integrity_check: bool = False, - **options: Any, -) -> IterDataPipe[Dict[str, Any]]: - dataset = find(name) +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset: + dataset_cls = find(BUILTIN_DATASETS, name) - config = dataset.info.make_config(**options) - root = os.path.join(home(), dataset.name) + if root is None: + root = pathlib.Path(home()) / name - return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) + return dataset_cls(root, **config) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md index c20c0241fac..05d61c6870e 100644 --- a/torchvision/prototype/datasets/_builtin/README.md +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below. Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that -module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be -discussed in detail below: +module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in +detail below: ```python -from typing import Any, Dict, List +import pathlib +from typing import Any, BinaryIO, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource +from .._api import register_dataset, register_info + +NAME = "my-dataset" + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + ... + ) + +@register_dataset(NAME) class MyDataset(Dataset): - def _make_info(self) -> DatasetInfo: + def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None: ... + super().__init__(root, skip_integrity_check=skip_integrity_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: ... - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]: + ... + + def __len__(self) -> int: ... ``` -### `_make_info(self)` +In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a +dictionary of static information. The most common use case is to provide human-readable categories. +[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. -The `DatasetInfo` carries static information about the dataset. There are two required fields: +Finally, both the dataset class and the info function need to be registered on the API with the respective decorators. +With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively. -- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain - lowercase characters. +### `__init__(self, root, *, ..., skip_integrity_check = False)` -There are more optional parameters that can be passed: +Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the +base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as +setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke +the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with +an underscore. -- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their - availability will be automatically checked if a user tries to load the dataset. Within the implementation, import - these packages lazily to avoid missing dependencies at import time. -- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the - corresponding label returned in the dataset samples. - [See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. -- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[Any]]`. - The options are accessible through the `config` namespace in the other two functions. First value of the sequence is - taken as default if the user passes no option to `torchvision.prototype.datasets.load()`. +If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base +class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically +checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to +avoid missing dependencies at import time. -## `resources(self, config)` +### `_resources(self)` -Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a -specific `config` can be build. The download will happen automatically. +Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be +build. The download will happen automatically. Currently, the following `OnlineResource`'s are supported: @@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024): print(checksum.hexdigest()) ``` -### `_make_datapipe(resource_dps, *, config)` +### `_datapipe(self, resource_dps)` This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone @@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated to add one. See the MNIST or CelebA datasets for example. -`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return -value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain -tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one +`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return +value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain +tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one of such tuples for the file specified by the resource. Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and -`Grouper`. There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since -most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered -that is only needed at runtime. +`Grouper`. There are two issues with that: + +1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely. +2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime. Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than trying to zip already loaded images. There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and -`hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take -place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are -required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. +`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding +should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` +and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the names (yet!). +### `__len__` + +This returns an integer denoting the number of samples that can be drawn from the dataset. Please use +[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the +readability. For example, `1_281_167` vs. `1281167`. + +If there are only two different numbers, a simple `if` / `else` is fine: + +```py +def __len__(self): + return 12_345 if self._split == "train" else 6_789 +``` + +If there are more options, using a dictionary usually is the most readable option: + +```py +def __len__(self): + return { + "train": 3, + "val": 2, + "test": 1, + }[self._split] +``` + +If the number of samples depends on more than one parameter, you can use tuples as dictionary keys: + +```py +def __len__(self): + return { + ("train", "bar"): 4, + ("train", "baz"): 3, + ("test", "bar"): 2, + ("test", "baz"): 1, + }[(self._split, self._foo)] +``` + +The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the +development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way +is to define a dummy method like + +```py +def __len__(self): + return 1 +``` + +and only fill it with the correct data if the implementation is otherwise finished. +[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples. + ## Tests To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data. This mock-up should resemble the original data as close as necessary, while containing only few examples. To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the -same name as you have defined in `_make_config()` (if the name includes hyphens `-`, replace them with underscores `_`) -and decorate it with `@register_mock`: +same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function". +Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset +will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options, +you can use the `combinations_grid()` helper function, e.g. +`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`. + +In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass +the `name` parameter to `@register_mock` ```py # this is defined in torchvision/prototype/datasets/_builtin +@register_dataset("my-dataset") class MyDataset(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "my-dataset", - ... - ) - -@register_mock -def my_dataset(info, root, config): + ... + +@register_mock(name="my-dataset", configs=...) +def my_dataset(root, config): ... ``` -The function receives three arguments: +The mock data function receives two arguments: -- `info`: The return value of `_make_info()`. - `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data needs to be placed. -- `config`: The configuration to generate the data for. This is the same value that `_make_datapipe()` receives. +- `config`: The configuration to generate the data for. This is one of the dictionaries defined in + `@register_mock(configs=...)` The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if -the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current -`config`. Although this seems odd at first, this is important. Consider the following original data setup: +the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of +the current `config`. Although this seems odd at first, this is important. Consider the following original data setup: ``` root @@ -167,9 +234,8 @@ root For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in -`_make_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data -for the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real -data. +`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for +the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data. For datasets that are ported from the old API, we already have some mock data in [`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there @@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences - `tmp_dir` corresponds to `root`, but is a `str` rather than a [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like `folder = pathlib.Path(tmp_dir)`. This is not needed. -- Although both parameters are called `config`, the value in the new tests is a namespace. Thus, please use `config.foo` - over `config["foo"]` to enhance readability. - The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files specified in the dataset. @@ -196,9 +260,9 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets ### How do I start? -Get the skeleton of your dataset class ready with all 3 methods. For `_make_datapipe()`, you can just do +Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do `return resources_dp[0]` to get started. Then import the dataset class in -`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset and it will be +`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be instantiable via `datasets.load("mydataset")`. On a separate script, try something like ```py @@ -206,7 +270,7 @@ from torchvision.prototype import datasets dataset = datasets.load("mydataset") for sample in dataset: - print(sample) # this is the content of an item in datapipe returned by _make_datapipe() + print(sample) # this is the content of an item in datapipe returned by _datapipe() break # Or you can also inspect the sample in a debugger ``` @@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format. ### How do I handle a dataset that defines many categories? -As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or -fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line -specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be -automatically loaded if `categories=` is not set. +As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more +categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a +category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file` +function and pass it `$NAME`. In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where -each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets -passed the `root` path to the resources, but they have to be manually loaded, e.g. -`self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. +each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method +should return a sequence of strings representing the category names. In the method body, you'll have to manually load +the resources, e.g. + +```py +resources = self._resources() +dp = resources[0].load(self._root) +``` + +Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes +sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that. + To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. ### What if a resource file forms an I/O bottleneck? @@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed. `preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also accepts `"decompress"` and `"extract"` to handle these common scenarios. + +### How do I compute the number of samples? + +Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way +than to iterate over the dataset and count the number of samples: + +```py +import itertools +from torchvision.prototype import datasets + + +def combinations_grid(**kwargs): + return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] + + +# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there +configs = combinations_grid(split=("train", "test"), foo=("bar", "baz")) + +for config in configs: + dataset = datasets.load("my-dataset", **config) + + num_samples = 0 + for _ in dataset: + num_samples += 1 + + print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples) +``` + +To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation +files. diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index b2beddc7f2b..4acc1d53b4d 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -12,7 +12,7 @@ from .gtsrb import GTSRB from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST -from .oxford_iiit_pet import OxfordIITPet +from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM from .sbd import SBD from .semeion import SEMEION diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 4a409835b5e..7010ab9503d 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,6 +1,6 @@ import pathlib import re -from typing import Any, Dict, List, Tuple, BinaryIO +from typing import Any, Dict, List, Tuple, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -9,26 +9,46 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + read_mat, + hint_sharding, + hint_shuffling, + read_categories_file, ) -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from .._api import register_dataset, register_info + +@register_info("caltech101") +def _caltech101_info() -> Dict[str, Any]: + return dict(categories=read_categories_file("caltech101")) + + +@register_dataset("caltech101") class Caltech101(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "caltech101", + """ + - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101 + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + skip_integrity_check: bool = False, + ) -> None: + self._categories = _caltech101_info()["categories"] + + super().__init__( + root, dependencies=("scipy",), - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", + skip_integrity_check=skip_integrity_check, ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", @@ -88,7 +108,7 @@ def _prepare_sample( ann = read_mat(ann_buffer) return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), image_path=image_path, image=image, ann_path=ann_path, @@ -98,12 +118,7 @@ def _prepare_sample( contour=_Feature(ann["obj_contour"].T), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps images_dp = Filter(images_dp, self._is_not_background_image) @@ -122,23 +137,39 @@ def _make_datapipe( ) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 8677 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, self._is_not_background_image) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) +@register_info("caltech256") +def _caltech256_info() -> Dict[str, Any]: + return dict(categories=read_categories_file("caltech256")) + + +@register_dataset("caltech256") class Caltech256(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "caltech256", - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", - ) + """ + - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256 + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + skip_integrity_check: bool = False, + ) -> None: + self._categories = _caltech256_info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", @@ -156,25 +187,23 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: return dict( path=path, image=EncodedImage.from_file(buffer), - label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories), + label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 30607 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dir_names = {pathlib.Path(path).parent.name for path, _ in dp} return [name.split(".")[1] for name in sorted(dir_names)] diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 854c705b746..46ccf8de6f7 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,6 +1,6 @@ import csv -import functools -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO +import pathlib +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -11,8 +11,6 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, GDriveResource, OnlineResource, ) @@ -25,6 +23,7 @@ ) from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox +from .._api import register_dataset, register_info csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -60,15 +59,32 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: yield line.pop("image_id"), line +NAME = "celeba" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() + + +@register_dataset(NAME) class CelebA(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "celeba", - homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", - valid_options=dict(split=("train", "val", "test")), - ) + """ + - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: splits = GDriveResource( "0B7EVK8r0v71pY0NSMzRuSXJEVkk", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", @@ -101,14 +117,13 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [splits, images, identities, attributes, bounding_boxes, landmarks] - _SPLIT_ID_TO_NAME = { - "0": "train", - "1": "val", - "2": "test", - } - - def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: - return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split + def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: + split_id = { + "train": "0", + "val": "1", + "test": "2", + }[self._split] + return data[1]["split_id"] == split_id def _prepare_sample( self, @@ -145,16 +160,11 @@ def _prepare_sample( }, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) + splits_dp = Filter(splits_dp, self._filter_split) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -186,3 +196,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + "train": 162_770, + "val": 19_867, + "test": 19_962, + }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 3d7acefb903..514938d6e5f 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -1,9 +1,8 @@ import abc -import functools import io import pathlib import pickle -from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,20 +10,17 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, path_comparator, hint_sharding, + read_categories_file, ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info + class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -44,19 +40,23 @@ class _CifarBase(Dataset): _LABELS_KEY: str _META_FILE_NAME: str _CATEGORIES_KEY: str + _categories: List[str] + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: pass - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - type(self).__name__.lower(), - homepage="https://www.cs.toronto.edu/~kriz/cifar.html", - valid_options=dict(split=("train", "test")), - ) - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", @@ -72,52 +72,72 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data return dict( image=Image(image_array), - label=Label(category_idx, categories=self.categories), + label=Label(category_idx, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) + dp = Filter(dp, self._is_data_file) dp = Mapper(dp, self._unpickle) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 50_000 if self._split == "train" else 10_000 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Mapper(dp, self._unpickle) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) +@register_info("cifar10") +def _cifar10_info() -> Dict[str, Any]: + return dict(categories=read_categories_file("cifar10")) + + +@register_dataset("cifar10") class Cifar10(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-10-python.tar.gz" _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" _LABELS_KEY = "labels" _META_FILE_NAME = "batches.meta" _CATEGORIES_KEY = "label_names" + _categories = _cifar10_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name.startswith("data" if split == "train" else "test") + return path.name.startswith("data" if self._split == "train" else "test") + +@register_info("cifar100") +def _cifar100_info() -> Dict[str, Any]: + return dict(categories=read_categories_file("cifar100")) + +@register_dataset("cifar100") class Cifar100(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-100-python.tar.gz" _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" _LABELS_KEY = "fine_labels" _META_FILE_NAME = "meta" _CATEGORIES_KEY = "fine_label_names" + _categories = _cifar100_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name == split + return path.name == self._split diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index dd08a257a5b..3a139787c6f 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,14 +1,8 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, @@ -19,16 +13,30 @@ ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "clevr" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() + +@register_dataset(NAME) class CLEVR(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "clevr", - homepage="https://cs.stanford.edu/people/jcjohns/clevr/", - valid_options=dict(split=("train", "val", "test")), - ) + """ + - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", @@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A label=Label(len(scenes_data["objects"])) if scenes_data else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, scenes_dp = Demultiplexer( archive_dp, @@ -76,12 +79,12 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) + images_dp = Filter(images_dp, path_comparator("parent.name", self._split)) images_dp = hint_shuffling(images_dp) images_dp = hint_sharding(images_dp) - if config.split != "test": - scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) + if self._split != "test": + scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json")) scenes_dp = JsonParser(scenes_dp) scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) scenes_dp = UnBatcher(scenes_dp) @@ -97,3 +100,6 @@ def _make_datapipe( dp = Mapper(images_dp, self._add_empty_anns) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 70_000 if self._split == "train" else 15_000 diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 1005c7b3130..ff3b5f37c96 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,8 +1,8 @@ -import functools import pathlib import re from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import torch from torchdata.datapipes.iter import ( @@ -16,43 +16,65 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, INFINITE_BUFFER_SIZE, - BUILTIN_DIR, getitem, + read_categories_file, path_accessor, hint_sharding, hint_shuffling, ) from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping +from .._api import register_dataset, register_info + + +NAME = "coco" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, super_categories = zip(*read_categories_file(NAME)) + return dict(categories=categories, super_categories=super_categories) + +@register_dataset(NAME) class Coco(Dataset): - def _make_info(self) -> DatasetInfo: - name = "coco" - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("pycocotools",), - categories=categories, - homepage="https://cocodataset.org/", - valid_options=dict( - split=("train", "val"), - year=("2017", "2014"), - annotations=(*self._ANN_DECODERS.keys(), None), - ), - extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), + """ + - **homepage**: https://cocodataset.org/ + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2017", + annotations: Optional[str] = "instances", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val"}) + self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) + self._annotations = ( + self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) + if annotations is not None + else None ) + info = _info() + categories, super_categories = info["categories"], info["super_categories"] + self._categories = categories + self._category_to_super_category = dict(zip(categories, super_categories)) + + super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) + _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGES_CHECKSUMS = { @@ -69,14 +91,14 @@ def _make_info(self) -> DatasetInfo: "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( - f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip", - sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)], + f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", + sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], ) meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip", - sha256=self._META_CHECKSUMS[config.year], + f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", + sha256=self._META_CHECKSUMS[self._year], ) return [images, meta] @@ -110,10 +132,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", image_size=image_size, ), - labels=Label(labels, categories=self.categories), - super_categories=[ - self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels - ], + labels=Label(labels, categories=self._categories), + super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], ann_ids=[ann["id"] for ann in anns], ) @@ -134,9 +154,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) - def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: + def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) + return bool( + match + and match["split"] == self._split + and match["year"] == self._year + and match["annotations"] == self._annotations + ) def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data @@ -157,38 +182,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - *, - annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data sample = self._prepare_image(image_data) + # this method is only called if we have annotations + annotations = cast(str, self._annotations) sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps - if config.annotations is None: + if self._annotations is None: dp = hint_shuffling(images_dp) dp = hint_sharding(dp) + dp = hint_shuffling(dp) return Mapper(dp, self._prepare_image) - meta_dp = Filter( - meta_dp, - functools.partial( - self._filter_meta_files, - split=config.split, - year=config.year, - annotations=config.annotations, - ), - ) + meta_dp = Filter(meta_dp, self._filter_meta_files) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) @@ -216,7 +229,6 @@ def _make_datapipe( ref_key_fn=getitem("id"), buffer_size=INFINITE_BUFFER_SIZE, ) - dp = IterKeyZipper( anns_dp, images_dp, @@ -224,18 +236,24 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), + ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), + ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), + ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), + }[(self._split, self._year)][ + self._annotations # type: ignore[index] + ] - return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) - - def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: - config = self.default_config - resources = self.resources(config) + def _generate_categories(self) -> Tuple[Tuple[str, str]]: + self._annotations = "instances" + resources = self._resources() - dp = resources[1].load(root) - dp = Filter( - dp, - functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), - ) + dp = resources[1].load(self._root) + dp = Filter(dp, self._filter_meta_files) dp = JsonParser(dp) _, meta = next(iter(dp)) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 0b4dc306734..012ecae19e2 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,21 +1,47 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + path_comparator, + hint_sharding, + hint_shuffling, + read_categories_file, +) from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info +NAME = "country211" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + + +@register_dataset(NAME) class Country211(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "country211", - homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", - valid_options=dict(split=("train", "val", "test")), - ) + """ + - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + self._split_folder_name = "valid" if split == "val" else split + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://openaipublic.azureedge.net/clip/data/country211.tgz", @@ -23,17 +49,11 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) ] - _SPLIT_NAME_MAPPER = { - "train": "train", - "val": "valid", - "test": "test", - } - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) @@ -41,16 +61,21 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: return pathlib.Path(data[0]).parent.parent.name == split - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) + dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) - dp = resources[0].load(root) + def __len__(self) -> int: + return { + "train": 31_650, + "val": 10_550, + "test": 21_100, + }[self._split] + + def _generate_categories(self) -> List[str]: + resources = self._resources() + dp = resources[0].load(self._root) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 1b90b476aa7..1e4db7cef73 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,7 +1,7 @@ import csv import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -15,8 +15,6 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, ) @@ -27,27 +25,52 @@ hint_shuffling, getitem, path_comparator, + read_categories_file, path_accessor, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from .._api import register_dataset, register_info + csv.register_dialect("cub200", delimiter=" ") +NAME = "cub200" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + + +@register_dataset(NAME) class CUB200(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "cub200", - homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html", - dependencies=("scipy",), - valid_options=dict( - split=("train", "test"), - year=("2011", "2010"), - ), + """ + - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2011", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._year = self._verify_str_arg(year, "year", ("2010", "2011")) + + self._categories = _info()["categories"] + + super().__init__( + root, + # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 + # dependencies=("scipy",), + skip_integrity_check=skip_integrity_check, ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - if config.year == "2011": + def _resources(self) -> List[OnlineResource]: + if self._year == "2011": archive = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", @@ -59,7 +82,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: preprocess="decompress", ) return [archive, segmentations] - else: # config.year == "2010" + else: # self._year == "2010" split = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", @@ -90,12 +113,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _2011_filter_split(self, row: List[str], *, split: str) -> bool: + def _2011_filter_split(self, row: List[str]) -> bool: _, split_id = row return { "0": "test", "1": "train", - }[split_id] == split + }[split_id] == self._split def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) @@ -149,17 +172,12 @@ def _prepare_sample( return dict( prepare_ann_fn(anns_data, image.image_size), image=image, - label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), + label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: prepare_ann_fn: Callable - if config.year == "2011": + if self._year == "2011": archive_dp, segmentations_dp = resource_dps images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE @@ -171,7 +189,7 @@ def _make_datapipe( ) split_dp = CSVParser(split_dp, dialect="cub200") - split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split)) + split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Mapper(split_dp, getitem(0)) split_dp = Mapper(split_dp, image_files_map.get) @@ -188,10 +206,10 @@ def _make_datapipe( ) prepare_ann_fn = self._2011_prepare_ann - else: # config.year == "2010" + else: # self._year == "2010" split_dp, images_dp, anns_dp = resource_dps - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = Mapper(split_dp, self._2010_split_key) @@ -217,11 +235,19 @@ def _make_datapipe( ) return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(year="2011") - resources = self.resources(config) + def __len__(self) -> int: + return { + ("train", "2010"): 3_000, + ("test", "2010"): 3_033, + ("train", "2011"): 5_994, + ("test", "2011"): 5_794, + }[(self._split, self._year)] + + def _generate_categories(self) -> List[str]: + self._year = "2011" + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200") diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 682fed2d9c2..b082ada19ce 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,12 +1,10 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, ) @@ -15,10 +13,16 @@ hint_sharding, path_comparator, getitem, + read_categories_file, hint_shuffling, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + + +NAME = "dtd" + class DTDDemux(enum.IntEnum): SPLIT = 0 @@ -26,18 +30,36 @@ class DTDDemux(enum.IntEnum): IMAGES = 2 +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + + +@register_dataset(NAME) class DTD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "dtd", - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", - valid_options=dict( - split=("train", "test", "val"), - fold=tuple(str(fold) for fold in range(1, 11)), - ), - ) + """DTD Dataset. + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + fold: int = 1, + skip_validation_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + + if not (1 <= fold <= 10): + raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}") + self._fold = fold + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_validation_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", @@ -71,24 +93,19 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO return dict( joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] splits_dp, joint_categories_dp, images_dp = Demultiplexer( archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) + splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt")) splits_dp = LineReader(splits_dp, decode=True, return_path=False) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -114,10 +131,13 @@ def _make_datapipe( def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, self._filter_images) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) + + def __len__(self) -> int: + return 1_880 # All splits have the same length diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 336f35de968..ab31aaf6f42 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -1,31 +1,44 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info -class EuroSAT(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "eurosat", - homepage="https://github.com/phelber/eurosat", - categories=( - "AnnualCrop", - "Forest", - "HerbaceousVegetation", - "Highway", - "Industrial," "Pasture", - "PermanentCrop", - "Residential", - "River", - "SeaLake", - ), +NAME = "eurosat" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=( + "AnnualCrop", + "Forest", + "HerbaceousVegetation", + "Highway", + "Industrial," "Pasture", + "PermanentCrop", + "Residential", + "River", + "SeaLake", ) + ) + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class EuroSAT(Dataset): + """EuroSAT Dataset. + homepage="https://github.com/phelber/eurosat", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://madm.dfki.de/files/sentinel/EuroSAT.zip", @@ -37,15 +50,16 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 27_000 diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index a5bfa681d02..c1a914c6f63 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,11 +1,10 @@ -from typing import Any, Dict, List, cast +import pathlib +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, OnlineResource, KaggleDownloadResource, ) @@ -15,26 +14,40 @@ ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info +NAME = "fer2013" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral")) + + +@register_dataset(NAME) class FER2013(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fer2013", - homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", - categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), - valid_options=dict(split=("train", "test")), - ) + """FER 2013 Dataset + homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) _CHECKSUMS = { "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = KaggleDownloadResource( - cast(str, self.info.homepage), - file_name=f"{config.split}.csv.zip", - sha256=self._CHECKSUMS[config.split], + "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", + file_name=f"{self._split}.csv.zip", + sha256=self._CHECKSUMS[self._split], ) return [archive] @@ -43,17 +56,15 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: return dict( image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self.categories) if label_id is not None else None, + label=Label(int(label_id), categories=self._categories) if label_id is not None else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVDictParser(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 28_709 if self._split == "train" else 3_589 diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index cb720f137d9..5100e5d8c74 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, List, Dict, Optional, BinaryIO +from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -9,26 +9,41 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, hint_sharding, path_comparator, getitem, INFINITE_BUFFER_SIZE, + read_categories_file, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "food101" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + + +@register_dataset(NAME) class Food101(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "food101", - homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", - valid_options=dict(split=("train", "test")), - ) + """Food 101 dataset + homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", + """ + + def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", @@ -49,7 +64,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: id, (path, buffer) = data return dict( - label=Label.from_category(id.split("/", 1)[0], categories=self.categories), + label=Label.from_category(id.split("/", 1)[0], categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) @@ -58,17 +73,12 @@ def _image_key(self, data: Tuple[str, Any]) -> str: path = Path(data[0]) return path.relative_to(path.parents[1]).with_suffix("").as_posix() - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, split_dp = Demultiplexer( archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -83,9 +93,12 @@ def _make_datapipe( return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: Path) -> List[str]: - resources = self.resources(self.default_config) - dp = resources[0].load(root) + def _generate_categories(self) -> List[str]: + resources = self._resources() + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = LineReader(dp, decode=True, return_path=False) return list(dp) + + def __len__(self) -> int: + return 75_750 if self._split == "train" else 25_250 diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index c08d8947292..01f754208e2 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,11 +1,9 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, OnlineResource, HttpResource, ) @@ -17,15 +15,31 @@ ) from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from .._api import register_dataset, register_info +NAME = "gtsrb" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=[f"{label:05d}" for label in range(43)], + ) + + +@register_dataset(NAME) class GTSRB(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "gtsrb", - homepage="https://benchmark.ini.rub.de", - categories=[f"{label:05d}" for label in range(43)], - valid_options=dict(split=("train", "test")), - ) + """GTSRB Dataset + + homepage="https://benchmark.ini.rub.de" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" _URLS = { @@ -39,10 +53,10 @@ def _make_info(self) -> DatasetInfo: "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])] + def _resources(self) -> List[OnlineResource]: + rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - if config.split == "test": + if self._split == "test": rsrcs.append( HttpResource( self._URLS["test_ground_truth"], @@ -74,14 +88,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self.categories), + "label": Label(label, categories=self._categories), "bounding_box": bounding_box, } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) @@ -98,3 +110,6 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 26_640 if self._split == "train" else 12_630 diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 86bab4515e1..1307757cef6 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,7 +1,8 @@ +import enum import functools import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -14,23 +15,30 @@ Enumerator, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, OnlineResource, ManualDownloadResource, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, - BUILTIN_DIR, - path_comparator, getitem, read_mat, hint_sharding, hint_shuffling, + read_categories_file, + path_accessor, ) from torchvision.prototype.features import Label, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping + +from .._api import register_dataset, register_info + +NAME = "imagenet" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, wnids = zip(*read_categories_file(NAME)) + return dict(categories=categories, wnids=wnids) class ImageNetResource(ManualDownloadResource): @@ -38,32 +46,33 @@ def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) +class ImageNetDemux(enum.IntEnum): + META = 0 + LABEL = 1 + + +@register_dataset(NAME) class ImageNet(Dataset): - def _make_info(self) -> DatasetInfo: - name = "imagenet" - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("scipy",), - categories=categories, - homepage="https://www.image-net.org/", - valid_options=dict(split=("train", "val", "test")), - extra=dict( - wnid_to_category=FrozenMapping(zip(wnids, categories)), - category_to_wnid=FrozenMapping(zip(categories, wnids)), - sizes=FrozenMapping( - [ - (DatasetConfig(split="train"), 1_281_167), - (DatasetConfig(split="val"), 50_000), - (DatasetConfig(split="test"), 100_000), - ] - ), - ), - ) + """ + - **homepage**: https://www.image-net.org/ + """ - def supports_sharded(self) -> bool: - return True + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + + info = _info() + categories, wnids = info["categories"], info["wnids"] + self._categories = categories + self._wnids = wnids + self._wnid_to_category = dict(zip(wnids, categories)) + + super().__init__(root, skip_integrity_check=skip_integrity_check) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", @@ -71,15 +80,15 @@ def supports_sharded(self) -> bool: "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - name = "test_v10102019" if config.split == "test" else config.split + def _resources(self) -> List[OnlineResource]: + name = "test_v10102019" if self._split == "test" else self._split images = ImageNetResource( file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name], ) resources: List[OnlineResource] = [images] - if config.split == "val": + if self._split == "val": devkit = ImageNetResource( file_name="ILSVRC2012_devkit_t12.tar.gz", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", @@ -88,19 +97,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return resources - def num_samples(self, config: DatasetConfig) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[config.split] - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), data def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: @@ -108,10 +110,17 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: return { - "meta.mat": 0, - "ILSVRC2012_validation_ground_truth.txt": 1, + "meta.mat": ImageNetDemux.META, + "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, }.get(pathlib.Path(data[0]).name) + # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 + # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment + _WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", + } + def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: synsets = read_mat(data[1], squeeze_me=True)["synsets"] return [ @@ -121,21 +130,20 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl if num_children == 0 ] - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: + def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: return wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - def _val_test_image_key(self, data: Tuple[str, Any]) -> int: - path = pathlib.Path(data[0]) - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] + def _val_test_image_key(self, path: pathlib.Path) -> int: + return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] def _prepare_val_data( self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, wnid = label_data - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), image_data def _prepare_sample( @@ -150,19 +158,17 @@ def _prepare_sample( image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split in {"train", "test"}: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split in {"train", "test"}: dp = resource_dps[0] # the train archive is a tar of tars - if config.split == "train": + if self._split == "train": dp = TarArchiveLoader(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) - dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) + dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) else: # config.split == "val": images_dp, devkit_dp = resource_dps @@ -174,6 +180,7 @@ def _make_datapipe( _, wnids = zip(*next(iter(meta_dp))) label_dp = LineReader(label_dp, decode=True, return_path=False) + # We cannot use self._wnids here, since we use a different order than the dataset label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_shuffling(label_dp) @@ -183,26 +190,29 @@ def _make_datapipe( label_dp, images_dp, key_fn=getitem(0), - ref_key_fn=self._val_test_image_key, + ref_key_fn=path_accessor(self._val_test_image_key), buffer_size=INFINITE_BUFFER_SIZE, ) dp = Mapper(dp, self._prepare_val_data) return Mapper(dp, self._prepare_sample) - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } + def __len__(self) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[self._split] + + def _filter_meta(self, data: Tuple[str, Any]) -> bool: + return self._classifiy_devkit(data) == ImageNetDemux.META - def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: - config = self.info.make_config(split="val") - resources = self.resources(config) + def _generate_categories(self) -> List[Tuple[str, ...]]: + self._split = "val" + resources = self._resources() - devkit_dp = resources[1].load(root) - meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + devkit_dp = resources[1].load(self._root) + meta_dp = Filter(devkit_dp, self._filter_meta) meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 1e14e6dfc58..e5537a1ef66 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,12 +7,13 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile -__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] +from .._api import register_dataset, register_info + prod = functools.partial(functools.reduce, operator.mul) @@ -61,14 +62,14 @@ class _MNISTBase(Dataset): _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: pass - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: (images_file, images_sha256), ( labels_file, labels_sha256, - ) = self._files_and_checksums(config) + ) = self._files_and_checksums() url_bases = self._URL_BASE if isinstance(url_bases, str): @@ -82,21 +83,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [images, labels] - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: return None, None - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories: List[str] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, label = data return dict( image=Image(image), - label=Label(label, dtype=torch.int64, categories=self.categories), + label=Label(label, dtype=torch.int64, categories=self._categories), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps - start, stop = self.start_and_stop(config) + start, stop = self.start_and_stop() images_dp = Decompressor(images_dp) images_dp = MNISTFileReader(images_dp, start=start, stop=stop) @@ -107,19 +108,31 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) - return Mapper(dp, functools.partial(self._prepare_sample, config=config)) + return Mapper(dp, self._prepare_sample) + + +@register_info("mnist") +def _mnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("mnist") class MNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "mnist", - categories=10, - homepage="http://yann.lecun.com/exdb/mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://yann.lecun.com/exdb/mnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE: Union[str, Sequence[str]] = ( "http://yann.lecun.com/exdb/mnist", @@ -132,8 +145,8 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "train" if config.split == "train" else "t10k" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "train" if self._split == "train" else "t10k" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" return (images_file, self._CHECKSUMS[images_file]), ( @@ -141,28 +154,35 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) + _categories = _mnist_info()["categories"] + + def __len__(self) -> int: + return 60_000 if self._split == "train" else 10_000 + + +@register_info("fashionmnist") +def _fashionmnist_info() -> Dict[str, Any]: + return dict( + categories=[ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ], + ) + +@register_dataset("fashionmnist") class FashionMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fashionmnist", - categories=( - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ), - homepage="https://github.com/zalandoresearch/fashion-mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: https://github.com/zalandoresearch/fashion-mnist + """ _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" _CHECKSUMS = { @@ -172,17 +192,21 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", } + _categories = _fashionmnist_info()["categories"] + + +@register_info("kmnist") +def _kmnist_info() -> Dict[str, Any]: + return dict( + categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], + ) + +@register_dataset("kmnist") class KMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "kmnist", - categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], - homepage="http://codh.rois.ac.jp/kmnist/index.html.en", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en + """ _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" _CHECKSUMS = { @@ -192,36 +216,46 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", } + _categories = _kmnist_info()["categories"] + + +@register_info("emnist") +def _emnist_info() -> Dict[str, Any]: + return dict( + categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), + ) + +@register_dataset("emnist") class EMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "emnist", - categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), - homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", - valid_options=dict( - split=("train", "test"), - image_set=( - "Balanced", - "By_Merge", - "By_Class", - "Letters", - "Digits", - "MNIST", - ), - ), + """ + - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + image_set: str = "Balanced", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._image_set = self._verify_str_arg( + image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST") ) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" - # Since EMNIST provides the data files inside an archive, we don't need provide checksums for them + # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them return (images_file, ""), (labels_file, "") - def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"{self._URL_BASE}/emnist-gzip.zip", @@ -229,9 +263,9 @@ def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResour ) ] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: path = pathlib.Path(data[0]) - (images_file, _), (labels_file, _) = self._files_and_checksums(config) + (images_file, _), (labels_file, _) = self._files_and_checksums() if path.name == images_file: return 0 elif path.name == labels_file: @@ -239,6 +273,8 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None + _categories = _emnist_info()["categories"] + _LABEL_OFFSETS = { 38: 1, 39: 1, @@ -251,45 +287,71 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, - # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, - # since there is no 'c', 'd' corresponds to + # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For + # example, since there is no 'c', 'd' corresponds to # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), # and at the same time corresponds to # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) - # in self.categories. Thus, we need to add 1 to the label to correct this. - if config.image_set in ("Balanced", "By_Merge"): + # in self._categories. Thus, we need to add 1 to the label to correct this. + if self._image_set in ("Balanced", "By_Merge"): image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._prepare_sample(data, config=config) + return super()._prepare_sample(data) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( archive_dp, 2, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config) + return super()._datapipe([images_dp, labels_dp]) + + def __len__(self) -> int: + return { + ("train", "Balanced"): 112_800, + ("train", "By_Merge"): 697_932, + ("train", "By_Class"): 697_932, + ("train", "Letters"): 124_800, + ("train", "Digits"): 240_000, + ("train", "MNIST"): 60_000, + ("test", "Balanced"): 18_800, + ("test", "By_Merge"): 116_323, + ("test", "By_Class"): 116_323, + ("test", "Letters"): 20_800, + ("test", "Digits"): 40_000, + ("test", "MNIST"): 10_000, + }[(self._split, self._image_set)] + + +@register_info("qmnist") +def _qmnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("qmnist") class QMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "qmnist", - categories=10, - homepage="https://github.com/facebookresearch/qmnist", - valid_options=dict( - split=("train", "test", "test10k", "test50k", "nist"), - ), - ) + """ + - **homepage**: https://github.com/facebookresearch/qmnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" _CHECKSUMS = { @@ -301,9 +363,9 @@ def _make_info(self) -> DatasetInfo: "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}" - suffix = "xz" if config.split == "nist" else "gz" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}" + suffix = "xz" if self._split == "nist" else "gz" images_file = f"{prefix}-images-idx3-ubyte.{suffix}" labels_file = f"{prefix}-labels-idx2-int.{suffix}" return (images_file, self._CHECKSUMS[images_file]), ( @@ -311,13 +373,13 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: start: Optional[int] stop: Optional[int] - if config.split == "test10k": + if self._split == "test10k": start = 0 stop = 10000 - elif config.split == "test50k": + elif self._split == "test50k": start = 10000 stop = None else: @@ -325,10 +387,12 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories = _emnist_info()["categories"] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._prepare_sample((image, label), config=config) + sample = super()._prepare_sample((image, label)) sample.update( dict( @@ -340,3 +404,12 @@ def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: Da ) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) return sample + + def __len__(self) -> int: + return { + "train": 60_000, + "test": 60_000, + "test10k": 10_000, + "test50k": 50_000, + "nist": 402_953, + }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 8d4fc00dbb0..f7da02a4765 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,12 +1,10 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, ) @@ -16,27 +14,41 @@ hint_shuffling, getitem, path_accessor, + read_categories_file, path_comparator, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info -class OxfordIITPetDemux(enum.IntEnum): + +NAME = "oxford-iiit-pet" + + +class OxfordIIITPetDemux(enum.IntEnum): SPLIT_AND_CLASSIFICATION = 0 SEGMENTATIONS = 1 -class OxfordIITPet(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "oxford-iiit-pet", - homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", - valid_options=dict( - split=("trainval", "test"), - ), - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class OxfordIIITPet(Dataset): + """Oxford IIIT Pet Dataset + homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"trainval", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: images = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", @@ -51,8 +63,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: return { - "annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION, - "trimaps": OxfordIITPetDemux.SEGMENTATIONS, + "annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION, + "trimaps": OxfordIIITPetDemux.SEGMENTATIONS, }.get(pathlib.Path(data[0]).parent.name) def _filter_images(self, data: Tuple[str, Any]) -> bool: @@ -70,7 +82,7 @@ def _prepare_sample( image_path, image_buffer = image_data return dict( - label=Label(int(classification_data["label"]) - 1, categories=self.categories), + label=Label(int(classification_data["label"]) - 1, categories=self._categories), species="cat" if classification_data["species"] == "1" else "dog", segmentation_path=segmentation_path, segmentation=EncodedImage.from_file(segmentation_buffer), @@ -78,9 +90,7 @@ def _prepare_sample( image=EncodedImage.from_file(image_buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps images_dp = Filter(images_dp, self._filter_images) @@ -93,9 +103,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - split_and_classification_dp = Filter( - split_and_classification_dp, path_comparator("name", f"{config.split}.txt") - ) + split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt")) split_and_classification_dp = CSVDictParser( split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " ) @@ -122,15 +130,14 @@ def _make_datapipe( return Mapper(dp, self._prepare_sample) def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION + return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.default_config - resources = self.resources(config) + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[1].load(root) + dp = resources[1].load(self._root) dp = Filter(dp, self._filter_split_and_classification_anns) - dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) + dp = Filter(dp, path_comparator("name", "trainval.txt")) dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} @@ -138,3 +145,6 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]: *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) ) return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories] + + def __len__(self) -> int: + return 3_680 if self._split == "trainval" else 3_669 diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 3d7b9547a76..7cd31469139 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,13 +1,12 @@ import io +import pathlib from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Iterator +from typing import Any, Dict, List, Optional, Tuple, Iterator, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, OnlineResource, GDriveResource, ) @@ -17,6 +16,11 @@ ) from torchvision.prototype.features import Label +from .._api import register_dataset, register_info + + +NAME = "pcam" + class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): def __init__( @@ -40,15 +44,25 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=["0", "1"]) + + +@register_dataset(NAME) class PCAM(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "pcam", - homepage="https://github.com/basveeling/pcam", - categories=2, - valid_options=dict(split=("train", "test", "val")), - dependencies=["h5py"], - ) + # TODO write proper docstring + """PCAM Dataset + + homepage="https://github.com/basveeling/pcam" + """ + + def __init__( + self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",)) _RESOURCES = { "train": ( @@ -89,10 +103,10 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ # = [images resource, targets resource] GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") - for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] + for file_name, gdrive_id, sha256 in self._RESOURCES[self._split] ] def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: @@ -100,12 +114,10 @@ def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: return { "image": features.Image(image.transpose(2, 0, 1)), - "label": Label(target.item()), + "label": Label(target.item(), categories=self._categories), } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps @@ -116,3 +128,6 @@ def _make_datapipe( dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 262_144 if self._split == "train" else 32_768 diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 7fd47b6c991..0c806fe098c 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,6 +1,6 @@ import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,13 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -26,22 +20,42 @@ path_comparator, hint_sharding, hint_shuffling, + read_categories_file, ) from torchvision.prototype.features import _Feature, EncodedImage +from .._api import register_dataset, register_info + +NAME = "sbd" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + +@register_dataset(NAME) class SBD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "sbd", - dependencies=("scipy",), - homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", - valid_options=dict( - split=("train", "val", "train_noval"), - ), - ) + """ + - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) + + self._categories = _info()["categories"] + + super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", @@ -85,12 +99,7 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st segmentation=_Feature(anns["Segmentation"].item()), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps archive_dp = resource_dps[0] @@ -101,10 +110,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": + if self._split == "train_noval": split_dp = extra_split_dp - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -120,10 +129,17 @@ def _make_datapipe( ) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return { + "train": 8_498, + "val": 2_857, + "train_noval": 5_623, + }[self._split] + + def _generate_categories(self) -> Tuple[str, ...]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "category_names.m")) dp = LineReader(dp) dp = Mapper(dp, bytes.decode, input_col=1) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index fb64c051d6c..5051bde4047 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Tuple +import pathlib +from typing import Any, Dict, List, Tuple, Union import torch from torchdata.datapipes.iter import ( @@ -8,24 +9,34 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, OneHotLabel +from .._api import register_dataset, register_info +NAME = "semeion" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(i) for i in range(10)]) + + +@register_dataset(NAME) class SEMEION(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "semeion", - categories=10, - homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", - ) + """Semeion dataset + homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: data = HttpResource( "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", @@ -36,18 +47,16 @@ def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: image_data, label_data = data[:256], data[256:-1] return dict( - image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), - label=OneHotLabel([int(label) for label in label_data], categories=self.categories), + image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)), + label=OneHotLabel([int(label) for label in label_data], categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 1_593 diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 51c0b6152e6..465d753c2e5 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -1,11 +1,19 @@ import pathlib -from typing import Any, Dict, List, Tuple, Iterator, BinaryIO +from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, + hint_shuffling, + path_comparator, + read_mat, + read_categories_file, +) from torchvision.prototype.features import BoundingBox, EncodedImage, Label +from .._api import register_dataset, register_info + class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: @@ -18,16 +26,31 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: yield tuple(ann) # type: ignore[misc] +NAME = "stanford-cars" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) + + +@register_dataset(NAME) class StanfordCars(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - name="stanford-cars", - homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", - dependencies=("scipy",), - valid_options=dict( - split=("test", "train"), - ), - ) + """Stanford Cars dataset. + homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", + dependencies=scipy + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) _URL_ROOT = "https://ai.stanford.edu/~jkrause/" _URLS = { @@ -44,9 +67,9 @@ def _make_info(self) -> DatasetInfo: "car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])] - if config.split == "train": + def _resources(self) -> List[OnlineResource]: + resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])] + if self._split == "train": resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"])) else: @@ -65,19 +88,14 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, return dict( path=path, image=image, - label=Label(target[4] - 1, categories=self.categories), + label=Label(target[4] - 1, categories=self._categories), bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps - if config.split == "train": + if self._split == "train": targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) targets_dp = StanfordCarsLabelReader(targets_dp) dp = Zipper(images_dp, targets_dp) @@ -85,12 +103,14 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(split="train") - resources = self.resources(config) + def _generate_categories(self) -> List[str]: + resources = self._resources() - devkit_dp = resources[1].load(root) + devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat")) _, meta_file = next(iter(meta_dp)) return list(read_mat(meta_file, squeeze_me=True)["class_names"]) + + def __len__(self) -> int: + return 8_144 if self._split == "train" else 8_041 diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 70daece4f86..175aa6c0a51 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Tuple, BinaryIO +import pathlib +from typing import Any, Dict, List, Tuple, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -8,8 +9,6 @@ ) from torchvision.prototype.datasets.utils import ( Dataset, - DatasetConfig, - DatasetInfo, HttpResource, OnlineResource, ) @@ -20,16 +19,33 @@ ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info +NAME = "svhn" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(c) for c in range(10)]) + + +@register_dataset(NAME) class SVHN(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "svhn", - dependencies=("scipy",), - categories=10, - homepage="http://ufldl.stanford.edu/housenumbers/", - valid_options=dict(split=("train", "test", "extra")), - ) + """SVHN Dataset. + homepage="http://ufldl.stanford.edu/housenumbers/", + dependencies = scipy + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) _CHECKSUMS = { "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", @@ -37,10 +53,10 @@ def _make_info(self) -> DatasetInfo: "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: data = HttpResource( - f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat", - sha256=self._CHECKSUMS[config.split], + f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat", + sha256=self._CHECKSUMS[self._split], ) return [data] @@ -60,18 +76,20 @@ def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any] return dict( image=Image(image_array.transpose((2, 0, 1))), - label=Label(int(label_array) % 10, categories=self.categories), + label=Label(int(label_array) % 10, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Mapper(dp, self._read_images_and_labels) dp = UnBatcher(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + "train": 73_257, + "test": 26_032, + "extra": 531_131, + }[self._split] diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 155fbff5dbb..e732f3b788a 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -1,22 +1,39 @@ -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label +from .._api import register_dataset, register_info +NAME = "usps" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(c) for c in range(10)]) + + +@register_dataset(NAME) class USPS(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "usps", - homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", - valid_options=dict( - split=("train", "test"), - ), - categories=10, - ) + """USPS Dataset + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" @@ -29,8 +46,8 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - return [USPS._RESOURCES[config.split]] + def _resources(self) -> List[OnlineResource]: + return [USPS._RESOURCES[self._split]] def _prepare_sample(self, line: str) -> Dict[str, Any]: label, *values = line.strip().split(" ") @@ -38,17 +55,15 @@ def _prepare_sample(self, line: str) -> Dict[str, Any]: pixels = torch.tensor(values).add_(1).div_(2) return dict( image=Image(pixels.reshape(16, 16)), - label=Label(int(label) - 1, categories=self.categories), + label=Label(int(label) - 1, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = Decompressor(resource_dps[0]) dp = LineReader(dp, decode=True, return_path=False) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 7_291 if self._split == "train" else 2_007 diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 5c1d3f8c3a3..05a3c2e8622 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,6 +1,7 @@ +import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union from xml.etree import ElementTree from torchdata.datapipes.iter import ( @@ -12,13 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -26,34 +21,48 @@ path_comparator, hint_sharding, hint_shuffling, + read_categories_file, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from .._api import register_dataset, register_info -class VOCDatasetInfo(DatasetInfo): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007") +NAME = "voc" - def make_config(self, **options: Any) -> DatasetConfig: - config = super().make_config(**options) - if config.split == "test" and config.year != "2007": - raise ValueError("`split='test'` is only available for `year='2007'`") - return config +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=read_categories_file(NAME)) +@register_dataset(NAME) class VOC(Dataset): - def _make_info(self) -> DatasetInfo: - return VOCDatasetInfo( - "voc", - homepage="http://host.robots.ox.ac.uk/pascal/VOC/", - valid_options=dict( - split=("train", "val", "trainval", "test"), - year=("2012", "2007", "2008", "2009", "2010", "2011"), - task=("detection", "segmentation"), - ), - ) + """ + - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2012", + task: str = "detection", + skip_integrity_check: bool = False, + ) -> None: + self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) + if split == "test" and year != "2007": + raise ValueError("`split='test'` is only available for `year='2007'`") + else: + self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) + + self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass" + self._split_folder = "Main" if task == "detection" else "Segmentation" + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), @@ -67,31 +76,27 @@ def _make_info(self) -> DatasetInfo: "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year] - archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256) + def _resources(self) -> List[OnlineResource]: + file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] + archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) return [archive] - _ANNS_FOLDER = dict( - detection="Annotations", - segmentation="SegmentationClass", - ) - _SPLIT_FOLDER = dict( - detection="Main", - segmentation="Segmentation", - ) - def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: path = pathlib.Path(data[0]) return name in path.parent.parts[-depth:] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + class _Demux(enum.IntEnum): + SPLIT = 0 + IMAGES = 1 + ANNS = 2 + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: if self._is_in_folder(data, name="ImageSets", depth=2): - return 0 + return self._Demux.SPLIT elif self._is_in_folder(data, name="JPEGImages"): - return 1 - elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): - return 2 + return self._Demux.IMAGES + elif self._is_in_folder(data, name=self._anns_folder): + return self._Demux.ANNS else: return None @@ -111,7 +116,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), labels=Label( - [self.categories.index(instance["name"]) for instance in instances], categories=self.categories + [self._categories.index(instance["name"]) for instance in instances], categories=self._categories ), ) @@ -121,8 +126,6 @@ def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], - *, - prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data @@ -130,29 +133,24 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( - prepare_ann_fn(ann_buffer), + (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer), image_path=image_path, image=EncodedImage.from_file(image_buffer), ann_path=ann_path, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( archive_dp, 3, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder)) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -166,25 +164,59 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, - functools.partial( - self._prepare_sample, - prepare_ann_fn=self._prepare_detection_ann - if config.task == "detection" - else self._prepare_segmentation_ann, - ), - ) - - def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: - return self._classify_archive(data, config=config) == 2 - - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(task="detection") - - resource = self.resources(config)[0] - dp = resource.load(pathlib.Path(root) / self.name) - dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2007", "detection"): 2_501, + ("train", "2007", "segmentation"): 209, + ("train", "2008", "detection"): 2_111, + ("train", "2008", "segmentation"): 511, + ("train", "2009", "detection"): 3_473, + ("train", "2009", "segmentation"): 749, + ("train", "2010", "detection"): 4_998, + ("train", "2010", "segmentation"): 964, + ("train", "2011", "detection"): 5_717, + ("train", "2011", "segmentation"): 1_112, + ("train", "2012", "detection"): 5_717, + ("train", "2012", "segmentation"): 1_464, + ("val", "2007", "detection"): 2_510, + ("val", "2007", "segmentation"): 213, + ("val", "2008", "detection"): 2_221, + ("val", "2008", "segmentation"): 512, + ("val", "2009", "detection"): 3_581, + ("val", "2009", "segmentation"): 750, + ("val", "2010", "detection"): 5_105, + ("val", "2010", "segmentation"): 964, + ("val", "2011", "detection"): 5_823, + ("val", "2011", "segmentation"): 1_111, + ("val", "2012", "detection"): 5_823, + ("val", "2012", "segmentation"): 1_449, + ("trainval", "2007", "detection"): 5_011, + ("trainval", "2007", "segmentation"): 422, + ("trainval", "2008", "detection"): 4_332, + ("trainval", "2008", "segmentation"): 1_023, + ("trainval", "2009", "detection"): 7_054, + ("trainval", "2009", "segmentation"): 1_499, + ("trainval", "2010", "detection"): 10_103, + ("trainval", "2010", "segmentation"): 1_928, + ("trainval", "2011", "detection"): 11_540, + ("trainval", "2011", "segmentation"): 2_223, + ("trainval", "2012", "detection"): 11_540, + ("trainval", "2012", "segmentation"): 2_913, + ("test", "2007", "detection"): 4_952, + ("test", "2007", "segmentation"): 210, + }[(self._split, self._year, self._task)] + + def _filter_anns(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == self._Demux.ANNS + + def _generate_categories(self) -> List[str]: + self._task = "detection" + resources = self._resources() + + archive_dp = resources[0].load(self._root) + dp = Filter(archive_dp, self._filter_anns) dp = Mapper(dp, self._parse_detection_ann, input_col=1) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 3c2bf7e73cb..6d4e854fe34 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -2,25 +2,21 @@ import argparse import csv -import pathlib import sys from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR def main(*names, force=False): - home = pathlib.Path(datasets.home()) - for name in names: path = BUILTIN_DIR / f"{name}.categories" if path.exists() and not force: continue - dataset = find(name) + dataset = datasets.load(name) try: - categories = dataset._generate_categories(home / name) + categories = dataset._generate_categories() except NotImplementedError: continue diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 9423b65a8ee..e7ef72f28a9 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._dataset import Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index b5c6d7acb60..528d0a0f25f 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,184 +1,57 @@ import abc -import csv import importlib -import itertools -import os import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection +from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator from torch.utils.data import IterDataPipe -from torchvision._utils import sequence_to_str -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion +from torchvision.datasets.utils import verify_str_arg -from .._home import use_sharded_dataset -from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource -class DatasetConfig(FrozenBunch): - # This needs to be Frozen because we often pass configs as partial(func, config=config) - # and partial() requires the parameters to be hashable. - pass - +class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): + @staticmethod + def _verify_str_arg( + value: str, + arg: Optional[str] = None, + valid_values: Optional[Collection[str]] = None, + *, + custom_msg: Optional[str] = None, + ) -> str: + return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) -class DatasetInfo: def __init__( - self, - name: str, - *, - dependencies: Collection[str] = (), - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, - citation: Optional[str] = None, - homepage: Optional[str] = None, - license: Optional[str] = None, - valid_options: Optional[Dict[str, Sequence[Any]]] = None, - extra: Optional[Dict[str, Any]] = None, + self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = () ) -> None: - self.name = name.lower() - - self.dependecies = dependencies - - if categories is None: - path = BUILTIN_DIR / f"{self.name}.categories" - categories = path if path.exists() else [] - if isinstance(categories, int): - categories = [str(label) for label in range(categories)] - elif isinstance(categories, (str, pathlib.Path)): - path = pathlib.Path(categories).expanduser().resolve() - categories, *_ = zip(*self.read_categories_file(path)) - self.categories = tuple(categories) - - self.citation = citation - self.homepage = homepage - self.license = license - - self._valid_options = valid_options or dict() - self._configs = tuple( - DatasetConfig(**dict(zip(self._valid_options.keys(), combination))) - for combination in itertools.product(*self._valid_options.values()) - ) - - self.extra = FrozenBunch(extra or dict()) - - @property - def default_config(self) -> DatasetConfig: - return self._configs[0] - - @staticmethod - def read_categories_file(path: pathlib.Path) -> List[List[str]]: - with open(path, newline="") as file: - return [row for row in csv.reader(file)] - - def make_config(self, **options: Any) -> DatasetConfig: - if not self._valid_options and options: - raise ValueError( - f"Dataset {self.name} does not take any options, " - f"but got {sequence_to_str(list(options), separate_last=' and')}." - ) - - for name, arg in options.items(): - if name not in self._valid_options: - raise ValueError( - add_suggestion( - f"Unknown option '{name}' of dataset {self.name}.", - word=name, - possibilities=sorted(self._valid_options.keys()), - ) - ) - - valid_args = self._valid_options[name] - - if arg not in valid_args: - raise ValueError( - add_suggestion( - f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.", - word=arg, - possibilities=valid_args, - ) - ) - - return DatasetConfig(self.default_config, **options) - - def check_dependencies(self) -> None: - for dependency in self.dependecies: + for dependency in dependencies: try: importlib.import_module(dependency) - except ModuleNotFoundError as error: + except ModuleNotFoundError: raise ModuleNotFoundError( - f"Dataset '{self.name}' depends on the third-party package '{dependency}'. " + f"{type(self).__name__}() depends on the third-party package '{dependency}'. " f"Please install it, for example with `pip install {dependency}`." - ) from error - - def __repr__(self) -> str: - items = [("name", self.name)] - for key in ("citation", "homepage", "license"): - value = getattr(self, key) - if value is not None: - items.append((key, value)) - items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items())) - return make_repr(type(self).__name__, items) + ) from None + self._root = pathlib.Path(root).expanduser().resolve() + resources = [ + resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() + ] + self._dp = self._datapipe(resources) -class Dataset(abc.ABC): - def __init__(self) -> None: - self._info = self._make_info() + def __iter__(self) -> Iterator[Dict[str, Any]]: + yield from self._dp @abc.abstractmethod - def _make_info(self) -> DatasetInfo: + def _resources(self) -> List[OnlineResource]: pass - @property - def info(self) -> DatasetInfo: - return self._info - - @property - def name(self) -> str: - return self.info.name - - @property - def default_config(self) -> DatasetConfig: - return self.info.default_config - - @property - def categories(self) -> Tuple[str, ...]: - return self.info.categories - @abc.abstractmethod - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: pass @abc.abstractmethod - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def __len__(self) -> int: pass - def supports_sharded(self) -> bool: - return False - - def load( - self, - root: Union[str, pathlib.Path], - *, - config: Optional[DatasetConfig] = None, - skip_integrity_check: bool = False, - ) -> IterDataPipe[Dict[str, Any]]: - if not config: - config = self.info.default_config - - if use_sharded_dataset() and self.supports_sharded(): - root = os.path.join(root, *config.values()) - dataset_size = self.info.extra["sizes"][config] - return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return] - - self.info.check_dependencies() - resource_dps = [ - resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) - ] - return self._make_datapipe(resource_dps, config=config) - - def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: + def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index fa48218fe02..007e91eb657 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,3 +1,4 @@ +import csv import functools import pathlib import pickle @@ -9,6 +10,7 @@ Any, Tuple, TypeVar, + List, Iterator, Dict, IO, @@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) + + +def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]: + path = BUILTIN_DIR / f"{name}.categories" + with open(path, newline="") as file: + rows = list(csv.reader(file)) + rows = [row[0] if len(row) == 1 else row for row in rows] + return rows diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe5284394cb..233128880e3 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -2,20 +2,13 @@ import difflib import io import mmap -import os -import os.path import platform -import textwrap from typing import ( Any, BinaryIO, Callable, - cast, Collection, - Iterable, Iterator, - Mapping, - NoReturn, Sequence, Tuple, TypeVar, @@ -30,9 +23,6 @@ __all__ = [ "add_suggestion", - "FrozenMapping", - "make_repr", - "FrozenBunch", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -60,82 +50,9 @@ def add_suggestion( return f"{msg.strip()} {hint}" -K = TypeVar("K") D = TypeVar("D") -class FrozenMapping(Mapping[K, D]): - def __init__(self, *args: Any, **kwargs: Any) -> None: - data = dict(*args, **kwargs) - self.__dict__["__data__"] = data - self.__dict__["__final_hash__"] = hash(tuple(data.items())) - - def __getitem__(self, item: K) -> D: - return cast(Mapping[K, D], self.__dict__["__data__"])[item] - - def __iter__(self) -> Iterator[K]: - return iter(self.__dict__["__data__"].keys()) - - def __len__(self) -> int: - return len(self.__dict__["__data__"]) - - def __immutable__(self) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __setitem__(self, key: K, value: Any) -> NoReturn: - self.__immutable__() - - def __delitem__(self, key: K) -> NoReturn: - self.__immutable__() - - def __hash__(self) -> int: - return cast(int, self.__dict__["__final_hash__"]) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FrozenMapping): - return NotImplemented - - return hash(self) == hash(other) - - def __repr__(self) -> str: - return repr(self.__dict__["__data__"]) - - -def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: - def to_str(sep: str) -> str: - return sep.join([f"{key}={value}" for key, value in items]) - - prefix = f"{name}(" - postfix = ")" - body = to_str(", ") - - line_length = int(os.environ.get("COLUMNS", 80)) - body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length - multiline_body = len(str(body).splitlines()) > 1 - if not (body_too_long or multiline_body): - return prefix + body + postfix - - body = textwrap.indent(to_str(",\n"), " " * 2) - return f"{prefix}\n{body}\n{postfix}" - - -class FrozenBunch(FrozenMapping): - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError as error: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error - - def __setattr__(self, key: Any, value: Any) -> NoReturn: - self.__immutable__() - - def __delattr__(self, item: Any) -> NoReturn: - self.__immutable__() - - def __repr__(self) -> str: - return make_repr(type(self).__name__, self.items()) - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size))