diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index f3f0b466d62..5ecb60d3624 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -149,6 +149,14 @@ Video classification Kinetics UCF101 +Video prediction +~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated/ + :template: class_dataset.rst + + MovingMNIST .. _base_classes_datasets: diff --git a/test/test_datasets.py b/test/test_datasets.py index dbce7853eff..bd6d1dcb259 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1494,6 +1494,37 @@ def test_num_examples_test50k(self): assert len(dataset) == info["num_examples"] - 10000 +class MovingMNISTTestCase(datasets_utils.DatasetTestCase): + DATASET_CLASS = datasets.MovingMNIST + FEATURE_TYPES = (torch.Tensor,) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19)) + + def inject_fake_data(self, tmpdir, config): + base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__) + os.makedirs(base_folder, exist_ok=True) + num_samples = 20 + data = np.concatenate( + [ + np.zeros((config["split_ratio"], num_samples, 64, 64)), + np.ones((20 - config["split_ratio"], num_samples, 64, 64)), + ] + ) + np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data) + return num_samples + + @datasets_utils.test_all_configs + def test_split(self, config): + if config["split"] is None: + return + + with self.create_dataset(config) as (dataset, info): + if config["split"] == "train": + assert (dataset.data == 0).all() + else: + assert (dataset.data == 1).all() + + class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.DatasetFolder diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 1e76ba42e53..c748a8a0ff1 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -296,6 +296,10 @@ def qmnist(): ) +def moving_mnist(): + return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST") + + def omniglot(): return itertools.chain( *[ diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 23eddb236b0..e18a9a54b16 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -36,6 +36,7 @@ from .lfw import LFWPairs, LFWPeople from .lsun import LSUN, LSUNClass from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST +from .moving_mnist import MovingMNIST from .omniglot import Omniglot from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM diff --git a/torchvision/datasets/moving_mnist.py b/torchvision/datasets/moving_mnist.py new file mode 100644 index 00000000000..afff0bfa3b9 --- /dev/null +++ b/torchvision/datasets/moving_mnist.py @@ -0,0 +1,93 @@ +import os.path +from typing import Callable, Optional + +import numpy as np +import torch +from torchvision.datasets.utils import download_url, verify_str_arg +from torchvision.datasets.vision import VisionDataset + + +class MovingMNIST(VisionDataset): + """`MovingMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists. + split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``. + If ``split=None``, the full data is returned. + split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split + frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]`` + is returned. If ``split=None``, this parameter is ignored and the all frames data is returned. + transform (callable, optional): A function/transform that takes in an torch Tensor + and returns a transformed version. E.g, ``transforms.RandomCrop`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy" + + def __init__( + self, + root: str, + split: Optional[str] = None, + split_ratio: int = 10, + download: bool = False, + transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform) + + self._base_folder = os.path.join(self.root, self.__class__.__name__) + self._filename = self._URL.split("/")[-1] + + if split is not None: + verify_str_arg(split, "split", ("train", "test")) + self.split = split + + if not isinstance(split_ratio, int): + raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}") + elif not (1 <= split_ratio <= 19): + raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.") + self.split_ratio = split_ratio + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it.") + + data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename))) + if self.split == "train": + data = data[: self.split_ratio] + else: + data = data[self.split_ratio :] + self.data = data.transpose(0, 1).unsqueeze(2).contiguous() + + def __getitem__(self, idx: int) -> torch.Tensor: + """ + Args: + index (int): Index + Returns: + torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames. + """ + data = self.data[idx] + if self.transform is not None: + data = self.transform(data) + + return data + + def __len__(self) -> int: + return len(self.data) + + def _check_exists(self) -> bool: + return os.path.exists(os.path.join(self._base_folder, self._filename)) + + def download(self) -> None: + if self._check_exists(): + return + + download_url( + url=self._URL, + root=self._base_folder, + filename=self._filename, + md5="be083ec986bfe91a449d63653c411eb2", + )