Skip to content

Commit

Permalink
Add graph function to list DataPipes from graph (#888)
Browse files Browse the repository at this point in the history
Summary:
Add a `list_dps` function to list `DataPipes` from the graph.
- It's similar to [`get_all_graph_pipes `](https://github.com/pytorch/pytorch/blob/896fa8c5c9b0191c9621e04ab5e20057614d48ad/torch/utils/data/graph_settings.py#L19) from pytorch core
- An extra argument of `exclude_dps` to exclude the `DataPipe` and its prior graph from the result.

Reason to add this function:
- It's required to set random states differently for DataPipe before/after `sharding_filter`
```py
graph = traverse_dps(datapipe)
sf_dps = find_dps(graph, ShardingFilter)

# DataPipes prior to `sharding_filter`
p_dps = []
for sf_dp in sf_dps:
    p_dps.extend(list_dps(traverse_dps(sf_dp)))

# DataPipes after `sharding_filter`
a_dps = list_dps(graph, exclude_dps=sf_dps)
```

Step 1 for #885

Pull Request resolved: #888

Reviewed By: VitalyFedyunin, NivekT

Differential Revision: D41099171

Pulled By: ejguan

fbshipit-source-id: d9d6e7beb498fea3921d8a3a1020649dd3955ce2
  • Loading branch information
ejguan authored and facebook-github-bot committed Nov 11, 2022
1 parent 59d1462 commit b2ca33d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/source/dataloader2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ And, graph utility functions are provided in ``torchdata.dataloader.graph`` to h

traverse_dps
find_dps
replace_dp
list_dps
remove_dp
replace_dp

Adapter
--------
Expand Down
40 changes: 39 additions & 1 deletion test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, Mapper
from torchdata.datapipes.utils import to_graph

Expand Down Expand Up @@ -97,6 +97,44 @@ def test_find_dps(self) -> None:
for dp in dps:
self.assertTrue(dp in expected_dps)

def test_list_dps(self) -> None:
def _validate_fn(dps, exp_dps):
self.assertEqual(len(dps), len(exp_dps))
exp_set = {*exp_dps}
for dp in dps:
self.assertTrue(dp in exp_set)

graph, exp_all_dps = self._get_datapipes()
(
src_dp,
m1,
ub,
dm,
c1,
c2,
m2,
dp,
) = exp_all_dps

# List all DataPipes
dps = list_dps(graph)
_validate_fn(dps, exp_all_dps)

# List all DataPipes excluding a single DataPipe
dps = list_dps(graph, exclude_dps=m1)
_, _, *exp_dps = exp_all_dps
_validate_fn(dps, exp_dps)

dps = list_dps(graph, exclude_dps=m2)
*exp_dps_1, _, c2, _, dp = exp_all_dps
exp_dps = list(exp_dps_1) + [c2, dp]
_validate_fn(dps, exp_dps)

# List all DataPipes excluding multiple DataPipes
dps = list_dps(graph, exclude_dps=[m1, m2])
exp_dps = [ub, dm, c2, dp]
_validate_fn(dps, exp_dps)

def _validate_graph(self, graph, nested_dp):
self.assertEqual(len(graph), len(nested_dp))
for dp_id, sub_nested_dp in zip(graph, nested_dp):
Expand Down
24 changes: 24 additions & 0 deletions torchdata/dataloader2/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.dataloader2.graph.utils import find_dps, list_dps, remove_dp, replace_dp


__all__ = [
"DataPipe",
"DataPipeGraph",
"find_dps",
"list_dps",
"remove_dp",
"replace_dp",
"traverse_dps",
]


assert __all__ == sorted(__all__)
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
# LICENSE file in the root directory of this source tree.


from typing import List, Type

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
from typing import List, Optional, Set, Type, Union

from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe


__all__ = ["find_dps", "replace_dp", "remove_dp"]


def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
Expand All @@ -34,6 +30,34 @@ def helper(g) -> None: # pyre-ignore
return dps


def list_dps(graph: DataPipeGraph, exclude_dps: Optional[Union[DataPipe, List[DataPipe]]] = None) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function, return a list
of all DataPipe instances without duplication. If ``exclude_dps`` is provided,
the provided ``DataPipes`` and their predecessors will be ignored.
"""
dps: List[DataPipe] = []
cache: Set[int] = set()

if exclude_dps is not None:
if isinstance(exclude_dps, (IterDataPipe, MapDataPipe)):
cache.add(id(exclude_dps))
else:
for dp in exclude_dps: # type: ignore[union-attr]
cache.add(id(dp))

def helper(g) -> None: # pyre-ignore
for dp_id, (dp, src_graph) in g.items():
if dp_id not in cache:
cache.add(dp_id)
dps.append(dp)
helper(src_graph)

helper(graph)

return dps


# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph:
r"""
Expand Down

0 comments on commit b2ca33d

Please sign in to comment.