diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index 640318c1ec3..b9f62e01696 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -34,11 +34,15 @@ class Graph { // For Each void ForEachNode(std::function NodeHandler) const; Maybe MaybeForEachNode(std::function(NodeType*)> NodeHandler) const; + // In case you want to change the topological structure during the node handler. + // For example, adding/deleting a node or an edge. + // Still, it might have bugs even if you use TopoForEachNodeDynamic. + void TopoForEachNodeDynamic(std::function NodeHandler) const; void TopoForEachNode(std::function NodeHandler) const; - void TopoForEachNodeFast(std::function NodeHandler) const; + Maybe TopoForEachNodeDynamicWithErrorCaptured( + std::function(NodeType*)> NodeHandler) const; Maybe TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const; - Maybe TopoForEachNodeFastMaybe(std::function(NodeType*)> NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; void ForEachEdge(std::function EdgeHandler) const; @@ -55,19 +59,36 @@ class Graph { const std::function&)>& ForEachNext, const std::function& Handler) const; + void TopoForEachNodeDynamic( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + void TopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; + void TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + + Maybe TopoForEachNodeDynamicWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const; + Maybe TopoForEachNodeWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; - Maybe TopoForEachNodeFastMaybe( + Maybe TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; @@ -219,33 +240,33 @@ NodeType* Graph::SoleSinkNode() const { } template -void Graph::TopoForEachNode(std::function NodeHandler) const { - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, - NodeHandler); +void Graph::TopoForEachNodeDynamic( + std::function NodeHandler) const { + TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template -void Graph::TopoForEachNodeFast( - std::function NodeHandler) const { - CHECK_JUST(TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge, - &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { - NodeHandler(node); - return Maybe::Ok(); - })); +void Graph::TopoForEachNode(std::function NodeHandler) const { + CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { + NodeHandler(node); + return Maybe::Ok(); + })); } template -Maybe Graph::TopoForEachNodeWithErrorCaptured( +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { - return TopoForEachNodeWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, - &NodeType::ForEachNodeOnOutEdge, NodeHandler); + return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template -Maybe Graph::TopoForEachNodeFastMaybe( +Maybe Graph::TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { - return TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, - NodeHandler); + return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template @@ -253,15 +274,14 @@ void Graph::SortedTopoForEachNode( std::function LessThan, std::function NodeHandler) const { ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); }); - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnSortedInEdge, - &NodeType::ForEachNodeOnSortedOutEdge, NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge, + NodeHandler); } template void Graph::ReverseTopoForEachNode( std::function NodeHandler) const { - TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, - NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler); } template @@ -517,6 +537,19 @@ std::unique_ptr> Graph::FindFirstNontrivi return std::unique_ptr>(); } +template +void Graph::TopoForEachNodeDynamic( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, + [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + template void Graph::TopoForEachNode( const std::list& starts, @@ -531,7 +564,18 @@ void Graph::TopoForEachNode( } template -Maybe Graph::TopoForEachNodeWithErrorCaptured( +void Graph::TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + +template +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, @@ -562,7 +606,40 @@ Maybe Graph::TopoForEachNodeWithErrorCaptured( } template -Maybe Graph::TopoForEachNodeFastMaybe( +Maybe Graph::TopoForEachNodeWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const { + HashMap counter_in; + std::queue queue; + for (NodeType* start : starts) { + queue.push(start); + counter_in[start] = 0; + ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); + } + while (!queue.empty()) { + NodeType* cur_node = queue.front(); + queue.pop(); + JUST(Handler(cur_node)); + ForEachOutNode(cur_node, [&](NodeType* out) { + auto it = counter_in.find(out); + // Move the initialization here + if (it == counter_in.end()) { + int32_t count = 0; + ForEachInNode(out, [&](NodeType* out_in) { count++; }); + counter_in[out] = count; + it = counter_in.find(out); + } + it->second--; + if (it->second == 0) { queue.push(out); } + }); + } + return Maybe::Ok(); +} + +template +Maybe Graph::TopoForEachNodeWithErrorCaptured( const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const { @@ -595,7 +672,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( HashMap node2distance_to_sink; { std::list nodes; - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { nodes.emplace_back(node); }); std::list sinks; for (NodeType* node : nodes) { @@ -603,7 +680,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; }); if (is_sink) { sinks.emplace_back(node); } } - TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) { + TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) { int64_t distance_to_sink = -1; ForEachOutNode(node, [&](NodeType* out_node) { distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]); @@ -698,12 +775,12 @@ Graph::MakePredicatorIsReachable( std::shared_ptr id2ancestor(new Id2Ancestor(node_num())); int64_t id = 0; node2id->reserve(node_num()); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { node2id->emplace(node, id); id2ancestor->at(id).Resize(node_num()); id += 1; }); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { const int64_t node_id = node2id->at(node); auto& ancestor_bitset_vec = id2ancestor->at(node_id); ForEachInNode(node, [&](NodeType* in_node) { diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index 09df4962fcb..82e8b20088d 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -466,8 +466,7 @@ void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function& No const std::function& Handler) { ForEachDataAndCtrlOutNode(node, Handler); }; - TopoForEachNode(DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode, - OpGraphForEachOutDataAndCtrlNode, NodeHandler); + TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler); } std::function diff --git a/oneflow/core/graph/straighten_nodes.cpp b/oneflow/core/graph/straighten_nodes.cpp index 404fcb7dd11..ca5cc163fe6 100644 --- a/oneflow/core/graph/straighten_nodes.cpp +++ b/oneflow/core/graph/straighten_nodes.cpp @@ -233,7 +233,7 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task task_type2machine_id2node_id2topo_structs; std::map min_node_id2topo_struct; int32_t previous_min_layer = 0; - task_graph->TopoForEachNodeFast([&](TaskNode* node) { + task_graph->TopoForEachNode([&](TaskNode* node) { auto& topo_struct = task_node2topo_struct[node]; topo_struct.node = node; if (node->in_edges().empty()) {