Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MovingMNIST dataset #7042

Merged
merged 15 commits into from
Jan 4, 2023
8 changes: 8 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ Video classification
Kinetics
UCF101

Video prediction
~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated/
:template: class_dataset.rst

MovingMNIST

.. _base_classes_datasets:

Expand Down
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,37 @@ def test_num_examples_test50k(self):
assert len(dataset) == info["num_examples"] - 10000


class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples

@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return

with self.create_dataset(config) as (dataset, info):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
assert (dataset.data == 1).all()


tsugumi-sys marked this conversation as resolved.
Show resolved Hide resolved
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder

Expand Down
4 changes: 4 additions & 0 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def qmnist():
)


def moving_mnist():
return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST")


def omniglot():
return itertools.chain(
*[
Expand Down
1 change: 1 addition & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
Expand Down
93 changes: 93 additions & 0 deletions torchvision/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os.path
from typing import Callable, Optional

import numpy as np
import torch
from torchvision.datasets.utils import download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset


class MovingMNIST(VisionDataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"

def __init__(
self,
root: str,
split: Optional[str] = None,
split_ratio: int = 10,
download: bool = False,
transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform)

self._base_folder = os.path.join(self.root, self.__class__.__name__)
self._filename = self._URL.split("/")[-1]

if split is not None:
verify_str_arg(split, "split", ("train", "test"))
self.split = split

if not isinstance(split_ratio, int):
raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
elif not (1 <= split_ratio <= 19):
raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
self.split_ratio = split_ratio
tsugumi-sys marked this conversation as resolved.
Show resolved Hide resolved

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")

data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()

def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data = self.data[idx]
if self.transform is not None:
data = self.transform(data)

return data

def __len__(self) -> int:
return len(self.data)

def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self._base_folder, self._filename))

def download(self) -> None:
if self._check_exists():
return

download_url(
url=self._URL,
root=self._base_folder,
filename=self._filename,
md5="be083ec986bfe91a449d63653c411eb2",
)