diff --git a/docs/source/dataloader2.rst b/docs/source/dataloader2.rst index d98f3cefe..1c0d6f588 100644 --- a/docs/source/dataloader2.rst +++ b/docs/source/dataloader2.rst @@ -57,8 +57,9 @@ And, graph utility functions are provided in ``torchdata.dataloader.graph`` to h traverse_dps find_dps - replace_dp + list_dps remove_dp + replace_dp Adapter -------- diff --git a/test/test_graph.py b/test/test_graph.py index a8642f1a0..dba62f097 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -13,7 +13,7 @@ from _utils._common_utils_for_test import IS_WINDOWS from torch.utils.data import IterDataPipe from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface -from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps +from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps from torchdata.datapipes.iter import IterableWrapper, Mapper from torchdata.datapipes.utils import to_graph @@ -97,6 +97,44 @@ def test_find_dps(self) -> None: for dp in dps: self.assertTrue(dp in expected_dps) + def test_list_dps(self) -> None: + def _validate_fn(dps, exp_dps): + self.assertEqual(len(dps), len(exp_dps)) + exp_set = {*exp_dps} + for dp in dps: + self.assertTrue(dp in exp_set) + + graph, exp_all_dps = self._get_datapipes() + ( + src_dp, + m1, + ub, + dm, + c1, + c2, + m2, + dp, + ) = exp_all_dps + + # List all DataPipes + dps = list_dps(graph) + _validate_fn(dps, exp_all_dps) + + # List all DataPipes excluding a single DataPipe + dps = list_dps(graph, exclude_dps=m1) + _, _, *exp_dps = exp_all_dps + _validate_fn(dps, exp_dps) + + dps = list_dps(graph, exclude_dps=m2) + *exp_dps_1, _, c2, _, dp = exp_all_dps + exp_dps = list(exp_dps_1) + [c2, dp] + _validate_fn(dps, exp_dps) + + # List all DataPipes excluding multiple DataPipes + dps = list_dps(graph, exclude_dps=[m1, m2]) + exp_dps = [ub, dm, c2, dp] + _validate_fn(dps, exp_dps) + def _validate_graph(self, graph, nested_dp): self.assertEqual(len(graph), len(nested_dp)) for dp_id, sub_nested_dp in zip(graph, nested_dp): diff --git a/torchdata/dataloader2/graph/__init__.py b/torchdata/dataloader2/graph/__init__.py new file mode 100644 index 000000000..84154943b --- /dev/null +++ b/torchdata/dataloader2/graph/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps + +from torchdata.dataloader2.graph.utils import find_dps, list_dps, remove_dp, replace_dp + + +__all__ = [ + "DataPipe", + "DataPipeGraph", + "find_dps", + "list_dps", + "remove_dp", + "replace_dp", + "traverse_dps", +] + + +assert __all__ == sorted(__all__) diff --git a/torchdata/dataloader2/graph.py b/torchdata/dataloader2/graph/utils.py similarity index 84% rename from torchdata/dataloader2/graph.py rename to torchdata/dataloader2/graph/utils.py index d3a8836a0..3ce616b06 100644 --- a/torchdata/dataloader2/graph.py +++ b/torchdata/dataloader2/graph/utils.py @@ -5,17 +5,13 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Type - -from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps +from typing import List, Optional, Set, Type, Union +from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.map import MapDataPipe -__all__ = ["find_dps", "replace_dp", "remove_dp"] - - def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]: r""" Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe @@ -34,6 +30,34 @@ def helper(g) -> None: # pyre-ignore return dps +def list_dps(graph: DataPipeGraph, exclude_dps: Optional[Union[DataPipe, List[DataPipe]]] = None) -> List[DataPipe]: + r""" + Given the graph of DataPipe generated by ``traverse_dps`` function, return a list + of all DataPipe instances without duplication. If ``exclude_dps`` is provided, + the provided ``DataPipes`` and their predecessors will be ignored. + """ + dps: List[DataPipe] = [] + cache: Set[int] = set() + + if exclude_dps is not None: + if isinstance(exclude_dps, (IterDataPipe, MapDataPipe)): + cache.add(id(exclude_dps)) + else: + for dp in exclude_dps: # type: ignore[union-attr] + cache.add(id(dp)) + + def helper(g) -> None: # pyre-ignore + for dp_id, (dp, src_graph) in g.items(): + if dp_id not in cache: + cache.add(dp_id) + dps.append(dp) + helper(src_graph) + + helper(graph) + + return dps + + # Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph: r"""