Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add graph function to list DataPipes from graph (#888)
Summary: Add a `list_dps` function to list `DataPipes` from the graph. - It's similar to [`get_all_graph_pipes `](https://github.com/pytorch/pytorch/blob/896fa8c5c9b0191c9621e04ab5e20057614d48ad/torch/utils/data/graph_settings.py#L19) from pytorch core - An extra argument of `exclude_dps` to exclude the `DataPipe` and its prior graph from the result. Reason to add this function: - It's required to set random states differently for DataPipe before/after `sharding_filter` ```py graph = traverse_dps(datapipe) sf_dps = find_dps(graph, ShardingFilter) # DataPipes prior to `sharding_filter` p_dps = [] for sf_dp in sf_dps: p_dps.extend(list_dps(traverse_dps(sf_dp))) # DataPipes after `sharding_filter` a_dps = list_dps(graph, exclude_dps=sf_dps) ``` Step 1 for #885 Pull Request resolved: #888 Reviewed By: VitalyFedyunin, NivekT Differential Revision: D41099171 Pulled By: ejguan fbshipit-source-id: d9d6e7beb498fea3921d8a3a1020649dd3955ce2
- Loading branch information