Skip to content

Commit

Permalink
Use TopoForEachNodeFast as default. (#8436)
Browse files Browse the repository at this point in the history
* Use TopoForEachNodeFast as default.
Rename the original one as TopoForEachNodeDynamic

* Speed up TopoForEachNodeFast when traversing a subgraph
  • Loading branch information
Yipeng1994 authored Jun 17, 2022
1 parent 60e7800 commit 2f0b93b
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 32 deletions.
135 changes: 106 additions & 29 deletions oneflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class Graph {
// For Each
void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;
Maybe<void> MaybeForEachNode(std::function<Maybe<void>(NodeType*)> NodeHandler) const;
// In case you want to change the topological structure during the node handler.
// For example, adding/deleting a node or an edge.
// Still, it might have bugs even if you use TopoForEachNodeDynamic.
void TopoForEachNodeDynamic(std::function<void(NodeType*)> NodeHandler) const;
void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void TopoForEachNodeFast(std::function<void(NodeType*)> NodeHandler) const;
Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
Maybe<void> TopoForEachNodeWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
Maybe<void> TopoForEachNodeFastMaybe(std::function<Maybe<void>(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;

Expand All @@ -55,19 +59,36 @@ class Graph {
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNodeDynamic(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNode(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeFastMaybe(
Maybe<void> TopoForEachNodeWithErrorCaptured(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;
Expand Down Expand Up @@ -219,49 +240,48 @@ NodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const {
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge,
NodeHandler);
void Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(
std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNodeFast(
std::function<void(NodeType*)> NodeHandler) const {
CHECK_JUST(TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) {
NodeHandler(node);
return Maybe<void>::Ok();
}));
void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) {
NodeHandler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const {
return TopoForEachNodeWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeFastMaybe(
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const {
return TopoForEachNodeFastMaybe(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge,
NodeHandler);
return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::SortedTopoForEachNode(
std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,
std::function<void(NodeType*)> NodeHandler) const {
ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); });
TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnSortedInEdge,
&NodeType::ForEachNodeOnSortedOutEdge, NodeHandler);
TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge,
NodeHandler);
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge,
NodeHandler);
TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
Expand Down Expand Up @@ -517,6 +537,19 @@ std::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivi
return std::unique_ptr<HashSet<NodeType*>>();
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const {
CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode,
[&](NodeType* node) {
Handler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
const std::list<NodeType*>& starts,
Expand All @@ -531,7 +564,18 @@ void Graph<NodeType, EdgeType>::TopoForEachNode(
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
void Graph<NodeType, EdgeType>::TopoForEachNode(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const {
CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
Handler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
Expand Down Expand Up @@ -562,7 +606,40 @@ Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeFastMaybe(
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const {
HashMap<NodeType*, int32_t> counter_in;
std::queue<NodeType*> queue;
for (NodeType* start : starts) {
queue.push(start);
counter_in[start] = 0;
ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; });
}
while (!queue.empty()) {
NodeType* cur_node = queue.front();
queue.pop();
JUST(Handler(cur_node));
ForEachOutNode(cur_node, [&](NodeType* out) {
auto it = counter_in.find(out);
// Move the initialization here
if (it == counter_in.end()) {
int32_t count = 0;
ForEachInNode(out, [&](NodeType* out_in) { count++; });
counter_in[out] = count;
it = counter_in.find(out);
}
it->second--;
if (it->second == 0) { queue.push(out); }
});
}
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const {
Expand Down Expand Up @@ -595,15 +672,15 @@ void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(
HashMap<NodeType*, int64_t> node2distance_to_sink;
{
std::list<NodeType*> nodes;
TopoForEachNode(starts, ForEachInNode, ForEachOutNode,
TopoForEachNode(ForEachInNode, ForEachOutNode,
[&](NodeType* node) { nodes.emplace_back(node); });
std::list<NodeType*> sinks;
for (NodeType* node : nodes) {
bool is_sink = true;
ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; });
if (is_sink) { sinks.emplace_back(node); }
}
TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) {
TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) {
int64_t distance_to_sink = -1;
ForEachOutNode(node, [&](NodeType* out_node) {
distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);
Expand Down Expand Up @@ -698,12 +775,12 @@ Graph<NodeType, EdgeType>::MakePredicatorIsReachable(
std::shared_ptr<Id2Ancestor> id2ancestor(new Id2Ancestor(node_num()));
int64_t id = 0;
node2id->reserve(node_num());
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) {
TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
node2id->emplace(node, id);
id2ancestor->at(id).Resize(node_num());
id += 1;
});
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) {
TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
const int64_t node_id = node2id->at(node);
auto& ancestor_bitset_vec = id2ancestor->at(node_id);
ForEachInNode(node, [&](NodeType* in_node) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,7 @@ void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function<void(OpNode*)>& No
const std::function<void(OpNode*)>& Handler) {
ForEachDataAndCtrlOutNode(node, Handler);
};
TopoForEachNode(DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode,
OpGraphForEachOutDataAndCtrlNode, NodeHandler);
TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler);
}

std::function<bool(const std::string&, const std::string&)>
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/straighten_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void StraightenNodes(TaskGraph* task_graph, std::vector<TaskNode*>* ordered_task
task_type2machine_id2node_id2topo_structs;
std::map<int32_t, TopoStruct*> min_node_id2topo_struct;
int32_t previous_min_layer = 0;
task_graph->TopoForEachNodeFast([&](TaskNode* node) {
task_graph->TopoForEachNode([&](TaskNode* node) {
auto& topo_struct = task_node2topo_struct[node];
topo_struct.node = node;
if (node->in_edges().empty()) {
Expand Down

0 comments on commit 2f0b93b

Please sign in to comment.