Skip to content

Commit

Permalink
Add graph function to list DataPipes from graph
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Nov 8, 2022
1 parent d5a5940 commit f65a440
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
40 changes: 39 additions & 1 deletion test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps
from torchdata.dataloader2.graph.utils import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, Mapper
from torchdata.datapipes.utils import to_graph

Expand Down Expand Up @@ -97,6 +97,44 @@ def test_find_dps(self) -> None:
for dp in dps:
self.assertTrue(dp in expected_dps)

def test_list_dps(self) -> None:
def _validate_fn(dps, exp_dps):
self.assertEqual(len(dps), len(exp_dps))
exp_set = {*exp_dps}
for dp in dps:
self.assertTrue(dp in exp_set)

graph, exp_all_dps = self._get_datapipes()
(
src_dp,
m1,
ub,
dm,
c1,
c2,
m2,
dp,
) = exp_all_dps

# List all DataPipes
dps = list_dps(graph)
_validate_fn(dps, exp_all_dps)

# List all DataPipes excluding a single DataPipe
dps = list_dps(graph, exclude_dps=m1)
_, _, *exp_dps = exp_all_dps
_validate_fn(dps, exp_dps)

dps = list_dps(graph, exclude_dps=m2)
*exp_dps_1, _, c2, _, dp = exp_all_dps
exp_dps = list(exp_dps_1) + [c2, dp]
_validate_fn(dps, exp_dps)

# List all DataPipes excluding multiple DataPipes
dps = list_dps(graph, exclude_dps=[m1, m2])
exp_dps = [ub, dm, c2, dp]
_validate_fn(dps, exp_dps)

def _validate_graph(self, graph, nested_dp):
self.assertEqual(len(graph), len(nested_dp))
for dp_id, sub_nested_dp in zip(graph, nested_dp):
Expand Down
24 changes: 24 additions & 0 deletions torchdata/dataloader2/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.dataloader2.graph.utils import find_dps, list_dps, remove_dp, replace_dp


__all__ = [
"DataPipe",
"DataPipeGraph",
"find_dps",
"list_dps",
"remove_dp",
"replace_dp",
"traverse_dps",
]


assert __all__ == sorted(__all__)
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
# LICENSE file in the root directory of this source tree.


from typing import List, Type

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
from typing import List, Optional, Set, Type, Union

from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe


__all__ = ["find_dps", "replace_dp", "remove_dp"]


def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
Expand All @@ -34,6 +30,33 @@ def helper(g) -> None: # pyre-ignore
return dps


def list_dps(graph: DataPipeGraph, exclude_dps: Optional[Union[DataPipe, List[DataPipe]]] = None) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function,
return DataPipe instances excluding the provided DataPipe and its prior DataPipe graph.
"""
dps: List[DataPipe] = []
cache: Set[int] = set()

if exclude_dps is not None:
if isinstance(exclude_dps, (IterDataPipe, MapDataPipe)):
cache.add(id(exclude_dps))
else:
for dp in exclude_dps: # type: ignore[union-attr]
cache.add(id(dp))

def helper(g) -> None: # pyre-ignore
for dp_id, (dp, src_graph) in g.items():
if dp_id not in cache:
cache.add(dp_id)
dps.append(dp)
helper(src_graph)

helper(graph)

return dps


# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph:
r"""
Expand Down

0 comments on commit f65a440

Please sign in to comment.