Skip to content

Commit

Permalink
added islice for iterable datapipes
Browse files Browse the repository at this point in the history
  • Loading branch information
Diamond Bishop committed Aug 13, 2022
1 parent 827b13d commit ffb73a7
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ These DataPipes helps you select specific samples within a DataPipe.
Filter
Header
Dropper
ISlicer

Text DataPipes
-----------------------------
Expand Down
68 changes: 68 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,74 @@ def test_drop_iterdatapipe(self):
drop_dp = input_dp.drop([0, 1])
self.assertEqual(3, len(drop_dp))

def test_islice_iterdatapipe(self):
# tuple tests
input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)])

# Functional Test: slice with no stop and no step for tuple
islice_dp = input_dp.islice(1)
self.assertEqual([(1, 2), (4, 5), (7, 8)], list(islice_dp))

# Functional Test: slice with no step for tuple
islice_dp = input_dp.islice(0, 2)
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp))

# Functional Test: slice with step for tuple
islice_dp = input_dp.islice(0, 2, 2)
self.assertEqual([(0,), (3,), (6,)], list(islice_dp))

# Functional Test: filter with list of indices for tuple
islice_dp = input_dp.islice([0, 1])
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp))

# list tests
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])

# Functional Test: slice with no stop and no step for list
islice_dp = input_dp.islice(1)
self.assertEqual([[1, 2], [4, 5], [7, 8]], list(islice_dp))

# Functional Test: slice with no step for list
islice_dp = input_dp.islice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp))

# Functional Test: filter with list of indices for list
islice_dp = input_dp.islice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp))

# dict tests
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}])

# Functional Test: slice with no stop and no step for dict
islice_dp = input_dp.islice(1)
self.assertEqual([{"b": 2, "c": 3}, {"b": 4, "c": 5}, {"b": 6, "c": 7}], list(islice_dp))

# Functional Test: slice with no step for dict
islice_dp = input_dp.islice(0, 2)
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp))

# Functional Test: slice with step for dict
islice_dp = input_dp.islice(0, 2, 2)
self.assertEqual([{"a": 1}, {"a": 3}, {"a": 5}], list(islice_dp))

# Functional Test: filter with list of indices for dict
islice_dp = input_dp.islice(["a", "b"])
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp))

# __len__ Test:
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
islice_dp = input_dp.islice(0, 2)
self.assertEqual(3, len(islice_dp))

# Reset Test:
n_elements_before_reset = 2
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
islice_dp = input_dp.islice([2])
expected_res = [[2], [5], [8]]
res_before_reset, res_after_reset = reset_after_n_next_calls(islice_dp, n_elements_before_reset)
self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
self.assertEqual(expected_res, res_after_reset)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_serializable(self):
(iterdp.DataFrameMaker, IterableWrapper([(i,) for i in range(3)]), (), {"dtype": DTYPE}),
(iterdp.Decompressor, None, (), {}),
(iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.ISlicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.Enumerator, None, (2,), {}),
(iterdp.FlatMapper, None, (_fake_fn_ls,), {}),
(iterdp.FSSpecFileLister, ".", (), {}),
Expand Down
3 changes: 3 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
BatchMapperIterDataPipe as BatchMapper,
DropperIterDataPipe as Dropper,
FlatMapperIterDataPipe as FlatMapper,
ISliceIterDataPipe as ISlicer,
)
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
from torchdata.datapipes.iter.util.cacheholder import (
Expand Down Expand Up @@ -164,6 +165,7 @@
"Header",
"HttpReader",
"HuggingFaceHubReader",
"ISlicer",
"InBatchShuffler",
"InMemoryCacheHolder",
"IndexAdder",
Expand Down Expand Up @@ -212,3 +214,4 @@

# Please keep this list sorted
assert __all__ == sorted(__all__)

80 changes: 79 additions & 1 deletion torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Callable, Hashable, Iterator, List, Sized, TypeVar, Union
from typing import Callable, Hashable, Iterator, List, Optional, Sized, TypeVar, Union

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
Expand Down Expand Up @@ -205,3 +205,81 @@ def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")


@functional_datapipe("islice")
class ISliceIterDataPipe(IterDataPipe[T_co]):
r"""
returns a slice of elements in input DataPipe via start/stop/step or indices (functional name: ``islice``).
Args:
datapipe: IterDataPipe with iterable elements
index: a single start index for the slice or a list of indices to be returned instead of a start/stop slice
stop: the slice stop. ignored if index is a list
step: step to be taken from start to stop. ignored if index is a list
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper([(0, 10, 100), (1, 11, 111), (2, 12, 122), (3, 13, 133), (4, 14, 144)])
>>> islice_dp = dp.islice(0, 2)
>>> list(islice_dp)
[(0, 10), (1, 11), (2, 12), (3, 13), (4, 14)]
"""
datapipe: IterDataPipe

def __init__(
self,
datapipe: IterDataPipe,
index: Union[int, List[Hashable]],
stop: Optional[int] = None,
step: Optional[int] = None,
) -> None:
super().__init__()
self.datapipe = datapipe

self.index = index
self.stop = stop
self.step = step

def __iter__(self) -> Iterator[T_co]:
for old_item in self.datapipe:
if isinstance(old_item, tuple):
if isinstance(self.index, list):
new_item = tuple(x for i, x in enumerate(old_item) if i in self.index) # type: ignore[assignment]
else:
new_item = old_item[self.index : self.stop : self.step] # type: ignore[assignment]
elif isinstance(old_item, list):
if isinstance(self.index, list):
new_item = [x for i, x in enumerate(old_item) if i in self.index] # type: ignore[assignment]
else:
new_item = old_item[self.index : self.stop : self.step] # type: ignore[assignment]
elif isinstance(old_item, dict):
if isinstance(self.index, list):
new_item = {k: v for (k, v) in old_item.items() if k in self.index} # type: ignore[assignment]
else:
new_keys = list(old_item.keys())[self.index : self.stop : self.step]
new_item = {k: v for (k, v) in old_item.items() if k in new_keys} # type: ignore[assignment]
else:
new_item = old_item
warnings.warn(
"The next item was not an iterable and cannot be filtered, "
"please be aware that no filter was done or new item created."
)

if isinstance(self.index, list):
# check to make sure all indices requested were in the item. warn if not
try:
for i in self.index:
old_item[i]
except (IndexError, KeyError):
warnings.warn(
"At least one index in the filter is not present in the item being returned,"
" please be aware that expected columns/keys may be missing."
)

yield new_item # type: ignore[misc]

def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

0 comments on commit ffb73a7

Please sign in to comment.