Skip to content

Commit

Permalink
Move Dataloader Wrappers to OSS (facebookresearch#455)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#455

This will be helpful for OSS users who implement their own Iterable datasets.

Differential Revision: D20605900

fbshipit-source-id: c5039f948d60cb160d6cc11a68d182c5a8ddefeb
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 24, 2020
1 parent ffbad54 commit 882a224
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 0 deletions.
88 changes: 88 additions & 0 deletions classy_vision/dataset/classy_synthetic_image_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torchvision.transforms as transforms
from classy_vision.dataset import register_dataset
from classy_vision.dataset.classy_dataset import ClassyDataset
from classy_vision.dataset.core import RandomImageBinaryClassDataset
from classy_vision.dataset.dataloader_wrappers import DataloaderLimitWrapper
from classy_vision.dataset.transforms.util import (
ImagenetConstants,
build_field_transform_default_imagenet,
)


@register_dataset("synthetic_image_streaming")
class SyntheticImageClassificationStreamingDataset(ClassyDataset):
"""
Synthetic image dataset that behaves like a streaming dataset.
Requires a "num_samples" argument which decides the number of samples in the
phase. Also takes an optional "length" input which sets the length of the
dataset.
"""

def __init__(
self,
batchsize_per_replica,
shuffle,
transform,
num_samples,
crop_size,
class_ratio,
seed,
length=None,
):
if length is None:
# If length not provided, set to be same as num_samples
length = num_samples

dataset = RandomImageBinaryClassDataset(crop_size, class_ratio, length, seed)
super().__init__(
dataset, batchsize_per_replica, shuffle, transform, num_samples
)

@classmethod
def from_config(cls, config):
assert all(key in config for key in ["crop_size", "class_ratio", "seed"])
length = config.get("length")
crop_size = config["crop_size"]
class_ratio = config["class_ratio"]
seed = config["seed"]
(
transform_config,
batchsize_per_replica,
shuffle,
num_samples,
) = cls.parse_config(config)
default_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=ImagenetConstants.MEAN, std=ImagenetConstants.STD
),
]
)
transform = build_field_transform_default_imagenet(
transform_config, default_transform=default_transform
)
return cls(
batchsize_per_replica,
shuffle,
transform,
num_samples,
crop_size,
class_ratio,
seed,
length=length,
)

def iterator(self, *args, **kwargs):
return DataloaderLimitWrapper(
super().iterator(*args, **kwargs),
self.num_samples // self.get_global_batchsize(),
)
12 changes: 12 additions & 0 deletions classy_vision/dataset/dataloader_wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .dataloader_limit_wrapper import DataloaderLimitWrapper
from .dataloader_skip_none_wrapper import DataloaderSkipNoneWrapper
from .dataloader_wrapper import DataloaderWrapper


