forked from facebookresearch/ClassyVision
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move Dataloader Wrappers to OSS (facebookresearch#455)
Summary: Pull Request resolved: facebookresearch#455 This will be helpful for OSS users who implement their own Iterable datasets. Differential Revision: D20605900 fbshipit-source-id: c5039f948d60cb160d6cc11a68d182c5a8ddefeb
- Loading branch information
1 parent
ffbad54
commit 882a224
Showing
6 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torchvision.transforms as transforms | ||
from classy_vision.dataset import register_dataset | ||
from classy_vision.dataset.classy_dataset import ClassyDataset | ||
from classy_vision.dataset.core import RandomImageBinaryClassDataset | ||
from classy_vision.dataset.dataloader_wrappers import DataloaderLimitWrapper | ||
from classy_vision.dataset.transforms.util import ( | ||
ImagenetConstants, | ||
build_field_transform_default_imagenet, | ||
) | ||
|
||
|
||
@register_dataset("synthetic_image_streaming") | ||
class SyntheticImageClassificationStreamingDataset(ClassyDataset): | ||
""" | ||
Synthetic image dataset that behaves like a streaming dataset. | ||
Requires a "num_samples" argument which decides the number of samples in the | ||
phase. Also takes an optional "length" input which sets the length of the | ||
dataset. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batchsize_per_replica, | ||
shuffle, | ||
transform, | ||
num_samples, | ||
crop_size, | ||
class_ratio, | ||
seed, | ||
length=None, | ||
): | ||
if length is None: | ||
# If length not provided, set to be same as num_samples | ||
length = num_samples | ||
|
||
dataset = RandomImageBinaryClassDataset(crop_size, class_ratio, length, seed) | ||
super().__init__( | ||
dataset, batchsize_per_replica, shuffle, transform, num_samples | ||
) | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
assert all(key in config for key in ["crop_size", "class_ratio", "seed"]) | ||
length = config.get("length") | ||
crop_size = config["crop_size"] | ||
class_ratio = config["class_ratio"] | ||
seed = config["seed"] | ||
( | ||
transform_config, | ||
batchsize_per_replica, | ||
shuffle, | ||
num_samples, | ||
) = cls.parse_config(config) | ||
default_transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=ImagenetConstants.MEAN, std=ImagenetConstants.STD | ||
), | ||
] | ||
) | ||
transform = build_field_transform_default_imagenet( | ||
transform_config, default_transform=default_transform | ||
) | ||
return cls( | ||
batchsize_per_replica, | ||
shuffle, | ||
transform, | ||
num_samples, | ||
crop_size, | ||
class_ratio, | ||
seed, | ||
length=length, | ||
) | ||
|
||
def iterator(self, *args, **kwargs): | ||
return DataloaderLimitWrapper( | ||
super().iterator(*args, **kwargs), | ||
self.num_samples // self.get_global_batchsize(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .dataloader_limit_wrapper import DataloaderLimitWrapper | ||
from .dataloader_skip_none_wrapper import DataloaderSkipNoneWrapper | ||
from .dataloader_wrapper import DataloaderWrapper | ||
|
||
|
||
__all__ = ["DataloaderLimitWrapper", "DataloaderSkipNoneWrapper", "DataloaderWrapper"] |
77 changes: 77 additions & 0 deletions
77
classy_vision/dataset/dataloader_wrappers/dataloader_limit_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
from typing import Any, Iterable, Iterator | ||
|
||
from .dataloader_wrapper import DataloaderWrapper | ||
|
||
|
||
class DataloaderLimitWrapper(DataloaderWrapper): | ||
""" | ||
Dataloader which wraps another dataloader and only returns a limited | ||
number of items. | ||
This is useful for Iterable datasets where the length of the datasets isn't known. | ||
Such datasets can wrap their returned iterators with this class. See | ||
:func:`SyntheticImageClassificationStreamingDataset.iterator` for an example. | ||
Attribute accesses are passed to the wrapped dataloader. | ||
""" | ||
|
||
def __init__( | ||
self, dataloader: Iterable, limit: int, wrap_around: bool = True | ||
) -> None: | ||
"""Constructor for DataloaderLimitWrapper. | ||
Args: | ||
dataloader: The dataloader to wrap around | ||
limit: Specify the number of calls to the underlying dataloader. The wrapper | ||
will raise a `StopIteration` after `limit` calls. | ||
wrap_around: Whether to wrap around the original datatloader if the | ||
dataloader is exhausted before `limit` calls. | ||
Raises: | ||
RuntimeError: If `wrap_around` is set to `False` and the underlying | ||
dataloader is exhausted before `limit` calls. | ||
""" | ||
super().__init__(dataloader) | ||
# we use self.__dict__ to set the attributes since the __setattr__ method | ||
# is overridden | ||
attributes = {"limit": limit, "wrap_around": wrap_around, "_count": None} | ||
self.__dict__.update(attributes) | ||
|
||
def __iter__(self) -> Iterator[Any]: | ||
self._iter = iter(self.dataloader) | ||
self._count = 0 | ||
return self | ||
|
||
def __next__(self) -> Any: | ||
if self._count >= self.limit: | ||
raise StopIteration | ||
self._count += 1 | ||
try: | ||
return next(self._iter) | ||
except StopIteration: | ||
if self.wrap_around: | ||
# create a new iterator to load data from the beginning | ||
logging.info( | ||
f"Wrapping around after {self._count} calls. Limit: {self.limit}" | ||
) | ||
try: | ||
self._iter = iter(self.dataloader) | ||
return next(self._iter) | ||
except StopIteration: | ||
raise RuntimeError( | ||
"Looks like the dataset is empty, " | ||
"have you configured it properly?" | ||
) | ||
else: | ||
raise RuntimeError( | ||
f"StopIteration raised before {self.limit} items were returned" | ||
) | ||
|
||
def __len__(self) -> int: | ||
return self.limit |
33 changes: 33 additions & 0 deletions
33
classy_vision/dataset/dataloader_wrappers/dataloader_skip_none_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Iterable, Iterator | ||
|
||
from .dataloader_wrapper import DataloaderWrapper | ||
|
||
|
||
class DataloaderSkipNoneWrapper(DataloaderWrapper): | ||
""" | ||
Dataloader which wraps another dataloader and skip `None` batch data. | ||
Attribute accesses are passed to the wrapped dataloader. | ||
""" | ||
|
||
def __init__(self, dataloader: Iterable) -> None: | ||
super().__init__(dataloader) | ||
|
||
def __iter__(self) -> Iterator[Any]: | ||
self._iter = iter(self.dataloader) | ||
return self | ||
|
||
def __next__(self) -> Any: | ||
# we may get `None` batch data when all the images/videos in the batch | ||
# are corrupted. In such case, we keep getting the next batch until | ||
# meeting a good batch. | ||
next_batch = None | ||
while next_batch is None: | ||
next_batch = next(self._iter) | ||
return next_batch |
47 changes: 47 additions & 0 deletions
47
classy_vision/dataset/dataloader_wrappers/dataloader_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Iterable, Iterator | ||
|
||
|
||
class DataloaderWrapper(ABC): | ||
""" | ||
Abstract class representing dataloader which wraps another dataloader. | ||
Attribute accesses are passed to the wrapped dataloader. | ||
""" | ||
|
||
def __init__(self, dataloader: Iterable) -> None: | ||
# we use self.__dict__ to set the attributes since the __setattr__ method | ||
# is overridden | ||
attributes = {"dataloader": dataloader, "_iter": None} | ||
self.__dict__.update(attributes) | ||
|
||
@abstractmethod | ||
def __iter__(self) -> Iterator[Any]: | ||
pass | ||
|
||
@abstractmethod | ||
def __next__(self) -> Any: | ||
pass | ||
|
||
def __getattr__(self, attr) -> Any: | ||
""" | ||
Pass the getattr call to the wrapped dataloader | ||
""" | ||
if attr in self.__dict__: | ||
return self.__dict__[attr] | ||
return getattr(self.dataloader, attr) | ||
|
||
def __setattr__(self, attr, value) -> None: | ||
""" | ||
Pass the setattr call to the wrapped dataloader | ||
""" | ||
if attr in self.__dict__: | ||
self.__dict__[attr] = value | ||
else: | ||
setattr(self.dataloader, attr, value) |
53 changes: 53 additions & 0 deletions
53
test/dataset_dataloader_wrappers_dataloader_limit_wrapper_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from test.generic.config_utils import get_test_task_config | ||
|
||
from classy_vision.tasks import build_task | ||
|
||
|
||
class TestDataloaderLimitWrapper(unittest.TestCase): | ||
def _test_number_of_batches(self, data_iterator, expected_batches): | ||
num_batches = 0 | ||
for _ in data_iterator: | ||
num_batches += 1 | ||
self.assertEqual(num_batches, expected_batches) | ||
|
||
def test_streaming_dataset(self): | ||
""" | ||
Test that streaming datasets return the correct number of batches, and that | ||
the length is also calculated correctly. | ||
""" | ||
config = get_test_task_config() | ||
dataset_config = { | ||
"name": "synthetic_image_streaming", | ||
"split": "train", | ||
"crop_size": 224, | ||
"class_ratio": 0.5, | ||
"num_samples": 2000, | ||
"length": 4000, | ||
"seed": 0, | ||
"batchsize_per_replica": 32, | ||
"use_shuffle": True, | ||
} | ||
expected_batches = 62 | ||
config["dataset"]["train"] = dataset_config | ||
task = build_task(config) | ||
task.prepare() | ||
task.advance_phase() | ||
# test that the number of batches expected is correct | ||
self.assertEqual(task.num_batches_per_phase, expected_batches) | ||
|
||
# test that the data iterator returns the expected number of batches | ||
data_iterator = task.get_data_iterator() | ||
self._test_number_of_batches(data_iterator, expected_batches) | ||
|
||
# test that the dataloader can be rebuilt from the dataset inside it | ||
task._recreate_data_loader_from_dataset() | ||
task.create_data_iterator() | ||
data_iterator = task.get_data_iterator() | ||
self._test_number_of_batches(data_iterator, expected_batches) |