diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst index 597af295c..f68f17a89 100644 --- a/docs/source/torchdata.datapipes.iter.rst +++ b/docs/source/torchdata.datapipes.iter.rst @@ -193,8 +193,8 @@ These DataPipes helps you select specific samples within a DataPipe. Filter Header Dropper - ISlicer - Flatten + Slicer + Flattener Text DataPipes ----------------------------- diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 43264ee38..ff30ddcf3 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -993,71 +993,59 @@ def test_drop_iterdatapipe(self): drop_dp = input_dp.drop([0, 1]) self.assertEqual(3, len(drop_dp)) - def test_islice_iterdatapipe(self): + def test_slice_iterdatapipe(self): # tuple tests input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)]) # Functional Test: slice with no stop and no step for tuple - islice_dp = input_dp.islice(1) - self.assertEqual([(1, 2), (4, 5), (7, 8)], list(islice_dp)) + slice_dp = input_dp.slice(1) + self.assertEqual([(1, 2), (4, 5), (7, 8)], list(slice_dp)) # Functional Test: slice with no step for tuple - islice_dp = input_dp.islice(0, 2) - self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp)) + slice_dp = input_dp.slice(0, 2) + self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp)) # Functional Test: slice with step for tuple - islice_dp = input_dp.islice(0, 2, 2) - self.assertEqual([(0,), (3,), (6,)], list(islice_dp)) + slice_dp = input_dp.slice(0, 2, 2) + self.assertEqual([(0,), (3,), (6,)], list(slice_dp)) # Functional Test: filter with list of indices for tuple - islice_dp = input_dp.islice([0, 1]) - self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp)) + slice_dp = input_dp.slice([0, 1]) + self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp)) # list tests input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) # Functional Test: slice with no stop and no step for list - islice_dp = input_dp.islice(1) - self.assertEqual([[1, 2], [4, 5], [7, 8]], list(islice_dp)) + slice_dp = input_dp.slice(1) + self.assertEqual([[1, 2], [4, 5], [7, 8]], list(slice_dp)) # Functional Test: slice with no step for list - islice_dp = input_dp.islice(0, 2) - self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp)) + slice_dp = input_dp.slice(0, 2) + self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp)) # Functional Test: filter with list of indices for list - islice_dp = input_dp.islice(0, 2) - self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp)) + slice_dp = input_dp.slice(0, 2) + self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp)) # dict tests input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}]) - # Functional Test: slice with no stop and no step for dict - islice_dp = input_dp.islice(1) - self.assertEqual([{"b": 2, "c": 3}, {"b": 4, "c": 5}, {"b": 6, "c": 7}], list(islice_dp)) - - # Functional Test: slice with no step for dict - islice_dp = input_dp.islice(0, 2) - self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp)) - - # Functional Test: slice with step for dict - islice_dp = input_dp.islice(0, 2, 2) - self.assertEqual([{"a": 1}, {"a": 3}, {"a": 5}], list(islice_dp)) - # Functional Test: filter with list of indices for dict - islice_dp = input_dp.islice(["a", "b"]) - self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp)) + slice_dp = input_dp.slice(["a", "b"]) + self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(slice_dp)) # __len__ Test: input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - islice_dp = input_dp.islice(0, 2) - self.assertEqual(3, len(islice_dp)) + slice_dp = input_dp.slice(0, 2) + self.assertEqual(3, len(slice_dp)) # Reset Test: n_elements_before_reset = 2 input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - islice_dp = input_dp.islice([2]) + slice_dp = input_dp.slice([2]) expected_res = [[2], [5], [8]] - res_before_reset, res_after_reset = reset_after_n_next_calls(islice_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls(slice_dp, n_elements_before_reset) self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) self.assertEqual(expected_res, res_after_reset) @@ -1075,9 +1063,9 @@ def test_flatten_iterdatapipe(self): self.assertEqual([(0, 10, 1, 2, 3), (4, 14, 5, 6, 7), (8, 18, 9, 10, 11)], list(flatten_dp)) # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)]) + input_dp = IterableWrapper([(0, (1, 2)), (3, (4, 5)), (6, (7, 8))]) flatten_dp = input_dp.flatten() - self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8], list(flatten_dp)) + self.assertEqual([(0, 1, 2), (3, 4, 5), (6, 7, 8)], list(flatten_dp)) # list tests @@ -1092,9 +1080,14 @@ def test_flatten_iterdatapipe(self): self.assertEqual([[0, 10, 1, 2, 3], [4, 14, 5, 6, 7], [8, 18, 9, 10, 11]], list(flatten_dp)) # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + input_dp = IterableWrapper([[0, [1, 2]], [3, [4, 5]], [6, [7, 8]]]) flatten_dp = input_dp.flatten() - self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8], list(flatten_dp)) + self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(flatten_dp)) + + # Functional Test: string test, flatten all iters in the datapipe one level (no argument) + input_dp = IterableWrapper([["zero", ["one", "2"]], ["3", ["4", "5"]], ["6", ["7", "8"]]]) + flatten_dp = input_dp.flatten() + self.assertEqual([["zero", "one", "2"], ["3", "4", "5"], ["6", "7", "8"]], list(flatten_dp)) # dict tests @@ -1103,6 +1096,13 @@ def test_flatten_iterdatapipe(self): flatten_dp = input_dp.flatten("c") self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp)) + # Functional Test: flatten for an index already flat + input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}]) + flatten_dp = input_dp.flatten("a") + self.assertEqual( + [{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}], list(flatten_dp) + ) + # Functional Test: flatten for list of indices input_dp = IterableWrapper( [ @@ -1116,10 +1116,20 @@ def test_flatten_iterdatapipe(self): ) # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3, "d": 4}, {"a": 5, "b": 6, "c": 7, "d": 8}]) + input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}]) + flatten_dp = input_dp.flatten() + self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp)) + + # Functional Test: flatten all iters one level, multiple iters + input_dp = IterableWrapper( + [ + {"a": {"f": 10, "g": 11}, "b": 2, "c": {"d": 3, "e": 4}}, + {"a": {"f": 10, "g": 11}, "b": 6, "c": {"d": 7, "e": 8}}, + ] + ) flatten_dp = input_dp.flatten() self.assertEqual( - [("a", 1), ("b", 2), ("c", 3), ("d", 4), ("a", 5), ("b", 6), ("c", 7), ("d", 8)], list(flatten_dp) + [{"f": 10, "g": 11, "b": 2, "d": 3, "e": 4}, {"f": 10, "g": 11, "b": 6, "d": 7, "e": 8}], list(flatten_dp) ) # __len__ Test: diff --git a/test/test_serialization.py b/test/test_serialization.py index 443fafa13..938f1a62d 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -189,10 +189,9 @@ def test_serializable(self): (iterdp.DataFrameMaker, IterableWrapper([(i,) for i in range(3)]), (), {"dtype": DTYPE}), (iterdp.Decompressor, None, (), {}), (iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}), - (iterdp.ISlicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}), (iterdp.Enumerator, None, (2,), {}), (iterdp.FlatMapper, None, (_fake_fn_ls,), {}), - (iterdp.Flatten, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}), + (iterdp.Flattener, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}), (iterdp.FSSpecFileLister, ".", (), {}), (iterdp.FSSpecFileOpener, None, (), {}), ( @@ -285,6 +284,7 @@ def test_serializable(self): (), {"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)}, ), + (iterdp.Slicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}), (iterdp.TarArchiveLoader, None, (), {}), # TODO(594): Add serialization tests for optional DataPipe # (iterdp.TFRecordLoader, None, (), {}), diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 540a60d5b..86849b950 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -69,8 +69,8 @@ BatchMapperIterDataPipe as BatchMapper, DropperIterDataPipe as Dropper, FlatMapperIterDataPipe as FlatMapper, - FlattenIterDataPipe as Flatten, - ISliceIterDataPipe as ISlicer, + FlattenIterDataPipe as Flattener, + SliceIterDataPipe as Slicer, ) from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader from torchdata.datapipes.iter.util.cacheholder import ( @@ -158,7 +158,7 @@ "FileOpener", "Filter", "FlatMapper", - "Flatten", + "Flattener", "Forker", "FullSync", "GDriveReader", @@ -167,7 +167,6 @@ "Header", "HttpReader", "HuggingFaceHubReader", - "ISlicer", "InBatchShuffler", "InMemoryCacheHolder", "IndexAdder", @@ -199,6 +198,7 @@ "Saver", "ShardingFilter", "Shuffler", + "Slicer", "StreamReader", "TFRecordLoader", "TarArchiveLoader", diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 2efd87c2c..75065679f 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -5,8 +5,7 @@ # LICENSE file in the root directory of this source tree. import warnings -from itertools import chain -from typing import Callable, Hashable, Iterator, List, Optional, Sized, TypeVar, Union +from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union from torch.utils.data import functional_datapipe, IterDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn @@ -208,10 +207,10 @@ def __len__(self) -> int: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") -@functional_datapipe("islice") -class ISliceIterDataPipe(IterDataPipe[T_co]): +@functional_datapipe("slice") +class SliceIterDataPipe(IterDataPipe[T_co]): r""" - returns a slice of elements in input DataPipe via start/stop/step or indices (functional name: ``islice``). + returns a slice of elements in input DataPipe via start/stop/step or indices (functional name: ``slice``). Args: datapipe: IterDataPipe with iterable elements @@ -222,8 +221,8 @@ class ISliceIterDataPipe(IterDataPipe[T_co]): Example: >>> from torchdata.datapipes.iter import IterableWrapper >>> dp = IterableWrapper([(0, 10, 100), (1, 11, 111), (2, 12, 122), (3, 13, 133), (4, 14, 144)]) - >>> islice_dp = dp.islice(0, 2) - >>> list(islice_dp) + >>> slice_dp = dp.slice(0, 2) + >>> list(slice_dp) [(0, 10), (1, 11), (2, 12), (3, 13), (4, 14)] """ datapipe: IterDataPipe @@ -242,6 +241,13 @@ def __init__( self.stop = stop self.step = step + if isinstance(index, list): + if stop or step: + warnings.warn( + "A list of indices was passed as well as a stop or step for the slice," + "these arguments can't be used together so onlyu the indices list will be used." + ) + def __iter__(self) -> Iterator[T_co]: for old_item in self.datapipe: if isinstance(old_item, tuple): @@ -258,10 +264,13 @@ def __iter__(self) -> Iterator[T_co]: if isinstance(self.index, list): new_item = {k: v for (k, v) in old_item.items() if k in self.index} # type: ignore[assignment] else: - new_keys = list(old_item.keys())[self.index : self.stop : self.step] - new_item = {k: v for (k, v) in old_item.items() if k in new_keys} # type: ignore[assignment] + new_item = old_item # type: ignore[assignment] + warnings.warn( + "Dictionaries are not sliced by steps, only direct index. " + "Please be aware that no filter was done or new item created." + ) else: - new_item = old_item + new_item = old_item # type: ignore[assignment] warnings.warn( "The next item was not an iterable and cannot be filtered, " "please be aware that no filter was done or new item created." @@ -291,6 +300,9 @@ class FlattenIterDataPipe(IterDataPipe[T_co]): r""" returns a flattened copy of the input DataPipe based on provided indices (functional name: ``flatten``). + Note: + no args will flatten each item in the datapipe 1 level + Args: datapipe: IterDataPipe with iterable elements indices: a single index/key for the item to flatten from an iterator item or a list of indices/keys to be flattened @@ -303,7 +315,7 @@ class FlattenIterDataPipe(IterDataPipe[T_co]): [(0, 10, 100, 1000), (1, 11, 111, 1001), (2, 12, 122, 1002), (3, 13, 133, 1003), (4, 14, 144, 1004)] """ datapipe: IterDataPipe - indices = None + indices: Set[Hashable] = set() def __init__( self, @@ -319,47 +331,57 @@ def __init__( self.indices = {indices} def __iter__(self) -> Iterator[T_co]: + flatten_all = False + if not self.indices: + flatten_all = True for old_item in self.datapipe: - if self.indices: - if isinstance(old_item, dict): - new_item = {} # type: ignore[assignment] - for k, v in old_item.items(): - if k in self.indices: - for k_sub, v_sub in v.items(): + if isinstance(old_item, dict): + new_item = {} # type: ignore[assignment] + for k, v in old_item.items(): + if k in self.indices: + pass + if (flatten_all or (k in self.indices)) and isinstance(v, dict): + for k_sub, v_sub in v.items(): + if k_sub not in new_item: new_item[k_sub] = v_sub - else: + else: + warnings.warn( + "Flattener tried to insert the same key twice into the same dict," + "the second key,value pair has been dropped." + ) + else: + if k not in new_item: new_item[k] = v - else: - is_tuple = False - new_item = [] # type: ignore[assignment] - if isinstance(old_item, tuple): - is_tuple = True - old_item = list(old_item) - for i, item in enumerate(old_item): - if i in self.indices: - new_item.extend(list(item)) # type: ignore[attr-defined] else: - new_item.append(item) # type: ignore[attr-defined] - if is_tuple: - new_item = tuple(new_item) # type: ignore[assignment] + warnings.warn( + "Flattener tried to insert the same key twice into the same dict," + "the second key,value pair has been dropped." + ) + else: + is_tuple = False + new_item = [] # type: ignore[assignment] + if isinstance(old_item, tuple): + is_tuple = True + old_item = list(old_item) + for i, item in enumerate(old_item): + if (flatten_all or (i in self.indices)) and isinstance(item, (list, tuple)): + new_item.extend(list(item)) # type: ignore[attr-defined] + else: + new_item.append(item) # type: ignore[attr-defined] + if is_tuple: + new_item = tuple(new_item) # type: ignore[assignment] - # check to make sure all indices requested were in the item. warn if not - try: + # check to make sure all indices requested were in the item. warn if not + try: + if self.indices: for index in self.indices: old_item[index] - except (IndexError, KeyError): - warnings.warn( - "At least one index in the filter is not present in the item being returned," - " please be aware that expected columns/keys may be missing." - ) - - yield new_item # type: ignore[misc] - - else: - if isinstance(old_item, dict): - yield from chain(old_item.items()) - else: - yield from chain(old_item) + except (IndexError, KeyError): + warnings.warn( + "At least one index in the filter is not present in the item being returned," + " please be aware that expected columns/keys may be missing." + ) + yield new_item # type: ignore[misc] def __len__(self) -> int: if isinstance(self.datapipe, Sized):