__all__ = ["DataloaderLimitWrapper", "DataloaderSkipNoneWrapper", "DataloaderWrapper"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Iterable, Iterator

from .dataloader_wrapper import DataloaderWrapper


class DataloaderLimitWrapper(DataloaderWrapper):
"""
Dataloader which wraps another dataloader and only returns a limited
number of items.
This is useful for Iterable datasets where the length of the datasets isn't known.
Such datasets can wrap their returned iterators with this class. See
:func:`SyntheticImageClassificationStreamingDataset.iterator` for an example.
Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(
self, dataloader: Iterable, limit: int, wrap_around: bool = True
) -> None:
"""Constructor for DataloaderLimitWrapper.
Args:
dataloader: The dataloader to wrap around
limit: Specify the number of calls to the underlying dataloader. The wrapper
will raise a `StopIteration` after `limit` calls.
wrap_around: Whether to wrap around the original datatloader if the
dataloader is exhausted before `limit` calls.
Raises:
RuntimeError: If `wrap_around` is set to `False` and the underlying
dataloader is exhausted before `limit` calls.
"""
super().__init__(dataloader)
# we use self.__dict__ to set the attributes since the __setattr__ method
# is overridden
attributes = {"limit": limit, "wrap_around": wrap_around, "_count": None}
self.__dict__.update(attributes)

def __iter__(self) -> Iterator[Any]:
self._iter = iter(self.dataloader)
self._count = 0
return self

def __next__(self) -> Any:
if self._count >= self.limit:
raise StopIteration
self._count += 1
try:
return next(self._iter)
except StopIteration:
if self.wrap_around:
# create a new iterator to load data from the beginning
logging.info(
f"Wrapping around after {self._count} calls. Limit: {self.limit}"
)
try:
self._iter = iter(self.dataloader)
return next(self._iter)
except StopIteration:
raise RuntimeError(
"Looks like the dataset is empty, "
"have you configured it properly?"
)
else:
raise RuntimeError(
f"StopIteration raised before {self.limit} items were returned"
)

def __len__(self) -> int:
return self.limit
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable, Iterator

from .dataloader_wrapper import DataloaderWrapper


class DataloaderSkipNoneWrapper(DataloaderWrapper):
"""
Dataloader which wraps another dataloader and skip `None` batch data.
Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(self, dataloader: Iterable) -> None:
super().__init__(dataloader)

def __iter__(self) -> Iterator[Any]:
self._iter = iter(self.dataloader)
return self

def __next__(self) -> Any:
# we may get `None` batch data when all the images/videos in the batch
# are corrupted. In such case, we keep getting the next batch until
# meeting a good batch.
next_batch = None
while next_batch is None:
next_batch = next(self._iter)
return next_batch
47 changes: 47 additions & 0 deletions classy_vision/dataset/dataloader_wrappers/dataloader_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Iterable, Iterator


class DataloaderWrapper(ABC):
"""
Abstract class representing dataloader which wraps another dataloader.
Attribute accesses are passed to the wrapped dataloader.
"""

def __init__(self, dataloader: Iterable) -> None:
# we use self.__dict__ to set the attributes since the __setattr__ method
# is overridden
attributes = {"dataloader": dataloader, "_iter": None}
self.__dict__.update(attributes)

@abstractmethod
def __iter__(self) -> Iterator[Any]:
pass

@abstractmethod
def __next__(self) -> Any:
pass

def __getattr__(self, attr) -> Any:
"""
Pass the getattr call to the wrapped dataloader
"""
if attr in self.__dict__:
return self.__dict__[attr]
return getattr(self.dataloader, attr)

def __setattr__(self, attr, value) -> None:
"""
Pass the setattr call to the wrapped dataloader
"""
if attr in self.__dict__:
self.__dict__[attr] = value
else:
setattr(self.dataloader, attr, value)
53 changes: 53 additions & 0 deletions test/dataset_dataloader_wrappers_dataloader_limit_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from test.generic.config_utils import get_test_task_config

from classy_vision.tasks import build_task


class TestDataloaderLimitWrapper(unittest.TestCase):
def _test_number_of_batches(self, data_iterator, expected_batches):
num_batches = 0
for _ in data_iterator:
num_batches += 1
self.assertEqual(num_batches, expected_batches)

def test_streaming_dataset(self):
"""
Test that streaming datasets return the correct number of batches, and that
the length is also calculated correctly.
"""
config = get_test_task_config()
dataset_config = {
"name": "synthetic_image_streaming",
"split": "train",
"crop_size": 224,
"class_ratio": 0.5,
"num_samples": 2000,
"length": 4000,
"seed": 0,
"batchsize_per_replica": 32,
"use_shuffle": True,
}
expected_batches = 62
config["dataset"]["train"] = dataset_config
task = build_task(config)
task.prepare()
task.advance_phase()
# test that the number of batches expected is correct
self.assertEqual(task.num_batches_per_phase, expected_batches)

# test that the data iterator returns the expected number of batches
data_iterator = task.get_data_iterator()
self._test_number_of_batches(data_iterator, expected_batches)

# test that the dataloader can be rebuilt from the dataset inside it
task._recreate_data_loader_from_dataset()
task.create_data_iterator()
data_iterator = task.get_data_iterator()
self._test_number_of_batches(data_iterator, expected_batches)

0 comments on commit 882a224

Please sign in to comment.