Skip to content

Commit

Permalink
Feat straighten task nodes (#8347)
Browse files Browse the repository at this point in the history
* Add a fast topological traversal

* Add an initial implementation of straighen nodes

* Add the straighen nodes algorithm

* Change algorithm structure

* Remove some debug information

* Finalize the straighten algorithm after
deciding the parameters by experiments

* Notify the usage of straighten algorithm

* Of format

* Update oneflow/core/graph/straighten_nodes.cpp

Of format

Co-authored-by: daquexian <daquexian566@gmail.com>

* Of format

* Stop using visual string before we find a better key

* Remove magic numbers and Of format

* Remove starts

* Of format

* Fix a bug of using GetMaxVal<int32_t>() as an
initial number for comparing

* Refactor add straighten algo interface (#8435)

* feat(*): export straighten nodes algorithm inferface

* export documentation

* Update python/oneflow/nn/graph/graph_config.py

Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>

Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>

* Use TopoForEachNodeFast as default. (#8436)

* Use TopoForEachNodeFast as default.
Rename the original one as TopoForEachNodeDynamic

* Speed up TopoForEachNodeFast when traversing a subgraph

* Rename the switch and code clean up

* Hide the class TopoStruct

* Hide all the other functions

* Grammar

* Of format

Co-authored-by: daquexian <daquexian566@gmail.com>
Co-authored-by: Yinggang Wang <wyg19970408@gmail.com>
  • Loading branch information
3 people authored Jun 17, 2022
1 parent f6c3cb6 commit f7532fd
Show file tree
Hide file tree
Showing 10 changed files with 674 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/source/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Base class for running neural networks in Static Graph Mode.
allow_fuse_cast_scale,
set_gradient_accumulation_steps,
enable_cudnn_conv_heuristic_search_algo,
disable_straighten_algorithm,
:member-order: bysource


Expand Down
150 changes: 138 additions & 12 deletions oneflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ class Graph {
// For Each
void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;
Maybe<void> MaybeForEachNode(std::function<Maybe<void>(NodeType*)> NodeHandler) const;
// In case you want to change the topological structure during the node handler.
// For example, adding/deleting a node or an edge.
// Still, it might have bugs even if you use TopoForEachNodeDynamic.
void TopoForEachNodeDynamic(std::function<void(NodeType*)> NodeHandler) const;
void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
Maybe<void> TopoForEachNodeWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
Expand All @@ -53,18 +59,40 @@ class Graph {
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachNext,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNodeDynamic(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

void TopoForEachNode(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeDynamicWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;

Maybe<void> TopoForEachNodeWithErrorCaptured(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const;

void DfsTopoForEachNode(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
Expand Down Expand Up @@ -211,16 +239,33 @@ NodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const {
return sink_nodes_list.front();
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(
std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge,
NodeHandler);
CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) {
NodeHandler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const {
return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
std::function<Maybe<void>(NodeType*)> NodeHandler) const {
return TopoForEachNodeWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge,
return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge,
&NodeType::ForEachNodeOnOutEdge, NodeHandler);
}

Expand All @@ -229,15 +274,14 @@ void Graph<NodeType, EdgeType>::SortedTopoForEachNode(
std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,
std::function<void(NodeType*)> NodeHandler) const {
ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); });
TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnSortedInEdge,
&NodeType::ForEachNodeOnSortedOutEdge, NodeHandler);
TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge,
NodeHandler);
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
std::function<void(NodeType*)> NodeHandler) const {
TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge,
NodeHandler);
TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler);
}

template<typename NodeType, typename EdgeType>
Expand Down Expand Up @@ -493,6 +537,19 @@ std::unique_ptr<HashSet<NodeType*>> Graph<NodeType, EdgeType>::FindFirstNontrivi
return std::unique_ptr<HashSet<NodeType*>>();
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNodeDynamic(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const {
CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode,
[&](NodeType* node) {
Handler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
const std::list<NodeType*>& starts,
Expand All @@ -507,7 +564,18 @@ void Graph<NodeType, EdgeType>::TopoForEachNode(
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
void Graph<NodeType, EdgeType>::TopoForEachNode(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<void(NodeType*)>& Handler) const {
CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
Handler(node);
return Maybe<void>::Ok();
}));
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeDynamicWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
Expand Down Expand Up @@ -537,6 +605,64 @@ Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
const std::list<NodeType*>& starts,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const {
HashMap<NodeType*, int32_t> counter_in;
std::queue<NodeType*> queue;
for (NodeType* start : starts) {
queue.push(start);
counter_in[start] = 0;
ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; });
}
while (!queue.empty()) {
NodeType* cur_node = queue.front();
queue.pop();
JUST(Handler(cur_node));
ForEachOutNode(cur_node, [&](NodeType* out) {
auto it = counter_in.find(out);
// Move the initialization here
if (it == counter_in.end()) {
int32_t count = 0;
ForEachInNode(out, [&](NodeType* out_in) { count++; });
counter_in[out] = count;
it = counter_in.find(out);
}
it->second--;
if (it->second == 0) { queue.push(out); }
});
}
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::TopoForEachNodeWithErrorCaptured(
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachInNode,
const std::function<void(NodeType*, const std::function<void(NodeType*)>&)>& ForEachOutNode,
const std::function<Maybe<void>(NodeType*)>& Handler) const {
HashMap<NodeType*, int32_t> counter_in;
std::queue<NodeType*> queue;
ForEachNode([&](NodeType* node) {
int32_t count = 0;
ForEachInNode(node, [&](NodeType*) { count++; });
counter_in[node] = count;
if (count == 0) { queue.push(node); }
});
while (!queue.empty()) {
NodeType* cur_node = queue.front();
queue.pop();
JUST(Handler(cur_node));
ForEachOutNode(cur_node, [&](NodeType* out) {
--counter_in[out];
if (counter_in[out] == 0) { queue.push(out); }
});
}
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(
const std::list<NodeType*>& starts,
Expand All @@ -546,15 +672,15 @@ void Graph<NodeType, EdgeType>::DfsTopoForEachNodeSortByDistanceToSink(
HashMap<NodeType*, int64_t> node2distance_to_sink;
{
std::list<NodeType*> nodes;
TopoForEachNode(starts, ForEachInNode, ForEachOutNode,
TopoForEachNode(ForEachInNode, ForEachOutNode,
[&](NodeType* node) { nodes.emplace_back(node); });
std::list<NodeType*> sinks;
for (NodeType* node : nodes) {
bool is_sink = true;
ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; });
if (is_sink) { sinks.emplace_back(node); }
}
TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) {
TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) {
int64_t distance_to_sink = -1;
ForEachOutNode(node, [&](NodeType* out_node) {
distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]);
Expand Down Expand Up @@ -649,12 +775,12 @@ Graph<NodeType, EdgeType>::MakePredicatorIsReachable(
std::shared_ptr<Id2Ancestor> id2ancestor(new Id2Ancestor(node_num()));
int64_t id = 0;
node2id->reserve(node_num());
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) {
TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
node2id->emplace(node, id);
id2ancestor->at(id).Resize(node_num());
id += 1;
});
TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) {
TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) {
const int64_t node_id = node2id->at(node);
auto& ancestor_bitset_vec = id2ancestor->at(node_id);
ForEachInNode(node, [&](NodeType* in_node) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function<void(OpNode*)>& No
const std::function<void(OpNode*)>& Handler) {
ForEachDataAndCtrlOutNode(node, Handler);
};
TopoForEachNode(DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode,
OpGraphForEachOutDataAndCtrlNode, NodeHandler);
TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler);
}

std::function<bool(const std::string&, const std::string&)>
Expand Down
Loading

0 comments on commit f7532fd

Please sign in to comment.