From de63f3f37dd06ea418739149d0398549574125a0 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 14 Nov 2022 20:08:26 +0800 Subject: [PATCH] Feat speed up throughput (#9423) * Move the "-expand" and "-cast" ops backward * Hard-coding for stable diffusion, maximize overlaps * Use op_tyep_name instead of visual string * Change transfer nodes to tributary nodes * Rename tributary to overlap * Prepare to test different decide parameters * Prepare to print and test * {7, 5} seems to be one of the best as before * Find the best straighten mode 973 for stable diffusion * Put cpu nodes into overlap node list * Disable overlap between cpu and gpu if no cpu nodes * Update API * Remove magical number * Update comment * Remove std log message * Remove debug code * Static analysis --- oneflow/core/graph/straighten_nodes.cpp | 276 ++++++++++++++++-------- oneflow/core/graph/task_graph.cpp | 11 +- oneflow/core/job/job_conf.proto | 5 +- python/oneflow/nn/graph/graph_config.py | 17 +- 4 files changed, 214 insertions(+), 95 deletions(-) diff --git a/oneflow/core/graph/straighten_nodes.cpp b/oneflow/core/graph/straighten_nodes.cpp index 2664e1f9017..b649382b232 100644 --- a/oneflow/core/graph/straighten_nodes.cpp +++ b/oneflow/core/graph/straighten_nodes.cpp @@ -14,14 +14,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "oneflow/core/common/util.h" +#include "oneflow/core/graph/compute_task_node.h" #include "oneflow/core/graph/straighten_nodes.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/graph/task_node.h" +#include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/job/task.pb.h" +#include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/register/runtime_register_desc.h" namespace oneflow { @@ -29,12 +33,31 @@ namespace oneflow { namespace { enum TaskClassifier : int { - kWaitingTransfer = 0, - kWaitingComputation = 1, + kWaitingOverlapNode = 0, + kWaitingMainComputation = 1, kRunASAP = 2, kRunALAP = 3 }; +// deciding parameter +// The sorting order of nodes for the straighten algorithm +enum StraightenOrder : int { + kTributaryLayerAscend = 0, // small tributary layers go first + kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first + kLayerAscend = 2, // first in first out + kMemoryIncrementAscend = 3, // small memory increment go first + kActivationTimeAscend = 4, // small activation time go first + + kTributaryLayerDescend = 100, // large tributary layers go first + kDistanceToOverlapDescend = 101, // long distance to overlap go first + kLayerDescend = 102, // last in first out + kMemoryIncrementDescend = 103, // large memory increment go first + kActivationTimeDescend = 104, // large activation time go first +}; + +// The difference between a descending order and its corresponding ascending order +const int kDiff4AscendDescend = 100; + class TopoStruct { public: TaskNode* node = nullptr; @@ -42,9 +65,10 @@ class TopoStruct { int32_t tributary_layer = -1; bool on_trunk = false; int32_t counter = 0; - int32_t min_distance2transfer = -1; + int32_t min_distance2overlap = -1; int64_t memory_increment = -1; TopoStruct* next_same_node = nullptr; + int32_t activation_time = -1; // We can have some other nodes in it for example // SbpNode* node; // SbpEdge* node; @@ -57,46 +81,57 @@ class TopoStruct { void SpreadTrunk(HashMap* task_node2topo_struct); - // The minimum computation distance from the beginning of this op to the next transfer - int32_t GetMinDistance2Transfer(HashMap* task_node2topo_struct); + // The minimum computation distance from the beginning of this op to the next overlap node + int32_t GetMinDistance2Overlap(HashMap* task_node2topo_struct); // Memory increment = (memory of out registers) - (memory of in registers) void ComputeMeomoryIncrement(); + // Activation time = time of cpu - time of gpu + void ComputeActivationTime(); + // TODO: We might design more deciding parameter and choose a right combination of them in the // future. // deciding parameter - // i = 0: those with small tributary layers go first - // i = 1: those with small minimum distance to transfer go first - // i = 2: first in first out - // i = 3: those with small memory increment go first - // i = 4: those with large tributary layers go first - // i = 5: those with long distance to transfer go first - // i = 6: last in first out - // i = 7: those with large memory increment go first - int64_t GetDecidingParameter(int32_t i) const; + // kTributaryLayerAscend = 0, // small tributary layers go first + // kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first + // kLayerAscend = 2, // first in first out + // kMemoryIncrementAscend = 3, // small memory increment go first + // kActivationTimeAscend = 4, // small activation time go first + // kTributaryLayerDescend = 100, // large tributary layers go first + // kDistanceToOverlapDescend = 101, // long distance to overlap go first + // kLayerDescend = 102, // last in first out + // kMemoryIncrementDescend = 103, // large memory increment go first + // kActivationTimeDescend = 104, // large activation time go first + int64_t GetDecidingParameter(StraightenOrder so) const; }; +StraightenAlgorithmTag sat; + // NOTE: Leave these code for debugging in the future -// static std::vector decide_parameters({ParseIntegerFromEnv("Parameter0", 3), -// ParseIntegerFromEnv("Parameter1", 0), -// ParseIntegerFromEnv("Parameter2", 3)}); -// The best parameter set for saving time is {6, 4} +// static std::vector decide_parameters({ParseIntegerFromEnv("Parameter0", 3), +// ParseIntegerFromEnv("Parameter1", 0), +// ParseIntegerFromEnv("Parameter2", 3)}); +// The best parameter set for saving time is {102, 100} // The best parameter set for saving memory is {3, 0} -static std::vector decide_parameters; +static std::vector decide_parameters; -// SAT, a.k.a. Scholastic Aptitude Test, is the college admission test in the United States of -// America. +// SAT, a.k.a. Scholastic Aptitude Test, +// is the college admission test in the United States of America. void InitDecideParameters(StraightenAlgorithmTag sat) { decide_parameters.clear(); if (sat == StraightenAlgorithmTag::kCompressMemory) { - decide_parameters.push_back(3); - decide_parameters.push_back(0); + decide_parameters.push_back(StraightenOrder::kMemoryIncrementAscend); + decide_parameters.push_back(StraightenOrder::kTributaryLayerAscend); + } else if (sat == StraightenAlgorithmTag::kOverlap4Transfer) { + decide_parameters.push_back(StraightenOrder::kLayerDescend); + decide_parameters.push_back(StraightenOrder::kTributaryLayerDescend); } else { - // sat==StraightenAlgorithmTag::kOverlap4ModelParallelism - decide_parameters.push_back(6); - decide_parameters.push_back(4); + // sat==StraightenAlgorithmTag::kOverlap4CpuGpu + decide_parameters.push_back(StraightenOrder::kActivationTimeDescend); + decide_parameters.push_back(StraightenOrder::kLayerDescend); + decide_parameters.push_back(StraightenOrder::kMemoryIncrementAscend); } } @@ -156,14 +191,42 @@ bool IsTransferNode(TaskType task_type) { } } +// Some operators have longer time in cpu and less time in gpu. +// Running those operators without overlap would cause large gap during each iteration. +bool LongerActivationTimeInCpu(const OperatorConf& op_conf) { + if (op_conf.has_user_conf()) { + const auto& op_type_name = op_conf.user_conf().op_type_name(); + // They are sorted according to frequency of occurrences in stable diffusion + if (op_type_name == "expand_dims" // 90 + || op_type_name == "cast" // 16 + || op_type_name == "expand" // 2 + ) { + return true; + } + } + return false; +} + // Classifier for the set according to the task type TaskClassifier GetTaskClassifier(const TaskNode* node) { // Check task.pb.h for detail // They are sorted according to frequency of judgement // frequency of judgement = the number of occurrences / the times of judgement TaskType task_type = node->GetTaskType(); - if (task_type == TaskType::kNormalForward) { return TaskClassifier::kWaitingComputation; } - if (IsTransferNode(task_type)) { return TaskClassifier::kWaitingTransfer; } + if (task_type == TaskType::kNormalForward) { + const auto& op_conf = dynamic_cast(node)->op()->op_conf(); + if (op_conf.has_variable_conf()) { + // Variable operators would not be run. They just create tensors. + // We do not visualize any execution in NVTX. (Even a tick operator has something in NVTX.) + return TaskClassifier::kRunASAP; + } else if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu + && LongerActivationTimeInCpu(op_conf)) { + return TaskClassifier::kWaitingOverlapNode; + } else { + return TaskClassifier::kWaitingMainComputation; + } + } + if (IsTransferNode(task_type)) { return TaskClassifier::kWaitingOverlapNode; } if (task_type == TaskType::kCallbackNotify) { return TaskClassifier::kRunALAP; } if (ShouldRunASAP(task_type)) { return TaskClassifier::kRunASAP; } CHECK(false) << "Unclassified or invalid task type (" << task_type << ") showing up"; @@ -214,24 +277,24 @@ void TopoStruct::SpreadTrunk(HashMap* task_node2topo_stru }); } -// The minimum computation distance from the beginning of this op to the next transfer -int32_t TopoStruct::GetMinDistance2Transfer(HashMap* task_node2topo_struct) { - if (min_distance2transfer >= 0) { return min_distance2transfer; } - // if this node is a transfer node - if (IsTransferNode(node->GetTaskType())) { - min_distance2transfer = 0; - return min_distance2transfer; +// The minimum computation distance from the beginning of this op to the next overlap +int32_t TopoStruct::GetMinDistance2Overlap(HashMap* task_node2topo_struct) { + if (min_distance2overlap >= 0) { return min_distance2overlap; } + // if this node should be overlapped by main computation nodes + if (GetTaskClassifier(node) == TaskClassifier::kWaitingOverlapNode) { + min_distance2overlap = 0; + return min_distance2overlap; } // Otherwise, initialize it with a large number // Well, the total number in the task graph is large enough - min_distance2transfer = task_node2topo_struct->size(); + min_distance2overlap = task_node2topo_struct->size(); node->ForEachNodeOnOutEdge([&](TaskNode* out) { - min_distance2transfer = - std::min(min_distance2transfer, - task_node2topo_struct->at(out).GetMinDistance2Transfer(task_node2topo_struct)); + min_distance2overlap = + std::min(min_distance2overlap, + task_node2topo_struct->at(out).GetMinDistance2Overlap(task_node2topo_struct)); }); - ++min_distance2transfer; - return min_distance2transfer; + ++min_distance2overlap; + return min_distance2overlap; } // Memory increment = (memory of out registers) - (memory of in registers) @@ -258,28 +321,41 @@ void TopoStruct::ComputeMeomoryIncrement() { } } +// Activation time = time of cpu - time of gpu +void TopoStruct::ComputeActivationTime() { + if (node->GetTaskType() == TaskType::kNormalForward + && LongerActivationTimeInCpu(dynamic_cast(node)->op()->op_conf())) { + activation_time = 1; + } else { + activation_time = 0; + } +} + // deciding parameter -// i = 0: those with small tributary layers go first -// i = 1: those with small minimum distance to transfer go first -// i = 2: first in first out -// i = 3: those with small memory increment go first -// i = 4: those with large tributary layers go first -// i = 5: those with long distance to transfer go first -// i = 6: last in first out -// i = 7: those with large memory increment go first -int64_t TopoStruct::GetDecidingParameter(int32_t i) const { +// kTributaryLayerAscend = 0, // small tributary layers go first +// kDistanceToOverlapAscend = 1, // small minimum distance to overlap go first +// kLayerAscend = 2, // first in first out +// kMemoryIncrementAscend = 3, // small memory increment go first +// kActivationTimeAscend = 4, // small activation time go first +// kTributaryLayerDescend = 100, // large tributary layers go first +// kDistanceToOverlapDescend = 101, // long distance to overlap go first +// kLayerDescend = 102, // last in first out +// kMemoryIncrementDescend = 103, // large memory increment go first +// kActivationTimeDescend = 104, // large activation time go first +int64_t TopoStruct::GetDecidingParameter(StraightenOrder so) const { int64_t sign = 1; - if (i >= 4) { - i -= 4; + if (so >= kDiff4AscendDescend) { + so = StraightenOrder(int(so) - kDiff4AscendDescend); sign = -1; } - switch (i) { - case 0: return sign * tributary_layer; - case 1: return sign * min_distance2transfer; - case 2: return sign * min_layer; - case 3: return sign * memory_increment; + switch (so) { + case StraightenOrder::kTributaryLayerAscend: return sign * tributary_layer; + case StraightenOrder::kDistanceToOverlapAscend: return sign * min_distance2overlap; + case StraightenOrder::kLayerAscend: return sign * min_layer; + case StraightenOrder::kMemoryIncrementAscend: return sign * memory_increment; + case StraightenOrder::kActivationTimeAscend: return sign * activation_time; + default: return 0; } - return 0; } // Find the trunk of the task graph, then reduce the wait time for tributaries @@ -304,8 +380,28 @@ void FindTrunk(HashMap* task_node2topo_struct) { for (auto& pair : *task_node2topo_struct) { // Compute maximum layer for tributaries pair.second.SpreadTributaryLayer(task_node2topo_struct); - // Set the min_distance2transfer for each topological structure - pair.second.GetMinDistance2Transfer(task_node2topo_struct); + // Set the min_distance2overlap for each topological structure + pair.second.GetMinDistance2Overlap(task_node2topo_struct); + } +} + +void UpdateSat(const HashMap& task_node2topo_struct) { + sat = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); + if (sat == StraightenAlgorithmTag::kOverlap4CpuGpu) { + // If not cpu nodes, then the overlap strategy between cpu and gpu might consume large memory + bool exist_cpu_nodes = false; + for (const auto& pair : task_node2topo_struct) { + // Found a cpu node + if (pair.second.activation_time == 1) { + exist_cpu_nodes = true; + break; + } + } + if (!exist_cpu_nodes) { + // Switch to the compress memory strategy, the default one + // Since the overlap strategy for transfer might not be working on 1n1d. + sat = StraightenAlgorithmTag::kCompressMemory; + } } } @@ -326,6 +422,7 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task auto& topo_struct = task_node2topo_struct[node]; topo_struct.node = node; topo_struct.ComputeMeomoryIncrement(); + topo_struct.ComputeActivationTime(); if (node->in_edges().empty()) { topo_struct.min_layer = 0; } else { @@ -385,17 +482,19 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task // Generate other parameters in the topological data structure FindTrunk(&task_node2topo_struct); + // Update sat, since sat might be changed in previous jobs + UpdateSat(task_node2topo_struct); // Decide which node should run first - InitDecideParameters(GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph()); + InitDecideParameters(sat); VLOG(3) << "Straightening order: "; for (int32_t decide_parameter : decide_parameters) { VLOG(3) << decide_parameter; } // Order in the waiting sets struct comp { bool operator()(const TopoStruct* a, const TopoStruct* b) const { - for (int32_t decide_parameter : decide_parameters) { - int64_t decide_parameter_a = a->GetDecidingParameter(decide_parameter); - int64_t decide_parameter_b = b->GetDecidingParameter(decide_parameter); + for (auto decide_parameter : decide_parameters) { + auto decide_parameter_a = a->GetDecidingParameter(decide_parameter); + auto decide_parameter_b = b->GetDecidingParameter(decide_parameter); if (decide_parameter_a != decide_parameter_b) { return decide_parameter_a < decide_parameter_b; } @@ -405,10 +504,15 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task }; // Classify sets for the task nodes - // std::set waiting_transfer; // 0, TaskClassifier::kWaitingTransfer - // std::set waiting_computation; // 1, TaskClassifier::kWaitingComputation - // std::set run_asap; // 2, TaskClassifier::kRunASAP , run as soon as possible - // std::set run_alap; // 3, TaskClassifier::kRunALAP , run as late as possible + // 0, TaskClassifier::kWaitingOverlapNode + // It contains transfer nodes, and those with longer activation time in cpu if request. + // std::set waiting_overlap_node; + // 1, TaskClassifier::kWaitingMainComputation + // std::set waiting_main_computation; + // 2, TaskClassifier::kRunASAP , run as soon as possible + // std::set run_asap; + // 3, TaskClassifier::kRunALAP , run as late as possible + // std::set run_alap; const int32_t num_classifier = 4; std::vector> waiting_lists(num_classifier); @@ -502,8 +606,8 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task // straightening while (true) { if (waiting_lists[TaskClassifier::kRunASAP].empty()) { - if (waiting_lists[TaskClassifier::kWaitingTransfer].empty()) { - if (waiting_lists[TaskClassifier::kWaitingComputation].empty()) { + if (waiting_lists[TaskClassifier::kWaitingOverlapNode].empty()) { + if (waiting_lists[TaskClassifier::kWaitingMainComputation].empty()) { if (waiting_lists[TaskClassifier::kRunALAP].empty()) { // All the waiting lists are empty break; @@ -513,25 +617,25 @@ void StraightenNodes(TaskGraph* task_graph, std::vector* ordered_task } } else { // Execute one computation node - execute(TaskClassifier::kWaitingComputation, 1); + execute(TaskClassifier::kWaitingMainComputation, 1); } } else { int32_t computation_num = - std::min(int32_t(waiting_lists[TaskClassifier::kWaitingComputation].size() - / (waiting_lists[TaskClassifier::kWaitingTransfer].size())), - remain_task_nums[TaskClassifier::kWaitingComputation] - / remain_task_nums[TaskClassifier::kWaitingTransfer]); - // Holding the transfer - std::vector transfer_execution_list; - move2execution_list(waiting_lists[TaskClassifier::kWaitingTransfer], - transfer_execution_list); - remain_task_nums[TaskClassifier::kWaitingTransfer] -= transfer_execution_list.size(); - for (auto* transfer_node : transfer_execution_list) { SetOrderInGraph(transfer_node); } - // Overlap transfer with computation - execute(TaskClassifier::kWaitingComputation, computation_num); - - // Release the transfer - for (auto* transfer_node : transfer_execution_list) { finish_execution(transfer_node); } + std::min(int32_t(waiting_lists[TaskClassifier::kWaitingMainComputation].size() + / (waiting_lists[TaskClassifier::kWaitingOverlapNode].size())), + remain_task_nums[TaskClassifier::kWaitingMainComputation] + / remain_task_nums[TaskClassifier::kWaitingOverlapNode]); + // Holding the node to be overlapped + std::vector overlap_execution_list; + move2execution_list(waiting_lists[TaskClassifier::kWaitingOverlapNode], + overlap_execution_list); + remain_task_nums[TaskClassifier::kWaitingOverlapNode] -= overlap_execution_list.size(); + for (auto* overlap_node : overlap_execution_list) { SetOrderInGraph(overlap_node); } + // Overlap the node with computation from the trunk + execute(TaskClassifier::kWaitingMainComputation, computation_num); + + // Release the overlap node + for (auto* overlap_node : overlap_execution_list) { finish_execution(overlap_node); } } } else { execute(TaskClassifier::kRunASAP, waiting_lists[TaskClassifier::kRunASAP].size()); diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 76e4eb33e7d..8cf05e8871e 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/graph/inplace_lbi_graph.h" +#include "oneflow/core/job/job_conf.pb.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/global_for.h" @@ -897,12 +898,12 @@ void TaskGraph::DecideExecutionOrder() { // of memory StraightenAlgorithmTag straighten_algorithm_tag = GlobalJobDesc().job_conf().straighten_algorithm_tag_in_task_graph(); - if (straighten_algorithm_tag == StraightenAlgorithmTag::kCompressMemory - || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4ModelParallelism - && GlobalProcessCtx::WorldSize() > 1)) { - StraightenNodes(this, &ordered_task_nodes_); - } else { + if (straighten_algorithm_tag == StraightenAlgorithmTag::kDisable + || (straighten_algorithm_tag == StraightenAlgorithmTag::kOverlap4Transfer + && GlobalProcessCtx::WorldSize() == 1)) { SetOrderInGraphForEachNode(); + } else { + StraightenNodes(this, &ordered_task_nodes_); } } diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 5ab9e524818..2272b5d635e 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -207,8 +207,9 @@ message JobSignatureDef { enum StraightenAlgorithmTag { kDisable = 1; - kOverlap4ModelParallelism = 2; - kCompressMemory = 3; + kOverlap4Transfer = 2; + kCompressMemory = 3; + kOverlap4CpuGpu = 4; } message JobConfigProto { diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 0fea3f363d4..9400ea79f71 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -306,14 +306,27 @@ def enable_straighten_algorithm(self, mode: str = "MemoryFirst"): Under the third configuration, the straighten algorithm would try to compress memory as much as possible. It might save up to 13% of the memory for some models. And might save nothing for some models. + + straighten_algorithm_tag 4: OverlapCpuGpu + Under the forth configuration, the straighten algorithm would try to run the cpu nodes and gpu nodes alternately. + Such procedure would reduce the gaps of the execution on gpus. + It might speed up the training by 2%. + If no cpu nodes exist, the straighten_algorithm_tag would be switch to 3 automatically. """ - assert mode == "Disable" or mode == "SpeedFirst" or mode == "MemoryFirst" + assert ( + mode == "Disable" + or mode == "SpeedFirst" + or mode == "MemoryFirst" + or mode == "OverlapCpuGpu" + ) if mode == "Disable": self.proto.straighten_algorithm_tag_in_task_graph = 1 elif mode == "SpeedFirst": self.proto.straighten_algorithm_tag_in_task_graph = 2 - else: + elif mode == "MemoryFirst": self.proto.straighten_algorithm_tag_in_task_graph = 3 + else: + self.proto.straighten_algorithm_tag_in_task_graph = 4 def enable_auto_parallel(self, mode: bool = True): """If true, then graph will use the auto parallel algorithm to select a parallelism strategy.