diff --git a/docs/source/graph.rst b/docs/source/graph.rst index 270e5a01cf0..b51c38d5807 100644 --- a/docs/source/graph.rst +++ b/docs/source/graph.rst @@ -26,6 +26,7 @@ Base class for running neural networks in Static Graph Mode. allow_fuse_cast_scale, set_gradient_accumulation_steps, enable_cudnn_conv_heuristic_search_algo, + disable_straighten_algorithm, :member-order: bysource diff --git a/oneflow/core/graph/graph.h b/oneflow/core/graph/graph.h index a72f728c1d8..b9f62e01696 100644 --- a/oneflow/core/graph/graph.h +++ b/oneflow/core/graph/graph.h @@ -34,7 +34,13 @@ class Graph { // For Each void ForEachNode(std::function NodeHandler) const; Maybe MaybeForEachNode(std::function(NodeType*)> NodeHandler) const; + // In case you want to change the topological structure during the node handler. + // For example, adding/deleting a node or an edge. + // Still, it might have bugs even if you use TopoForEachNodeDynamic. + void TopoForEachNodeDynamic(std::function NodeHandler) const; void TopoForEachNode(std::function NodeHandler) const; + Maybe TopoForEachNodeDynamicWithErrorCaptured( + std::function(NodeType*)> NodeHandler) const; Maybe TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const; void ReverseTopoForEachNode(std::function NodeHandler) const; @@ -53,18 +59,40 @@ class Graph { const std::function&)>& ForEachNext, const std::function& Handler) const; + void TopoForEachNodeDynamic( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + void TopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function& Handler) const; + void TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const; + + Maybe TopoForEachNodeDynamicWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const; + Maybe TopoForEachNodeWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, const std::function(NodeType*)>& Handler) const; + Maybe TopoForEachNodeWithErrorCaptured( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const; + void DfsTopoForEachNode( const std::list& starts, const std::function&)>& ForEachInNode, @@ -211,16 +239,33 @@ NodeType* Graph::SoleSinkNode() const { return sink_nodes_list.front(); } +template +void Graph::TopoForEachNodeDynamic( + std::function NodeHandler) const { + TopoForEachNodeDynamic(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); +} + template void Graph::TopoForEachNode(std::function NodeHandler) const { - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, - NodeHandler); + CHECK_JUST(TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, [&](NodeType* node) { + NodeHandler(node); + return Maybe::Ok(); + })); +} + +template +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( + std::function(NodeType*)> NodeHandler) const { + return TopoForEachNodeDynamicWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, + &NodeType::ForEachNodeOnOutEdge, NodeHandler); } template Maybe Graph::TopoForEachNodeWithErrorCaptured( std::function(NodeType*)> NodeHandler) const { - return TopoForEachNodeWithErrorCaptured(source_nodes(), &NodeType::ForEachNodeOnInEdge, + return TopoForEachNodeWithErrorCaptured(&NodeType::ForEachNodeOnInEdge, &NodeType::ForEachNodeOnOutEdge, NodeHandler); } @@ -229,15 +274,14 @@ void Graph::SortedTopoForEachNode( std::function LessThan, std::function NodeHandler) const { ForEachNode([&](NodeType* node) { node->SortInOutEdges(LessThan); }); - TopoForEachNode(source_nodes(), &NodeType::ForEachNodeOnSortedInEdge, - &NodeType::ForEachNodeOnSortedOutEdge, NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnSortedInEdge, &NodeType::ForEachNodeOnSortedOutEdge, + NodeHandler); } template void Graph::ReverseTopoForEachNode( std::function NodeHandler) const { - TopoForEachNode(sink_nodes(), &NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, - NodeHandler); + TopoForEachNode(&NodeType::ForEachNodeOnOutEdge, &NodeType::ForEachNodeOnInEdge, NodeHandler); } template @@ -493,6 +537,19 @@ std::unique_ptr> Graph::FindFirstNontrivi return std::unique_ptr>(); } +template +void Graph::TopoForEachNodeDynamic( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST(TopoForEachNodeDynamicWithErrorCaptured(starts, ForEachInNode, ForEachOutNode, + [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + template void Graph::TopoForEachNode( const std::list& starts, @@ -507,7 +564,18 @@ void Graph::TopoForEachNode( } template -Maybe Graph::TopoForEachNodeWithErrorCaptured( +void Graph::TopoForEachNode( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function& Handler) const { + CHECK_JUST(TopoForEachNodeWithErrorCaptured(ForEachInNode, ForEachOutNode, [&](NodeType* node) { + Handler(node); + return Maybe::Ok(); + })); +} + +template +Maybe Graph::TopoForEachNodeDynamicWithErrorCaptured( const std::list& starts, const std::function&)>& ForEachInNode, const std::function&)>& ForEachOutNode, @@ -537,6 +605,64 @@ Maybe Graph::TopoForEachNodeWithErrorCaptured( return Maybe::Ok(); } +template +Maybe Graph::TopoForEachNodeWithErrorCaptured( + const std::list& starts, + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const { + HashMap counter_in; + std::queue queue; + for (NodeType* start : starts) { + queue.push(start); + counter_in[start] = 0; + ForEachInNode(start, [&](NodeType*) { LOG(FATAL) << "not a source"; }); + } + while (!queue.empty()) { + NodeType* cur_node = queue.front(); + queue.pop(); + JUST(Handler(cur_node)); + ForEachOutNode(cur_node, [&](NodeType* out) { + auto it = counter_in.find(out); + // Move the initialization here + if (it == counter_in.end()) { + int32_t count = 0; + ForEachInNode(out, [&](NodeType* out_in) { count++; }); + counter_in[out] = count; + it = counter_in.find(out); + } + it->second--; + if (it->second == 0) { queue.push(out); } + }); + } + return Maybe::Ok(); +} + +template +Maybe Graph::TopoForEachNodeWithErrorCaptured( + const std::function&)>& ForEachInNode, + const std::function&)>& ForEachOutNode, + const std::function(NodeType*)>& Handler) const { + HashMap counter_in; + std::queue queue; + ForEachNode([&](NodeType* node) { + int32_t count = 0; + ForEachInNode(node, [&](NodeType*) { count++; }); + counter_in[node] = count; + if (count == 0) { queue.push(node); } + }); + while (!queue.empty()) { + NodeType* cur_node = queue.front(); + queue.pop(); + JUST(Handler(cur_node)); + ForEachOutNode(cur_node, [&](NodeType* out) { + --counter_in[out]; + if (counter_in[out] == 0) { queue.push(out); } + }); + } + return Maybe::Ok(); +} + template void Graph::DfsTopoForEachNodeSortByDistanceToSink( const std::list& starts, @@ -546,7 +672,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( HashMap node2distance_to_sink; { std::list nodes; - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { nodes.emplace_back(node); }); std::list sinks; for (NodeType* node : nodes) { @@ -554,7 +680,7 @@ void Graph::DfsTopoForEachNodeSortByDistanceToSink( ForEachOutNode(node, [&](NodeType* out_node) { is_sink = false; }); if (is_sink) { sinks.emplace_back(node); } } - TopoForEachNode(sinks, ForEachOutNode, ForEachInNode, [&](NodeType* node) { + TopoForEachNode(ForEachOutNode, ForEachInNode, [&](NodeType* node) { int64_t distance_to_sink = -1; ForEachOutNode(node, [&](NodeType* out_node) { distance_to_sink = std::max(distance_to_sink, node2distance_to_sink[out_node]); @@ -649,12 +775,12 @@ Graph::MakePredicatorIsReachable( std::shared_ptr id2ancestor(new Id2Ancestor(node_num())); int64_t id = 0; node2id->reserve(node_num()); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { node2id->emplace(node, id); id2ancestor->at(id).Resize(node_num()); id += 1; }); - TopoForEachNode(starts, ForEachInNode, ForEachOutNode, [&](NodeType* node) { + TopoForEachNode(ForEachInNode, ForEachOutNode, [&](NodeType* node) { const int64_t node_id = node2id->at(node); auto& ancestor_bitset_vec = id2ancestor->at(node_id); ForEachInNode(node, [&](NodeType* in_node) { diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index 4bd88e55f5f..45e5eba9166 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -472,8 +472,7 @@ void OpGraph::TopoForEachNodeWithCtrlEdge(const std::function& No const std::function& Handler) { ForEachDataAndCtrlOutNode(node, Handler); }; - TopoForEachNode(DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode, - OpGraphForEachOutDataAndCtrlNode, NodeHandler); + TopoForEachNode(OpGraphForEachInDataAndCtrlNode, OpGraphForEachOutDataAndCtrlNode, NodeHandler); } std::function diff --git a/oneflow/core/graph/straighten_nodes.cpp b/oneflow/core/graph/straighten_nodes.cpp new file mode 100644 index 00000000000..1e708e19df0 --- /dev/null +++ b/oneflow/core/graph/straighten_nodes.cpp @@ -0,0 +1,485 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/graph/straighten_nodes.h" +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/graph/task_node.h" +#include "oneflow/core/job/job_desc.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/job/task.pb.h" + +namespace oneflow { + +namespace { + +enum TaskClassifier : int { + kWaitingTransfer = 0, + kWaitingComputation = 1, + kRunASAP = 2, + kRunALAP = 3 +}; + +class TopoStruct { + public: + TaskNode* node = nullptr; + int32_t min_layer = -1; + int32_t tributary_layer = -1; + bool on_mainstem = false; + int32_t counter = 0; + int32_t min_distance2transfer = -1; + TopoStruct* next_same_node = nullptr; + // We can have some other nodes in it for example + // SbpNode* node; + // SbpEdge* node; + // Or we can omit all the pointers and leave all the useful parameters. + + // Drop down the tributary layer + void DropTributaryLayer(int32_t upper_bound); + + void SpreadTributaryLayer(HashMap* task_node2topo_struct); + + void SpreadMainstem(HashMap* task_node2topo_struct); + + // The minimum computation distance from the beginning of this op to the next transfer + int32_t GetMinDistance2Transfer(HashMap* task_node2topo_struct); + + // deciding parameter + // i = 0: those with small tributary layers go first + // i = 1: those with small minimum distance to transfer go first + // i = 2: first in first out + // i = 3: those with large tributary layers go first + // i = 4: those with long distance to transfer go first + // i = 5: last in first out + int32_t GetDecidingParameter(int32_t i) const; +}; + +// move the head from source to target +void MoveFrontBetweenMaps(std::map& source, + std::map& target) { + if (!source.empty()) { + const auto& front = source.begin(); + target[front->first] = front->second; + source.erase(front); + } +}; + +bool ShouldRunASAP(TaskType task_type) { + // They are sorted according to frequency of occurrences + switch (task_type) { + // We mark the number of occurrences in bert + case TaskType::kDeviceTick: // 38 + case TaskType::kTick: // 8 + case TaskType::kSrcSubsetTick: // 6 + case TaskType::kDstSubsetTick: // 6 + case TaskType::kCriticalSectionWaitTick: // 4 + case TaskType::kWaitAndSendIds: // 2 + case TaskType::kPack: // 0 + case TaskType::kUnpack: // 0 + case TaskType::kRepeat: // 0 + case TaskType::kAcc: // 0 + case TaskType::kSourceTick: // 0 + case TaskType::kAccTick: // 0 + case TaskType::kCase: // 0 + case TaskType::kEsac: // 0 + case TaskType::kReentrantLock: return true; // 0 + default: return false; + } +} + +bool IsTransferNode(TaskType task_type) { + // return task_type == 12 || task_type == 13 || (48 <= task_type && task_type <= 64); + // They are sorted according to frequency of occurrences + switch (task_type) { + // We mark the number of occurrences in bert + case TaskType::kCollectiveBoxingGeneric: // 76 + case TaskType::kCopyHd: // 27 + case TaskType::kSliceBoxing: // 16 + case TaskType::kCopyCommNet: // 12 + case TaskType::kCollectiveBoxingPack: // 8 + case TaskType::kCollectiveBoxingUnpack: // 8 + case TaskType::kBoxingZeros: // 3 + case TaskType::kForeignInput: // 0 + case TaskType::kForeignOutput: // 0 + case TaskType::kDistributeConcat: // 0 + case TaskType::kDistributeSplit: // 0 + case TaskType::kBoxingIdentity: // 0 + case TaskType::kDecodeH2D: // 0 + case TaskType::kSspVariableProxy: return true; // 0 + default: return false; + } +} + +// Classifier for the set according to the task type +TaskClassifier GetTaskClassifier(const TaskNode* node) { + // Check task.pb.h for detail + // They are sorted according to frequency of judgement + // frequency of judgement = the number of occurrences / the times of judgement + TaskType task_type = node->GetTaskType(); + if (task_type == TaskType::kNormalForward) { return TaskClassifier::kWaitingComputation; } + if (IsTransferNode(task_type)) { return TaskClassifier::kWaitingTransfer; } + if (task_type == TaskType::kCallbackNotify) { return TaskClassifier::kRunALAP; } + if (ShouldRunASAP(task_type)) { return TaskClassifier::kRunASAP; } + CHECK(false) << "Unclassified or invalid task type (" << task_type << ") showing up"; + // Throw a kRunASAP which means ignoring this node in the algorithm + return TaskClassifier::kRunASAP; +} + +// Drop down the maximum layer with the minimum layer form consumer +void TopoStruct::DropTributaryLayer(int32_t upper_bound) { + if (upper_bound < tributary_layer || tributary_layer < 0) { tributary_layer = upper_bound; } +} + +// Should initialize the counter to be the number of out edges +// Compute maximum layer for tributaries +void TopoStruct::SpreadTributaryLayer(HashMap* task_node2topo_struct) { + if (counter || min_layer <= 0) { return; } + int32_t producer_max_lay = 0; + if (on_mainstem) { + producer_max_lay = min_layer - 1; + } else { + // On a tributary, the operator could be run later. + producer_max_lay = tributary_layer; + } + node->ForEachNodeOnInEdge([&](TaskNode* in) { + auto& topo_struct_in = task_node2topo_struct->at(in); + topo_struct_in.DropTributaryLayer(producer_max_lay); + --topo_struct_in.counter; + if (topo_struct_in.counter == 0) { topo_struct_in.SpreadTributaryLayer(task_node2topo_struct); } + }); + // Reduce counter to -1 to avoid visiting again + counter--; +} + +// Judge if this node is on the mainstem +// If so, judge it for its producer/upstream nodes +void TopoStruct::SpreadMainstem(HashMap* task_node2topo_struct) { + // Skip it if this node is already judged. + if (on_mainstem) { return; } + CHECK_GE(min_layer, 0) << "TopoStruct not initialized!"; + on_mainstem = true; + // If I am in the mainstem, then all the children with (min_layer >= my layer id - 1) would be + // considered as in the mainstem + node->ForEachNodeOnInEdge([&](TaskNode* in) { + auto& topo_struct_in = task_node2topo_struct->at(in); + if (topo_struct_in.min_layer == min_layer - 1) { + topo_struct_in.SpreadTributaryLayer(task_node2topo_struct); + } + }); +} + +// The minimum computation distance from the beginning of this op to the next transfer +int32_t TopoStruct::GetMinDistance2Transfer(HashMap* task_node2topo_struct) { + if (min_distance2transfer >= 0) { return min_distance2transfer; } + // if this node is a transfer node + if (IsTransferNode(node->GetTaskType())) { + min_distance2transfer = 0; + return min_distance2transfer; + } + // Otherwise, initialize it with a large number + // Well, the total number in the task graph is large enough + min_distance2transfer = task_node2topo_struct->size(); + node->ForEachNodeOnOutEdge([&](TaskNode* out) { + min_distance2transfer = + std::min(min_distance2transfer, + task_node2topo_struct->at(out).GetMinDistance2Transfer(task_node2topo_struct)); + }); + ++min_distance2transfer; + return min_distance2transfer; +} + +// deciding parameter +// i = 0: those with small tributary layers go first +// i = 1: those with small minimum distance to transfer go first +// i = 2: first in first out +// i = 3: those with large tributary layers go first +// i = 4: those with long distance to transfer go first +// i = 5: last in first out +int32_t TopoStruct::GetDecidingParameter(int32_t i) const { + int32_t sign = 1; + if (i >= 3) { + i -= 3; + sign = -1; + } + switch (i) { + case 0: return sign * tributary_layer; + case 1: return sign * min_distance2transfer; + case 2: return sign * min_layer; + } + return 0; +} + +// Find the mainstem of the task graph, then reduce the wait time for tributaries +void FindMainstem(HashMap* task_node2topo_struct) { + // Find the maximum layer number + int32_t max_min_layer = -1; + for (const auto& pair : *task_node2topo_struct) { + if (max_min_layer < pair.second.min_layer) { max_min_layer = pair.second.min_layer; } + } + // All the nodes with min_layer>=mainstem_end_id would be considered as mainstem nodes + // The last 5 layers would be considered as in mainstem anyway. + int32_t mainstem_end_id = max_min_layer - 4; + for (auto& pair : *task_node2topo_struct) { + auto& topo_struct = pair.second; + // Initialize the counter and Tributary Layer + topo_struct.counter = pair.first->out_edges().size(); + topo_struct.tributary_layer = max_min_layer; + // Find out all the nodes on the mainstem. + if (topo_struct.min_layer >= mainstem_end_id) { + topo_struct.SpreadMainstem(task_node2topo_struct); + } + } + + for (auto& pair : *task_node2topo_struct) { + // Compute maximum layer for tributaries + pair.second.SpreadTributaryLayer(task_node2topo_struct); + // Set the min_distance2transfer for each topological structure + pair.second.GetMinDistance2Transfer(task_node2topo_struct); + } +} + +} // anonymous namespace + +void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task_nodes) { + // The function for settle the order in the graph + int64_t order_in_graph = 0; + + // Generate topological data structure for each task node + HashMap task_node2topo_struct; + // Determine the same nodes which should run simultaneously + HashMap>> + task_type2machine_id2node_id2topo_structs; + std::map min_node_id2topo_struct; + int32_t previous_min_layer = 0; + task_graph->TopoForEachNode([&](TaskNode* node) { + auto& topo_struct = task_node2topo_struct[node]; + topo_struct.node = node; + if (node->in_edges().empty()) { + topo_struct.min_layer = 0; + } else { + int32_t max_min_layer = 0; + node->ForEachNodeOnInEdge([&](TaskNode* in) { + max_min_layer = std::max(max_min_layer, task_node2topo_struct[in].min_layer); + }); + topo_struct.min_layer = max_min_layer + 1; + // Deal with all the nodes with min_layer=previous_min_layer + if (max_min_layer >= previous_min_layer) { + // Using "7" to represent "and" + // a7b means a pair (a, b) + for (auto& task_type7machine_id2node_id2topo_structs : + task_type2machine_id2node_id2topo_structs) { + auto& machine_id2node_id2topo_structs = task_type7machine_id2node_id2topo_structs.second; + // Initializing the smallest node id for each machine + for (auto& machine_id7node_id2topo_structs : machine_id2node_id2topo_structs) { + MoveFrontBetweenMaps(machine_id7node_id2topo_structs.second, min_node_id2topo_struct); + } + + while (!min_node_id2topo_struct.empty()) { + // auto* topo_struct_min_node_id = min_node_id2topo_struct.begin()->second; + // Store the same nodes in different machines + std::vector same_nodes; + for (auto& min_node_id7topo_struct : min_node_id2topo_struct) { + auto* curr_topo_struct = min_node_id7topo_struct.second; + // Find out all the same nodes + // Stop using Visual string before we find a better key + // Currently we can use the topological structure and node id to decide the same nodes + same_nodes.push_back(curr_topo_struct); + } + // Cyclize them + for (int32_t i = 1; i < same_nodes.size(); i++) { + same_nodes[i - 1]->next_same_node = same_nodes[i]; + } + (*same_nodes.rbegin())->next_same_node = same_nodes[0]; + // Delete them and add new candidates + for (auto* same_node_topo_struct : same_nodes) { + // Erase them from min_node_id2topo_struct + min_node_id2topo_struct.erase(same_node_topo_struct->node->node_id()); + // Add new candidate + MoveFrontBetweenMaps( + machine_id2node_id2topo_structs[same_node_topo_struct->node->machine_id()], + min_node_id2topo_struct); + } + } + } + // Renew the previous min_layer at the end + previous_min_layer = topo_struct.min_layer; + } + } + // Put the topo structure into the map, waiting for determine the same nodes + task_type2machine_id2node_id2topo_structs[node->GetTaskType()][node->machine_id()] + [node->node_id()] = &topo_struct; + }); + + // Generate other parameters in the topological data structure + FindMainstem(&task_node2topo_struct); + + VLOG(3) << "Straightening order: " << 5 << ", " << 3; + + // Order in the waiting sets + // Decide which node should run first + struct comp { + bool operator()(const TopoStruct* a, const TopoStruct* b) const { + // NOTE: Leave these code for debugging in the future + // static std::vector decide_parameters({ParseIntegerFromEnv("Parameter0", 0), + // ParseIntegerFromEnv("Parameter1", 1), + // ParseIntegerFromEnv("Parameter2", 2)}); + // The best parameter set is {5, 3} + static std::vector decide_parameters({5, 3}); + for (int32_t decide_parameter : decide_parameters) { + int32_t decide_parameter_a = a->GetDecidingParameter(decide_parameter); + int32_t decide_parameter_b = b->GetDecidingParameter(decide_parameter); + if (decide_parameter_a != decide_parameter_b) { + return decide_parameter_a < decide_parameter_b; + } + } + return a->node->node_id() < b->node->node_id(); + } + }; + + // Classify sets for the task nodes + // std::set waiting_transfer; // 0, TaskClassifier::kWaitingTransfer + // std::set waiting_computation; // 1, TaskClassifier::kWaitingComputation + // std::set run_asap; // 2, TaskClassifier::kRunASAP , run as soon as possible + // std::set run_alap; // 3, TaskClassifier::kRunALAP , run as late as possible + const int32_t num_classifier = 4; + std::vector> waiting_lists(num_classifier); + + std::vector remain_task_nums(num_classifier, 0); + + auto SetOrderInGraph = [&](TaskNode* task_node) { + task_node->set_order_in_graph(order_in_graph); + ordered_task_nodes->emplace_back(task_node); + ++order_in_graph; + }; + + // wait in the list + auto wait = [&](TaskNode* node) { + TopoStruct* first_topo_struct = &task_node2topo_struct[node]; + // Check if all the same nodes are ready simultaneously + TopoStruct* curr_topo_struct = first_topo_struct->next_same_node; + while (curr_topo_struct && curr_topo_struct != first_topo_struct) { + if (curr_topo_struct->counter) { return; } + curr_topo_struct = curr_topo_struct->next_same_node; + } + // Add all the same nodes at the same time + curr_topo_struct = first_topo_struct; + auto& waiting_list = waiting_lists[GetTaskClassifier(node)]; + while (true) { + waiting_list.insert(curr_topo_struct); + // Reduce counter then this node will never be added again + // Though inserting into a map twice does not matter because of the same keys + curr_topo_struct->counter--; + curr_topo_struct = curr_topo_struct->next_same_node; + if ((!curr_topo_struct) || (curr_topo_struct == first_topo_struct)) { break; } + } + }; + + // initialization + task_graph->ForEachNode([&](TaskNode* node) { + int32_t count = node->in_edges().size(); + task_node2topo_struct[node].counter = count; + if (count == 0) { wait(node); } + remain_task_nums[GetTaskClassifier(node)]++; + }); + + // Finish execution + auto finish_execution = [&](TaskNode* node) { + node->ForEachNodeOnOutEdge([&](TaskNode* out) { + --(task_node2topo_struct[out].counter); + if (task_node2topo_struct[out].counter == 0) { wait(out); } + }); + }; + + // Move the first node of the waiting list to the execution list + auto move2execution_list = [&](std::set& waiting_list, + std::vector& execution_list) { + TaskNode* first_node = (*waiting_list.begin())->node; + int32_t execution_num = 0; + TopoStruct* first_topo_struct = &task_node2topo_struct[first_node]; + // Find all the same nodes in different machine + // They should be run simultaneously + TopoStruct* curr_topo_struct = first_topo_struct; + while (true) { + execution_num++; + execution_list.push_back(curr_topo_struct->node); + waiting_list.erase(curr_topo_struct); + // move and maybe leave + curr_topo_struct = curr_topo_struct->next_same_node; + if ((!curr_topo_struct) || (curr_topo_struct == first_topo_struct)) { break; } + } + CHECK_GT(execution_num, 0) << "Error, no task nodes are moved to the execution list"; + }; + + // Execute the first n nodes in the waiting list + auto execute = [&](int32_t list_classifier, int32_t n, bool if_reverse = false) { + // n > 0 + if (n <= 0) { return; } + auto& waiting_list = waiting_lists[list_classifier]; + std::vector execution_list; + int32_t count = 0; + // Move to the execution list + while (!waiting_list.empty()) { + move2execution_list(waiting_list, execution_list); + count++; + if (count >= n) { break; } + } + remain_task_nums[list_classifier] -= execution_list.size(); + // Set the order and then remove from the execution list + for (auto* node : execution_list) { + SetOrderInGraph(node); + finish_execution(node); + } + }; + + // straightening + while (true) { + if (waiting_lists[TaskClassifier::kRunASAP].empty()) { + if (waiting_lists[TaskClassifier::kWaitingTransfer].empty()) { + if (waiting_lists[TaskClassifier::kWaitingComputation].empty()) { + if (waiting_lists[TaskClassifier::kRunALAP].empty()) { + // All the waiting lists are empty + break; + } else { + // Execute all the nodes left + execute(TaskClassifier::kRunALAP, waiting_lists[TaskClassifier::kRunALAP].size()); + } + } else { + // Execute one computation node + execute(TaskClassifier::kWaitingComputation, 1); + } + } else { + int32_t computation_num = + std::min(int32_t(waiting_lists[TaskClassifier::kWaitingComputation].size() + / (waiting_lists[TaskClassifier::kWaitingTransfer].size())), + remain_task_nums[TaskClassifier::kWaitingComputation] + / remain_task_nums[TaskClassifier::kWaitingTransfer]); + // Holding the transfer + std::vector transfer_execution_list; + move2execution_list(waiting_lists[TaskClassifier::kWaitingTransfer], + transfer_execution_list); + remain_task_nums[TaskClassifier::kWaitingTransfer] -= transfer_execution_list.size(); + for (auto* transfer_node : transfer_execution_list) { SetOrderInGraph(transfer_node); } + // Overlap transfer with computation + execute(TaskClassifier::kWaitingComputation, computation_num); + + // Release the transfer + for (auto* transfer_node : transfer_execution_list) { finish_execution(transfer_node); } + } + } else { + execute(TaskClassifier::kRunASAP, waiting_lists[TaskClassifier::kRunASAP].size()); + } + } +} + +} // namespace oneflow diff --git a/oneflow/core/graph/straighten_nodes.h b/oneflow/core/graph/straighten_nodes.h new file mode 100644 index 00000000000..e68a03c698c --- /dev/null +++ b/oneflow/core/graph/straighten_nodes.h @@ -0,0 +1,27 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ +#define ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ + +#include "oneflow/core/graph/task_graph.h" + +namespace oneflow { + +void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task_nodes); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_STRAIGHTEN_NODES_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 5fd69c40274..404b93a455a 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/graph/task_stream_index_manager.h" #include "oneflow/core/ep/include/primitive/memcpy.h" +#include "oneflow/core/graph/straighten_nodes.h" namespace oneflow { @@ -419,7 +420,7 @@ void ForEachOpGraphNecessaryCtrlEdge( } // namespace -TaskGraph::TaskGraph() { +TaskGraph::TaskGraph(bool disable_straighten_algorithm) { OpGraph* op_graph = Global::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); @@ -450,7 +451,11 @@ TaskGraph::TaskGraph() { } }); - SetOrderInGraphForEachNode(); + if (disable_straighten_algorithm) { + SetOrderInGraphForEachNode(); + } else { + StraightenNodes(this, &ordered_task_nodes_); + } if (Global::Get()->enable_debug_mode()) { ToDotWithAutoFilePath(); } } diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 71593a834f1..2ec3e15f18e 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -43,7 +43,7 @@ class TaskGraph final : public Graph { OF_DISALLOW_COPY_AND_MOVE(TaskGraph); ~TaskGraph() override; - explicit TaskGraph(); + explicit TaskGraph(bool disable_straighten_algorithm); const char* TypeName() const override { return "TaskGraph"; } void RemoveEmptyRegsts(); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 7cdcbb9a5e1..a2d47a1d38a 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -61,7 +61,8 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { // Step3: build task_gph. // TODO(levi): we can rewrite this part of code in visitor pattern. - auto task_gph = std::make_unique(); + auto task_gph = + std::make_unique(job->job_conf().disable_straighten_algorithm_in_task_graph()); using std::placeholders::_1; task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1)); task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1)); diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 03638feec30..18dcb92e41b 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -240,6 +240,8 @@ message JobConfigProto { optional bool cudnn_conv_enable_pseudo_half = 600 [default = true]; optional bool enable_auto_mixed_precision = 602 [default = false]; optional bool enable_quantization_aware_training = 603 [default = false]; + + optional bool disable_straighten_algorithm_in_task_graph = 700 [default = false]; optional int64 concurrency_width = 1000 [default = 128]; diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index ea48ad8d957..d367ca5c333 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -278,6 +278,16 @@ def build(self, x): """ self.proto.cudnn_conv_heuristic_search_algo = mode + def disable_straighten_algorithm(self, mode: bool = False): + r""" Whether we disable the straighten algorithm. + + If using nccl compute stream, turning it on might not speed up the training. + If not using nccl compute stream, turning it on might slow down data parallelism by 0.6% and slow down model parallelism by 6%. + + The switch is off by default (i.e. use the straighten algorithm by default). + """ + self.proto.disable_straighten_algorithm_in_task_graph = mode + def _generate_optimizer_and_variable_configs( self, opt_dict: OptDict = None, variables_conf: OrderedDict = None, ):