From 369b3b04c9786356265d4e05a6be609d1140eb31 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 16 Jun 2022 18:52:22 +0800 Subject: [PATCH 1/2] Use TopoForEachNodeFast as default. Rename the original one as TopoForEachNodeDynamic --- oneflow/core/graph/graph.h | 95 +++++++++++++++--------- oneflow/core/graph/inplace_lbi_graph.cpp | 6 +- oneflow/core/graph/op_graph.cpp | 3 +- oneflow/core/graph/straighten_nodes.cpp | 2 +- oneflow/core/job/plan_util.cpp | 14 ++-- oneflow/core/job_rewriter/autograd.cpp | 4 +- 6 files changed, 72 insertions(+), 52 deletions(-) diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index 640318c1ec3..73a4aef6979 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -34,11 +34,15 @@ class Graph { // For Each void ForEachNode(std::function NodeHandler) const; Maybe MaybeForEachNode(std::function(NodeType*)> NodeHandler) const; + // In case you want to change the topological structure during the node handler. + // For example, adding/deleting a node or an edge. + // Still, it might have bugs even if you use TopoForEachNodeDynamic. + void TopoForEachNodeDynamic(std::function NodeHandler) const; void TopoForEachNode(std::function NodeHandler) const; - void TopoForEachNodeFast(std::function NodeHandler) const; + Maybe TopoForEachNodeDynamicWithErrorCaptured( + std::function(NodeType*)> NodeHandler) const; Maybe TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const; - Maybe TopoForEachNodeFastMaybe(std::function(NodeType*)> NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; void ForEachEdge(std::function EdgeHandler) const; @@ -55,19 +59,26 @@ class Graph { const std::function&)>& ForEachNext, const std::function& Handler) const; - void TopoForEachNode( + // Another reason to keep TopoForEachNodeDynamic is that we can start from a subset of source + // nodes. + void TopoForEachNodeDynamic( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; - Maybe TopoForEachNodeWithErrorCaptured( + void TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + + Maybe TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; - Maybe TopoForEachNodeFastMaybe( + Maybe TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; @@ -219,33 +230,33 @@ NodeType* Graph::SoleSinkNode() const { } template -void Graph::TopoForEachNode(std::function NodeHandler) const { - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, - NodeHandler); +void Graph::TopoForEachNodeDynamic( + std::function NodeHandler) const { + TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template -void Graph::TopoForEachNodeFast( - std::function NodeHandler) const { - CHECK_JUST(TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge, - &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { - NodeHandler(node); - return Maybe::Ok(); - })); +void Graph::TopoForEachNode(std::function NodeHandler) const { + CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { + NodeHandler(node); + return Maybe::Ok(); + })); } template -Maybe Graph::TopoForEachNodeWithErrorCaptured( +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { - return TopoForEachNodeWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, - &NodeType::ForEachNodeOnOutEdge, NodeHandler); + return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template -Maybe Graph::TopoForEachNodeFastMaybe( +Maybe Graph::TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { - return TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, - NodeHandler); + return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template @@ -253,15 +264,14 @@ void Graph::SortedTopoForEachNode( std::function LessThan, std::function NodeHandler) const { ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); }); - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnSortedInEdge, - &NodeType::ForEachNodeOnSortedOutEdge, NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge, + NodeHandler); } template void Graph::ReverseTopoForEachNode( std::function NodeHandler) const { - TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, - NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler); } template @@ -518,20 +528,31 @@ std::unique_ptr> Graph::FindFirstNontrivi } template -void Graph::TopoForEachNode( +void Graph::TopoForEachNodeDynamic( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const { - CHECK_JUST( - TopoForEachNodeWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { - Handler(node); - return Maybe::Ok(); - })); + CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, + [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); } template -Maybe Graph::TopoForEachNodeWithErrorCaptured( +void Graph::TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + +template +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, @@ -562,7 +583,7 @@ Maybe Graph::TopoForEachNodeWithErrorCaptured( } template -Maybe Graph::TopoForEachNodeFastMaybe( +Maybe Graph::TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const { @@ -595,7 +616,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( HashMap node2distance_to_sink; { std::list nodes; - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { nodes.emplace_back(node); }); std::list sinks; for (NodeType* node : nodes) { @@ -603,7 +624,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; }); if (is_sink) { sinks.emplace_back(node); } } - TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) { + TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) { int64_t distance_to_sink = -1; ForEachOutNode(node, [&](NodeType* out_node) { distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]); @@ -698,12 +719,12 @@ Graph::MakePredicatorIsReachable( std::shared_ptr id2ancestor(new Id2Ancestor(node_num())); int64_t id = 0; node2id->reserve(node_num()); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { node2id->emplace(node, id); id2ancestor->at(id).Resize(node_num()); id += 1; }); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { const int64_t node_id = node2id->at(node); auto& ancestor_bitset_vec = id2ancestor->at(node_id); ForEachInNode(node, [&](NodeType* in_node) { diff --git a/oneflow/core/graph/inplace_lbi_graph.cpp b/oneflow/core/graph/inplace_lbi_graph.cpp index f1fc4320e64..09087568447 100644 --- a/oneflow/core/graph/inplace_lbi_graph.cpp +++ b/oneflow/core/graph/inplace_lbi_graph.cpp @@ -389,7 +389,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstIntraOpRefConflictMutRefEdge( const auto* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); - TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsOtherIbnBoundToOneOfLbis(lbis, node->SoleInEdge())) { ret = node->SoleInEdge(); @@ -427,7 +427,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstConstRefConflictMutRefEdge( auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); const InplaceLbiEdge* ret = nullptr; - TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsConstRefConflictMutRefNode(node, nodes, IsValidEdge, @@ -449,7 +449,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge( const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); - TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (node->IsMutRef(IsValidEdge)) { mut_ref_nodes.insert(node); } size_t in_edges_size_check = 0; ForEachInNode(node, [&](const InplaceLbiNode* in_node) { diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index 09df4962fcb..82e8b20088d 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -466,8 +466,7 @@ void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function& No const std::function& Handler) { ForEachDataAndCtrlOutNode(node, Handler); }; - TopoForEachNode(DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode, - OpGraphForEachOutDataAndCtrlNode, NodeHandler); + TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler); } std::function diff --git a/oneflow/core/graph/straighten_nodes.cpp b/oneflow/core/graph/straighten_nodes.cpp index 404fcb7dd11..ca5cc163fe6 100644 --- a/oneflow/core/graph/straighten_nodes.cpp +++ b/oneflow/core/graph/straighten_nodes.cpp @@ -233,7 +233,7 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task task_type2machine_id2node_id2topo_structs; std::map min_node_id2topo_struct; int32_t previous_min_layer = 0; - task_graph->TopoForEachNodeFast([&](TaskNode* node) { + task_graph->TopoForEachNode([&](TaskNode* node) { auto& topo_struct = task_node2topo_struct[node]; topo_struct.node = node; if (node->in_edges().empty()) { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index a2d3fc2f8da..ab62cb7f279 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -768,13 +768,13 @@ void PlanUtil::GenCollectiveBoxingPlan(Job* job, Plan* plan) { }; HashSet visited; std::vector collective_boxing_nodes; - plan_task_graph.TopoForEachNode(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge, - [&](const PlanTaskNode* node) { - visited.insert(node); - if (IsCollectiveBoxingNode(node)) { - collective_boxing_nodes.emplace_back(node); - } - }); + plan_task_graph.TopoForEachNodeDynamic(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge, + [&](const PlanTaskNode* node) { + visited.insert(node); + if (IsCollectiveBoxingNode(node)) { + collective_boxing_nodes.emplace_back(node); + } + }); if (collective_boxing_nodes.empty()) { break; } HashMap name2request_info; for (const PlanTaskNode* node : collective_boxing_nodes) { diff --git a/oneflow/core/job_rewriter/autograd.cpp b/oneflow/core/job_rewriter/autograd.cpp index 4fdf6f3b50d..24f3b32f335 100644 --- a/oneflow/core/job_rewriter/autograd.cpp +++ b/oneflow/core/job_rewriter/autograd.cpp @@ -909,8 +909,8 @@ Maybe AutoGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_b HashMap in_oba2in_diff_lbi; HashMap out_oba2clone_bw_add_out_lbi; std::list topo_nodes; - op_graph.TopoForEachNode(loss_nodes, ForEachOutNode, ForEachInNode, - [&](OpNode* op_node) { topo_nodes.emplace_back(op_node); }); + op_graph.TopoForEachNodeDynamic(loss_nodes, ForEachOutNode, ForEachInNode, + [&](OpNode* op_node) { topo_nodes.emplace_back(op_node); }); for (OpNode* op_node : topo_nodes) { const auto& op_name = op_node->op().op_name(); auto DiffLbi4BnInOp = [&](const std::string& bn) -> LogicalBlobId* { From 9f10dc87f237769d49a1d17f48ef028538cc7816 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 16 Jun 2022 20:44:47 +0800 Subject: [PATCH 2/2] Speed up TopoForEachNodeFast when traversing a subgraph --- oneflow/core/graph/graph.h | 60 +++++++++++++++++++++++- oneflow/core/graph/inplace_lbi_graph.cpp | 6 +-- oneflow/core/job/plan_util.cpp | 14 +++--- oneflow/core/job_rewriter/autograd.cpp | 4 +- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index 73a4aef6979..b9f62e01696 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -59,14 +59,18 @@ class Graph { const std::function&)>& ForEachNext, const std::function& Handler) const; - // Another reason to keep TopoForEachNodeDynamic is that we can start from a subset of source - // nodes. void TopoForEachNodeDynamic( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; + void TopoForEachNode( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + void TopoForEachNode( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, @@ -78,6 +82,12 @@ class Graph { const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; + Maybe TopoForEachNodeWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const; + Maybe TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, @@ -540,6 +550,19 @@ void Graph::TopoForEachNodeDynamic( })); } +template +void Graph::TopoForEachNode( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST( + TopoForEachNodeWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + template void Graph::TopoForEachNode( const std::function&)>& ForEachInNode, @@ -582,6 +605,39 @@ Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( return Maybe::Ok(); } +template +Maybe Graph::TopoForEachNodeWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const { + HashMap counter_in; + std::queue queue; + for (NodeType* start : starts) { + queue.push(start); + counter_in[start] = 0; + ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); + } + while (!queue.empty()) { + NodeType* cur_node = queue.front(); + queue.pop(); + JUST(Handler(cur_node)); + ForEachOutNode(cur_node, [&](NodeType* out) { + auto it = counter_in.find(out); + // Move the initialization here + if (it == counter_in.end()) { + int32_t count = 0; + ForEachInNode(out, [&](NodeType* out_in) { count++; }); + counter_in[out] = count; + it = counter_in.find(out); + } + it->second--; + if (it->second == 0) { queue.push(out); } + }); + } + return Maybe::Ok(); +} + template Maybe Graph::TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, diff --git a/oneflow/core/graph/inplace_lbi_graph.cpp b/oneflow/core/graph/inplace_lbi_graph.cpp index 09087568447..f1fc4320e64 100644 --- a/oneflow/core/graph/inplace_lbi_graph.cpp +++ b/oneflow/core/graph/inplace_lbi_graph.cpp @@ -389,7 +389,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstIntraOpRefConflictMutRefEdge( const auto* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); - TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsOtherIbnBoundToOneOfLbis(lbis, node->SoleInEdge())) { ret = node->SoleInEdge(); @@ -427,7 +427,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstConstRefConflictMutRefEdge( auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); const InplaceLbiEdge* ret = nullptr; - TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (ret != nullptr) { return; } if (node->IsMutRef(IsValidEdge) && IsConstRefConflictMutRefNode(node, nodes, IsValidEdge, @@ -449,7 +449,7 @@ const InplaceLbiEdge* InplaceLbiGraph::FindFirstInterOpRefConflictMutRefEdge( const InplaceLbiNode* root = GetRoot(nodes, IsValidEdge); auto ForEachInNode = GetForEachValidInNode(&nodes, IsValidEdge); auto ForEachOutNode = GetForEachValidOutNode(&nodes, IsValidEdge); - TopoForEachNodeDynamic({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { + TopoForEachNode({root}, ForEachInNode, ForEachOutNode, [&](const InplaceLbiNode* node) { if (node->IsMutRef(IsValidEdge)) { mut_ref_nodes.insert(node); } size_t in_edges_size_check = 0; ForEachInNode(node, [&](const InplaceLbiNode* in_node) { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index ab62cb7f279..a2d3fc2f8da 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -768,13 +768,13 @@ void PlanUtil::GenCollectiveBoxingPlan(Job* job, Plan* plan) { }; HashSet visited; std::vector collective_boxing_nodes; - plan_task_graph.TopoForEachNodeDynamic(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge, - [&](const PlanTaskNode* node) { - visited.insert(node); - if (IsCollectiveBoxingNode(node)) { - collective_boxing_nodes.emplace_back(node); - } - }); + plan_task_graph.TopoForEachNode(src_nodes, ForEachNodeOnInEdge, ForEachNodeOnOutEdge, + [&](const PlanTaskNode* node) { + visited.insert(node); + if (IsCollectiveBoxingNode(node)) { + collective_boxing_nodes.emplace_back(node); + } + }); if (collective_boxing_nodes.empty()) { break; } HashMap name2request_info; for (const PlanTaskNode* node : collective_boxing_nodes) { diff --git a/oneflow/core/job_rewriter/autograd.cpp b/oneflow/core/job_rewriter/autograd.cpp index 24f3b32f335..4fdf6f3b50d 100644 --- a/oneflow/core/job_rewriter/autograd.cpp +++ b/oneflow/core/job_rewriter/autograd.cpp @@ -909,8 +909,8 @@ Maybe AutoGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_b HashMap in_oba2in_diff_lbi; HashMap out_oba2clone_bw_add_out_lbi; std::list topo_nodes; - op_graph.TopoForEachNodeDynamic(loss_nodes, ForEachOutNode, ForEachInNode, - [&](OpNode* op_node) { topo_nodes.emplace_back(op_node); }); + op_graph.TopoForEachNode(loss_nodes, ForEachOutNode, ForEachInNode, + [&](OpNode* op_node) { topo_nodes.emplace_back(op_node); }); for (OpNode* op_node : topo_nodes) { const auto& op_name = op_node->op().op_name(); auto DiffLbi4BnInOp = [&](const std::string& bn) -> LogicalBlobId* {