Skip to content

Commit

Permalink
Add MovingMNIST dataset (#7042)
Browse files Browse the repository at this point in the history
* add moving mnist dataset

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* remove unused modules

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* modify docstring

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* modify docstring and docs

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* add split and split ratio kwargs

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* fix checking split argument

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* remove unused package

* delete lines

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* fix filename property

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* fix reviews

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* modify docstrings

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* add split tests and etc

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

* fix tests

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>

Signed-off-by: tsugumi-sys <tidemark0105@gmail.com>
  • Loading branch information
tsugumi-sys authored Jan 4, 2023
1 parent 32d254b commit 46b7e27
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ Video classification
Kinetics
UCF101

Video prediction
~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: generated/
:template: class_dataset.rst

MovingMNIST

.. _base_classes_datasets:

Expand Down
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,37 @@ def test_num_examples_test50k(self):
assert len(dataset) == info["num_examples"] - 10000


class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))

def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples

@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return

with self.create_dataset(config) as (dataset, info):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
assert (dataset.data == 1).all()


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder

Expand Down
4 changes: 4 additions & 0 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ def qmnist():
)


def moving_mnist():
return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST")


def omniglot():
return itertools.chain(
*[
Expand Down
1 change: 1 addition & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
Expand Down
93 changes: 93 additions & 0 deletions torchvision/datasets/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os.path
from typing import Callable, Optional

import numpy as np
import torch
from torchvision.datasets.utils import download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset


class MovingMNIST(VisionDataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"

def __init__(
self,
root: str,
split: Optional[str] = None,
split_ratio: int = 10,
download: bool = False,
transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform)

self._base_folder = os.path.join(self.root, self.__class__.__name__)
self._filename = self._URL.split("/")[-1]

if split is not None:
verify_str_arg(split, "split", ("train", "test"))
self.split = split

if not isinstance(split_ratio, int):
raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
elif not (1 <= split_ratio <= 19):
raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
self.split_ratio = split_ratio

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")

data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()

def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data = self.data[idx]
if self.transform is not None:
data = self.transform(data)

return data

def __len__(self) -> int:
return len(self.data)

def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self._base_folder, self._filename))

def download(self) -> None:
if self._check_exists():
return

download_url(
url=self._URL,
root=self._base_folder,
filename=self._filename,
md5="be083ec986bfe91a449d63653c411eb2",
)

0 comments on commit 46b7e27

Please sign in to comment.