Skip to content

Add MovingMNIST dataset #7042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ Video classification
Kinetics
UCF101

Video prediction
~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated/
:template: class_dataset.rst

MovingMNIST

.. _base_classes_datasets:

Expand Down
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,37 @@ def test_num_examples_test50k(self):
assert len(dataset) == info["num_examples"] - 10000


class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples

@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return

with self.create_dataset(config) as (dataset, info):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
assert (dataset.data == 1).all()


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder

Expand Down
4 changes: 4 additions & 0 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def qmnist():
)


def moving_mnist():
return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST")


def omniglot():
return itertools.chain(
*[
Expand Down
1 change: 1 addition & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
Expand Down
93 changes: 93 additions & 0 deletions torchvision/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os.path
from typing import Callable, Optional

import numpy as np
import torch
from torchvision.datasets.utils import download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset


class MovingMNIST(VisionDataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"

def __init__(
self,
root: str,
split: Optional[str] = None,
split_ratio: int = 10,
download: bool = False,
transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform)

self._base_folder = os.path.join(self.root, self.__class__.__name__)
self._filename = self._URL.split("/")[-1]

if split is not None:
verify_str_arg(split, "split", ("train", "test"))
self.split = split

if not isinstance(split_ratio, int):
raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
elif not (1 <= split_ratio <= 19):
raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
self.split_ratio = split_ratio

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")

data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()

def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data = self.data[idx]
if self.transform is not None:
data = self.transform(data)

return data

def __len__(self) -> int:
return len(self.data)

def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self._base_folder, self._filename))

def download(self) -> None:
if self._check_exists():
return

download_url(
url=self._URL,
root=self._base_folder,
filename=self._filename,
md5="be083ec986bfe91a449d63653c411eb2",
)