From d7c11bcd1316f9aaf43253321cdd8765bc7881c5 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 21:33:55 +0800 Subject: [PATCH 01/30] Add a GetSbpSignature with use parallel num instead of parallel description --- oneflow/core/operator/operator.cpp | 12 ++++++++++++ oneflow/core/operator/operator.h | 9 +++++++++ oneflow/core/operator/user_op.cpp | 14 +++++++------- oneflow/core/operator/user_op.h | 2 +- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a48f166e278..4e7f44113a9 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -495,6 +495,7 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { return input_output_fastest_time_shape_; } +// TODO: Delete this function. We never use parallel_desc in GetSbpSignature Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { @@ -506,6 +507,17 @@ Maybe Operator::GetSbpSignaturesIf( return Maybe::Ok(); } +Maybe Operator::GetSbpSignaturesIf( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_num, sbp_sig_list)); + SbpSignatureBuilder() + .Broadcast(input_bns()) + .Broadcast(output_bns()) + .Build(sbp_sig_list->mutable_sbp_signature()->Add()); + return Maybe::Ok(); +} + Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index afed6ba8d2f..18ef8ed5fc7 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -173,6 +173,9 @@ class Operator { Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const; + Maybe GetSbpSignaturesIf( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const; virtual Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; @@ -212,11 +215,17 @@ class Operator { virtual Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; + // TODO: Delete this function. We never use parallel_desc in GetSbpSignature virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); } + virtual Maybe GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); + } virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index db311f600ad..b8765c22b7e 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -345,8 +345,9 @@ class UserOpSbpContext : public user_op::SbpContext { using ArgVec = std::vector>; UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list, - std::function(const std::string&)> LogicalBlobDesc4Ibn) - : op_(op), sbp_sig_list_(sbp_sig_list) { + std::function(const std::string&)> LogicalBlobDesc4Ibn, + int32_t parallel_num) + : op_(op), sbp_sig_list_(sbp_sig_list), parallel_num_(parallel_num) { const auto& user_op_conf = op->op_conf().user_conf(); for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) { const std::string& arg_name = it->first; @@ -375,14 +376,13 @@ class UserOpSbpContext : public user_op::SbpContext { DeviceType device_type() const override { return op_->device_type(); } - int64_t parallel_num() const override { - return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); - } + int64_t parallel_num() const override { return parallel_num_; } private: const UserOp* op_; SbpSignatureList* sbp_sig_list_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; + int32_t parallel_num_; }; class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext { @@ -876,10 +876,10 @@ Maybe UserOp::InferSbpSignature( Maybe UserOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; - UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn); + UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, parallel_num); JUST(val_->get_sbp_fn(&sbp_ctx)); // Add Broadcast for source user op tick input if (val_->op_def.input_size() == 1 && input_bns().size() == 1 diff --git a/oneflow/core/operator/user_op.h b/oneflow/core/operator/user_op.h index d0f39c8fce1..892111524cd 100644 --- a/oneflow/core/operator/user_op.h +++ b/oneflow/core/operator/user_op.h @@ -64,7 +64,7 @@ class UserOp final : public Operator { const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override; + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, From f30f29dd8b9103b5e44e4f98aed69206ab16554a Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:15:49 +0800 Subject: [PATCH 02/30] Get sbp_sig_list for each dimension of hierarchy --- oneflow/core/framework/sbp_infer_util.cpp | 9 +++++--- oneflow/core/framework/sbp_infer_util.h | 3 ++- oneflow/core/operator/operator.cpp | 26 ++++++++++++++++------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 3ec7562dd51..8211f98caf8 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -603,14 +603,17 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp } void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_num2sbp_sig_list, std::vector* nd_sbp_sig_list) { if (depth == dims) { nd_sbp_sig_list->push_back(nd_sbp_sig); } else { - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { + for (const auto& sbp_signature : + hierarchy_num2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth); - DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_num2sbp_sig_list, + nd_sbp_sig_list); } } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 21d7da6ae90..afff052a4b5 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -62,7 +62,8 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp int32_t sbp_axis); void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_num2sbp_sig_list, std::vector* nd_sbp_sig_list); // Compute storage for given NdSbp diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 4e7f44113a9..124c6257111 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -499,7 +499,7 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, sbp_sig_list)); + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) @@ -522,18 +522,27 @@ Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { // Get 1D sbp signature list - SbpSignatureList sbp_sig_list; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_list)); - CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0) - << op_name() << " gets no sbp signature from GetSbpSignaturesIf function!"; + HashMap hierarchy_num2sbp_sig_list; + for (int32_t hierarchy_num : *parallel_desc.hierarchy()) { + if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) { + auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num]; + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); + CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) + << op_name() + << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: " + << hierarchy_num; + } + } int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(sbp_sig_list.sbp_signature(0), &nd_sbp_sig); + SbpSignatureToNdSbpSignature(hierarchy_num2sbp_sig_list.begin()->second.sbp_signature(0), + &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); - DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(), + hierarchy_num2sbp_sig_list, nd_sbp_sig_list); return Maybe::Ok(); } @@ -845,7 +854,8 @@ Maybe Operator::InferSbpSignature( SbpSignatureList valid_sbp_sig_list; { SbpSignatureList sbp_sig_candidates; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_candidates)); + JUST( + GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates)); // filter sbp signatures by logical shape JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn, parallel_desc, &valid_sbp_sig_list)); From fdc7ee8558cab68aa9fa152cf1ba2a6dc2b4554e Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:20:21 +0800 Subject: [PATCH 03/30] Add test script and print out information --- oneflow/core/operator/operator.cpp | 55 +++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 124c6257111..e2a05e0f548 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -37,6 +38,12 @@ namespace oneflow { namespace { +std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { + std::ostringstream out; + out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); + return out.str(); +} + DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -787,6 +794,23 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]); sum_priority_ratio += priority_ratio; + + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") + != std::string::npos) { + if (i == 0) { + std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) + << ", placement: " + << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) + << std::endl; + std::cout << "Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() + << std::endl; + } + std::cout << "idx: " << i << ", sbp: " + << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) + << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) + << std::endl; + } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -820,21 +844,26 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - if (select_sbp_idx == -1) { - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; + + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") + != std::string::npos) { + std::cout << err.str() << std::endl; + std::cout << "select idx: " << select_sbp_idx << std::endl; } - return Error::RuntimeError() << err.str(); + if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From e1b4a96cf9c14b265b953ac4207d3f77e720a76f Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:47:33 +0800 Subject: [PATCH 04/30] Remove parallel description in GetSbpSignature() --- .../optimizer_placement_optimization_pass.cpp | 3 ++- oneflow/core/operator/dynamic_reshape_op.cpp | 4 ++-- oneflow/core/operator/operator.cpp | 12 ------------ oneflow/core/operator/operator.h | 9 --------- 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 2c6e16a8bb8..e62d2e7f3c5 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -297,7 +297,8 @@ bool IsS0SignatureSupported(const OpNode* node) { auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; - CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); + CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, + node->parallel_desc().parallel_num(), &list)); const auto IsInOutS0Parallel = [&](const SbpSignature& signature) { return IsS0Parallel(signature, node->op().SoleIbn()) && IsS0Parallel(signature, node->op().SoleObn()); diff --git a/oneflow/core/operator/dynamic_reshape_op.cpp b/oneflow/core/operator/dynamic_reshape_op.cpp index 34e90416d96..72ee5dc47c8 100644 --- a/oneflow/core/operator/dynamic_reshape_op.cpp +++ b/oneflow/core/operator/dynamic_reshape_op.cpp @@ -104,7 +104,7 @@ class DynamicReshapeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) @@ -144,7 +144,7 @@ class DynamicReshapeLikeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index e2a05e0f548..788c1b95d8f 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -502,18 +502,6 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { return input_output_fastest_time_shape_; } -// TODO: Delete this function. We never use parallel_desc in GetSbpSignature -Maybe Operator::GetSbpSignaturesIf( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); - SbpSignatureBuilder() - .Broadcast(input_bns()) - .Broadcast(output_bns()) - .Build(sbp_sig_list->mutable_sbp_signature()->Add()); - return Maybe::Ok(); -} - Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 18ef8ed5fc7..6fa19f13068 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -170,9 +170,6 @@ class Operator { Maybe NdSbp4BnInOp(const std::string& bn_in_op) const; Maybe OptLocalParallel4BnInOp(const std::string& bn_in_op) const; - Maybe GetSbpSignaturesIf( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const; Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const; @@ -215,12 +212,6 @@ class Operator { virtual Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; - // TODO: Delete this function. We never use parallel_desc in GetSbpSignature - virtual Maybe GetSbpSignatures( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); - } virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { From dc23ff7802dc4227a7204744f06ef555c63a0e76 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 23:09:59 +0800 Subject: [PATCH 05/30] Fix small bug --- oneflow/core/operator/operator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 788c1b95d8f..fc3300675cb 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -521,7 +521,7 @@ Maybe Operator::GetNdSbpSignatureList( for (int32_t hierarchy_num : *parallel_desc.hierarchy()) { if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) { auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num]; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_num, sbp_sig_list)); CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) << op_name() << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: " From 195b0ea149c77374737751356b97f6bf2da240ff Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 23:15:49 +0800 Subject: [PATCH 06/30] Disable InferNdSbp for reshape op --- oneflow/user/ops/reshape_user_op_util.cpp | 99 ----------------------- 1 file changed, 99 deletions(-) diff --git a/oneflow/user/ops/reshape_user_op_util.cpp b/oneflow/user/ops/reshape_user_op_util.cpp index 32fab5354e9..78dab917c98 100644 --- a/oneflow/user/ops/reshape_user_op_util.cpp +++ b/oneflow/user/ops/reshape_user_op_util.cpp @@ -174,103 +174,4 @@ Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( return Maybe::Ok(); } -namespace { - -Maybe GetInputNdSbp(user_op::InferNdSbpFnContext* ctx, const user_op::OpArg& in_arg, - NdSbp* distribution) { - *distribution = ctx->NdSbpHint4InputArgNameAndIndex(in_arg.name(), in_arg.index()); - const auto& constraints = ctx->nd_sbp_constraints(); - if (constraints.bn_in_op2nd_sbp_size() != 0) { - const auto it = - constraints.bn_in_op2nd_sbp().find(GenRepeatedBn(in_arg.name(), in_arg.index())); - if (it != constraints.bn_in_op2nd_sbp().end()) { *distribution = it->second; } - } - return Maybe::Ok(); -} - -Maybe ApplySbpParallel(const SbpParallel& sbp, const int64_t parallel_num, Shape* shape) { - if (sbp.has_split_parallel()) { - const int64_t axis = sbp.split_parallel().axis(); - CHECK_EQ_OR_RETURN(shape->At(axis) % parallel_num, 0) - << Error::RuntimeError() << "The size of tensor in the " << axis - << " must be an integer multiple of parallel_num, " - << "but got " << shape->At(axis) << " and " << parallel_num; - shape->Set(axis, shape->At(axis) / parallel_num); - } - return Maybe::Ok(); -} - -} // namespace - -Maybe ReshapeUserOpUtil::InferNdSbp(user_op::InferNdSbpFnContext* ctx, - const Shape& logical_in_shape, - const Shape& logical_out_shape) { - const std::string& op_type_name = ctx->user_op_conf().op_type_name(); - CHECK_OR_RETURN(op_type_name == "reshape" || op_type_name == "reshape_like") - << Error::RuntimeError() << "The op_type_name must be \"reshape\" or \"reshape_like\", " - << "but got " << op_type_name; - const bool is_reshape_like = (op_type_name == "reshape_like"); - std::vector in_args({{"in", 0}}); - if (is_reshape_like) { in_args.emplace_back(user_op::OpArg("like", 0)); } - HashMap ibn2nd_sbp; - ibn2nd_sbp.reserve(in_args.size()); - for (const auto& arg : in_args) { - NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index()); - JUST(GetInputNdSbp(ctx, arg, in_distribution)); - CHECK_OR_RETURN( - ibn2nd_sbp.emplace(GenRepeatedBn(arg.name(), arg.index()), *in_distribution).second) - << "emplace error"; // NOLINT(maybe-need-error-msg) - } - NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - - Shape in_shape = logical_in_shape; - Shape out_shape = logical_out_shape; - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - for (int64_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - SbpSignatureList sbp_sig_list; - user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list); - builder.Broadcast(in_args).Broadcast(user_op::OpArg("out", 0)).Build(); - if (is_reshape_like) { - builder.PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - builder.Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, - {{"like", 0}, {"out", 0}}, parallel_hierarchy.At(i), - &builder)); - } else { - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, {{"out", 0}}, - parallel_hierarchy.At(i), &builder)); - } - - const SbpSignature* matched_sbp_signature = nullptr; - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { - bool all_match = true; - for (const auto& in_arg : in_args) { - std::string ibn = GenRepeatedBn(in_arg.name(), in_arg.index()); - if (sbp_signature.bn_in_op2sbp_parallel().at(ibn) != ibn2nd_sbp.at(ibn).sbp_parallel(i)) { - all_match = false; - break; - } - } - if (all_match) { - matched_sbp_signature = &sbp_signature; - break; - } - } - CHECK_OR_RETURN(matched_sbp_signature != nullptr) - << "FusedLstmCellGrad::Pointer to the matched sbp signature is nullptr"; - SbpParallel out_sbp = matched_sbp_signature->bn_in_op2sbp_parallel().at("out_0"); - JUST(ApplySbpParallel(matched_sbp_signature->bn_in_op2sbp_parallel().at("in_0"), - parallel_hierarchy.At(i), &in_shape)); - JUST(ApplySbpParallel(out_sbp, parallel_hierarchy.At(i), &out_shape)); - *(out_distribution->add_sbp_parallel()) = out_sbp; - } - return Maybe::Ok(); -} - } // namespace oneflow From f7d29d12052fd040a87b9c5dfc2ebeafabb1b78e Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 14:29:05 +0800 Subject: [PATCH 07/30] Revert "Add test script and print out information" This reverts commit fdc7ee8558cab68aa9fa152cf1ba2a6dc2b4554e. --- oneflow/core/operator/operator.cpp | 55 +++++++----------------------- 1 file changed, 13 insertions(+), 42 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index fc3300675cb..d99127a1286 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,7 +17,6 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" -#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -38,12 +37,6 @@ namespace oneflow { namespace { -std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { - std::ostringstream out; - out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); - return out.str(); -} - DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -782,23 +775,6 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]); sum_priority_ratio += priority_ratio; - - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") - != std::string::npos) { - if (i == 0) { - std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) - << ", placement: " - << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) - << std::endl; - std::cout << "Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() - << std::endl; - } - std::cout << "idx: " << i << ", sbp: " - << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) - << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) - << std::endl; - } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -832,26 +808,21 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; - - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") - != std::string::npos) { - std::cout << err.str() << std::endl; - std::cout << "select idx: " << select_sbp_idx << std::endl; + if (select_sbp_idx == -1) { + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; } - if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } + return Error::RuntimeError() << err.str(); } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From f20e222327e21166d5b5325e37c3cbe9ca4f4ac6 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 17:12:12 +0800 Subject: [PATCH 08/30] Use the same physical shape as eager did --- oneflow/core/job/nd_sbp_util.cpp | 88 ++++++++++++++++++++++++++++-- oneflow/core/job/nd_sbp_util.h | 2 + oneflow/core/operator/operator.cpp | 75 ++----------------------- 3 files changed, 89 insertions(+), 76 deletions(-) diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index c93974acc18..55b72ebac9f 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -17,9 +17,84 @@ limitations under the License. #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/placement_sbp_util.h" +#include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { +namespace { + +template +Maybe Get1dHierarchyPhysicalShape(const Shape& logical_shape, + const SbpParallelT& sbp_parallel, + const int64_t parallel_num, const int64_t parallel_id) { + std::shared_ptr physical = std::make_shared(logical_shape); + + if (sbp_parallel.has_split_parallel()) { + const int64_t axis = sbp_parallel.split_parallel().axis(); + if (logical_shape.At(axis) > 0) { + CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num); + const BalancedSplitter bs(logical_shape.At(axis), parallel_num); + physical->Set(axis, bs.At(parallel_id).size()); + } + } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) { + // do nothing + } else { + UNIMPLEMENTED(); + } + return physical; +} + +Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, + const Shape& parallel_hierarchy, + const int64_t parallel_id) { + std::shared_ptr physical = std::make_shared(logical_shape); + Stride hierarch_stride(parallel_hierarchy); + FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { + const auto& sbp_parallel = nd_sbp.sbp_parallel(i); + if (sbp_parallel.has_split_parallel()) { + const int64_t split_axis = sbp_parallel.split_parallel().axis(); + if (LazyMode::is_enabled()) { + CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) + << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis + << ") of logical shape must be divisible by parallel num, but got logical_shape: " + << logical_shape.ToString() << ", placement: " << parallel_hierarchy + << ", nd_sbp: " << NdSbpToString(nd_sbp); + physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); + } else { + if (physical->At(split_axis) > 0) { + CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) + << Error::RuntimeError() << "Expected size at split axis (" << split_axis + << ") of logical shape must be be greater than or equal to parallel num, but got " + "logical_shape: " + << logical_shape.ToString() << ", placement: " << parallel_hierarchy + << ", nd_sbp: " << NdSbpToString(nd_sbp); + const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); + physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); + } + } + } + } + return physical; +} + +} // namespace + +Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, + const Shape& parallel_hierarchy, int64_t parallel_id) { + CHECK_GE_OR_RETURN(parallel_id, 0); + CHECK_LT_OR_RETURN(parallel_id, parallel_hierarchy.elem_cnt()); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), nd_sbp.sbp_parallel_size()); + if (parallel_hierarchy.NumAxes() == 1) { + return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0), + parallel_hierarchy.elem_cnt(), parallel_id); + } else { + return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_hierarchy, parallel_id); + } +} + std::vector GetTensorSliceView(const int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc) { @@ -97,11 +172,14 @@ TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, TensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, int64_t parallel_id) { - NdIndexOffsetHelper hierarchy_index_helper( - parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes()); - std::vector parallel_rank(SHAPE_MAX_AXIS_SIZE); - hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data()); - return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, parallel_rank); + // NdIndexOffsetHelper hierarchy_index_helper( + // parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes()); + // std::vector parallel_rank(SHAPE_MAX_AXIS_SIZE); + // hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data()); + // return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, + // parallel_rank); + return TensorSliceView( + *CHECK_JUST(GetPhysicalShape(logical_shape, nd_sbp, parallel_hierarchy, parallel_id))); } std::vector GetTensorSliceView(const Shape& parallel_hierarchy, diff --git a/oneflow/core/job/nd_sbp_util.h b/oneflow/core/job/nd_sbp_util.h index be8b72c7746..09d93758f77 100644 --- a/oneflow/core/job/nd_sbp_util.h +++ b/oneflow/core/job/nd_sbp_util.h @@ -21,6 +21,8 @@ limitations under the License. namespace oneflow { +Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, + const Shape& parallel_hierarchy, int64_t parallel_id); std::vector GetTensorSliceView(int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index d99127a1286..45f98c9face 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -1608,84 +1609,16 @@ Maybe ConstructAndInferOp(const OperatorConf& op_conf, return op; } -namespace { - -template -Maybe Get1dHierarchyPhysicalShape(const Shape& logical_shape, - const SbpParallelT& sbp_parallel, - const int64_t parallel_num, const int64_t parallel_id) { - std::shared_ptr physical = std::make_shared(logical_shape); - - if (sbp_parallel.has_split_parallel()) { - const int64_t axis = sbp_parallel.split_parallel().axis(); - if (logical_shape.At(axis) > 0) { - CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num); - const BalancedSplitter bs(logical_shape.At(axis), parallel_num); - physical->Set(axis, bs.At(parallel_id).size()); - } - } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) { - // do nothing - } else { - UNIMPLEMENTED(); - } - return physical; -} - -Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, - const ParallelDesc& parallel_desc, - const int64_t parallel_id) { - const auto& parallel_hierarchy = *parallel_desc.hierarchy(); - std::shared_ptr physical = std::make_shared(logical_shape); - Stride hierarch_stride(parallel_hierarchy); - FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { - const auto& sbp_parallel = nd_sbp.sbp_parallel(i); - if (sbp_parallel.has_split_parallel()) { - const int64_t split_axis = sbp_parallel.split_parallel().axis(); - if (LazyMode::is_enabled()) { - CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) - << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis - << ") of logical shape must be divisible by parallel num, but got logical_shape: " - << logical_shape.ToString() - << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) - << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); - physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); - } else { - if (physical->At(split_axis) > 0) { - CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) - << Error::RuntimeError() << "Expected size at split axis (" << split_axis - << ") of logical shape must be be greater than or equal to parallel num, but got " - "logical_shape: " - << logical_shape.ToString() - << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) - << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); - const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); - physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); - } - } - } - } - return physical; -} - -} // namespace - Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, int64_t parallel_id) { - CHECK_GE_OR_RETURN(parallel_id, 0); - CHECK_LT_OR_RETURN(parallel_id, parallel_desc.hierarchy()->elem_cnt()); - CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), nd_sbp.sbp_parallel_size()); - if (parallel_desc.hierarchy()->NumAxes() == 1) { - return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0), - parallel_desc.hierarchy()->elem_cnt(), parallel_id); - } else { - return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_id); - } + return GetPhysicalShape(logical_shape, nd_sbp, *parallel_desc.hierarchy(), parallel_id); } Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext& parallel_ctx) { - return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_ctx.parallel_id()); + return GetPhysicalShape(logical_shape, nd_sbp, *parallel_desc.hierarchy(), + parallel_ctx.parallel_id()); } Maybe GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp, From cf332cef945ffa9a957939c6ec12750b85a6f54e Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 17:17:03 +0800 Subject: [PATCH 09/30] Remove the difference between eager and lazy for physical shape --- oneflow/core/job/job_build_and_infer_ctx.cpp | 2 +- oneflow/core/job/nd_sbp_util.cpp | 33 ++++++++++---------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index b2b2b5c1260..95ddd5b3532 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -583,7 +583,7 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con } JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn)); // Check splitability - JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num())); + // JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num())); return op->GetOpAttributeWithoutOpNameAndLbn(); } diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index 55b72ebac9f..cdf5a65f404 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -56,25 +56,26 @@ Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); - if (LazyMode::is_enabled()) { - CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) - << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis - << ") of logical shape must be divisible by parallel num, but got logical_shape: " + // if (LazyMode::is_enabled()) { + // CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) + // << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << + // split_axis + // << ") of logical shape must be divisible by parallel num, but got logical_shape: " + // << logical_shape.ToString() << ", placement: " << parallel_hierarchy + // << ", nd_sbp: " << NdSbpToString(nd_sbp); + // physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); + // } else { + if (physical->At(split_axis) > 0) { + CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) + << Error::RuntimeError() << "Expected size at split axis (" << split_axis + << ") of logical shape must be be greater than or equal to parallel num, but got " + "logical_shape: " << logical_shape.ToString() << ", placement: " << parallel_hierarchy << ", nd_sbp: " << NdSbpToString(nd_sbp); - physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); - } else { - if (physical->At(split_axis) > 0) { - CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) - << Error::RuntimeError() << "Expected size at split axis (" << split_axis - << ") of logical shape must be be greater than or equal to parallel num, but got " - "logical_shape: " - << logical_shape.ToString() << ", placement: " << parallel_hierarchy - << ", nd_sbp: " << NdSbpToString(nd_sbp); - const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); - physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); - } + const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); + physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); } + // } } } return physical; From aa42b6d0472305d72f2d49d35f155dd4c9c189f4 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 17:21:45 +0800 Subject: [PATCH 10/30] Update the filter --- oneflow/core/framework/sbp_infer_util.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 8211f98caf8..dfb9f3b8698 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -639,10 +639,7 @@ double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& par const int64_t axis = sbp_parallel.split_parallel().axis(); if (axis >= logical_shape.NumAxes()) { return kUnsupportedBoxing; } // Use completely average split to count the storage - if (logical_shape.At(axis) <= 0 - || (logical_shape.At(axis) % parallel_hierarchy.At(dim_sbp) > 0)) { - return kUnsupportedBoxing; - } + if (logical_shape.At(axis) < parallel_hierarchy.At(dim_sbp)) { return kUnsupportedBoxing; } logical_shape.Set(axis, logical_shape.At(axis) / parallel_hierarchy.At(dim_sbp)); } } From 6e877628007280bbe618183d2dbacc27b1e7133a Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 18:28:21 +0800 Subject: [PATCH 11/30] Revert "Use the same physical shape as eager did" This reverts commit f20e222327e21166d5b5325e37c3cbe9ca4f4ac6. --- oneflow/core/job/nd_sbp_util.cpp | 89 ++---------------------------- oneflow/core/job/nd_sbp_util.h | 2 - oneflow/core/operator/operator.cpp | 75 +++++++++++++++++++++++-- 3 files changed, 76 insertions(+), 90 deletions(-) diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index cdf5a65f404..c93974acc18 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -17,85 +17,9 @@ limitations under the License. #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/nd_index_offset_helper.h" -#include "oneflow/core/framework/nd_sbp.h" -#include "oneflow/core/framework/placement_sbp_util.h" -#include "oneflow/core/job/lazy_mode.h" -#include "oneflow/core/register/tensor_slice_view.h" namespace oneflow { -namespace { - -template -Maybe Get1dHierarchyPhysicalShape(const Shape& logical_shape, - const SbpParallelT& sbp_parallel, - const int64_t parallel_num, const int64_t parallel_id) { - std::shared_ptr physical = std::make_shared(logical_shape); - - if (sbp_parallel.has_split_parallel()) { - const int64_t axis = sbp_parallel.split_parallel().axis(); - if (logical_shape.At(axis) > 0) { - CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num); - const BalancedSplitter bs(logical_shape.At(axis), parallel_num); - physical->Set(axis, bs.At(parallel_id).size()); - } - } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) { - // do nothing - } else { - UNIMPLEMENTED(); - } - return physical; -} - -Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, - const Shape& parallel_hierarchy, - const int64_t parallel_id) { - std::shared_ptr physical = std::make_shared(logical_shape); - Stride hierarch_stride(parallel_hierarchy); - FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { - const auto& sbp_parallel = nd_sbp.sbp_parallel(i); - if (sbp_parallel.has_split_parallel()) { - const int64_t split_axis = sbp_parallel.split_parallel().axis(); - // if (LazyMode::is_enabled()) { - // CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) - // << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << - // split_axis - // << ") of logical shape must be divisible by parallel num, but got logical_shape: " - // << logical_shape.ToString() << ", placement: " << parallel_hierarchy - // << ", nd_sbp: " << NdSbpToString(nd_sbp); - // physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); - // } else { - if (physical->At(split_axis) > 0) { - CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) - << Error::RuntimeError() << "Expected size at split axis (" << split_axis - << ") of logical shape must be be greater than or equal to parallel num, but got " - "logical_shape: " - << logical_shape.ToString() << ", placement: " << parallel_hierarchy - << ", nd_sbp: " << NdSbpToString(nd_sbp); - const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); - physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); - } - // } - } - } - return physical; -} - -} // namespace - -Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, - const Shape& parallel_hierarchy, int64_t parallel_id) { - CHECK_GE_OR_RETURN(parallel_id, 0); - CHECK_LT_OR_RETURN(parallel_id, parallel_hierarchy.elem_cnt()); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), nd_sbp.sbp_parallel_size()); - if (parallel_hierarchy.NumAxes() == 1) { - return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0), - parallel_hierarchy.elem_cnt(), parallel_id); - } else { - return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_hierarchy, parallel_id); - } -} - std::vector GetTensorSliceView(const int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc) { @@ -173,14 +97,11 @@ TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, TensorSliceView GetTensorSliceView4ParallelId(const Shape& parallel_hierarchy, const NdSbp& nd_sbp, const Shape& logical_shape, int64_t parallel_id) { - // NdIndexOffsetHelper hierarchy_index_helper( - // parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes()); - // std::vector parallel_rank(SHAPE_MAX_AXIS_SIZE); - // hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data()); - // return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, - // parallel_rank); - return TensorSliceView( - *CHECK_JUST(GetPhysicalShape(logical_shape, nd_sbp, parallel_hierarchy, parallel_id))); + NdIndexOffsetHelper hierarchy_index_helper( + parallel_hierarchy.dim_vec().data(), parallel_hierarchy.NumAxes()); + std::vector parallel_rank(SHAPE_MAX_AXIS_SIZE); + hierarchy_index_helper.OffsetToNdIndex(parallel_id, parallel_rank.data()); + return GetTensorSliceView4ParallelRank(parallel_hierarchy, nd_sbp, logical_shape, parallel_rank); } std::vector GetTensorSliceView(const Shape& parallel_hierarchy, diff --git a/oneflow/core/job/nd_sbp_util.h b/oneflow/core/job/nd_sbp_util.h index 09d93758f77..be8b72c7746 100644 --- a/oneflow/core/job/nd_sbp_util.h +++ b/oneflow/core/job/nd_sbp_util.h @@ -21,8 +21,6 @@ limitations under the License. namespace oneflow { -Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, - const Shape& parallel_hierarchy, int64_t parallel_id); std::vector GetTensorSliceView(int64_t parallel_num, const SbpParallel& sbp_parallel, const BlobDesc& blob_desc); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 45f98c9face..d99127a1286 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,7 +17,6 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" -#include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -1609,16 +1608,84 @@ Maybe ConstructAndInferOp(const OperatorConf& op_conf, return op; } +namespace { + +template +Maybe Get1dHierarchyPhysicalShape(const Shape& logical_shape, + const SbpParallelT& sbp_parallel, + const int64_t parallel_num, const int64_t parallel_id) { + std::shared_ptr physical = std::make_shared(logical_shape); + + if (sbp_parallel.has_split_parallel()) { + const int64_t axis = sbp_parallel.split_parallel().axis(); + if (logical_shape.At(axis) > 0) { + CHECK_GE_OR_RETURN(logical_shape.At(axis), parallel_num); + const BalancedSplitter bs(logical_shape.At(axis), parallel_num); + physical->Set(axis, bs.At(parallel_id).size()); + } + } else if (sbp_parallel.has_broadcast_parallel() || sbp_parallel.has_partial_sum_parallel()) { + // do nothing + } else { + UNIMPLEMENTED(); + } + return physical; +} + +Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, + const ParallelDesc& parallel_desc, + const int64_t parallel_id) { + const auto& parallel_hierarchy = *parallel_desc.hierarchy(); + std::shared_ptr physical = std::make_shared(logical_shape); + Stride hierarch_stride(parallel_hierarchy); + FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { + const auto& sbp_parallel = nd_sbp.sbp_parallel(i); + if (sbp_parallel.has_split_parallel()) { + const int64_t split_axis = sbp_parallel.split_parallel().axis(); + if (LazyMode::is_enabled()) { + CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) + << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis + << ") of logical shape must be divisible by parallel num, but got logical_shape: " + << logical_shape.ToString() + << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) + << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); + physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); + } else { + if (physical->At(split_axis) > 0) { + CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) + << Error::RuntimeError() << "Expected size at split axis (" << split_axis + << ") of logical shape must be be greater than or equal to parallel num, but got " + "logical_shape: " + << logical_shape.ToString() + << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) + << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); + const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); + physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); + } + } + } + } + return physical; +} + +} // namespace + Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, int64_t parallel_id) { - return GetPhysicalShape(logical_shape, nd_sbp, *parallel_desc.hierarchy(), parallel_id); + CHECK_GE_OR_RETURN(parallel_id, 0); + CHECK_LT_OR_RETURN(parallel_id, parallel_desc.hierarchy()->elem_cnt()); + CHECK_EQ_OR_RETURN(parallel_desc.hierarchy()->NumAxes(), nd_sbp.sbp_parallel_size()); + if (parallel_desc.hierarchy()->NumAxes() == 1) { + return Get1dHierarchyPhysicalShape(logical_shape, nd_sbp.sbp_parallel(0), + parallel_desc.hierarchy()->elem_cnt(), parallel_id); + } else { + return GetNdHierarchyPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_id); + } } Maybe GetPhysicalShape(const Shape& logical_shape, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext& parallel_ctx) { - return GetPhysicalShape(logical_shape, nd_sbp, *parallel_desc.hierarchy(), - parallel_ctx.parallel_id()); + return GetPhysicalShape(logical_shape, nd_sbp, parallel_desc, parallel_ctx.parallel_id()); } Maybe GetLogicalShape(const Shape& physical_shape, const NdSbp& nd_sbp, From 0f1554dc9b08757f78bb4006e5e8d6fdab96eb8b Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 18:40:07 +0800 Subject: [PATCH 12/30] Compute range for each rank --- oneflow/core/job/nd_sbp_util.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index c93974acc18..f54d1463377 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -85,10 +85,11 @@ TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, CHECK_GE(split_axis, 0); CHECK_LT(split_axis, ranges.size()); CHECK_EQ(ranges[split_axis].size() % parallel_hierarchy.At(i), 0); - const int64_t range_size = ranges[split_axis].size() / parallel_hierarchy.At(i); - const int64_t dim_start = ranges[split_axis].begin() + parallel_rank.at(i) * range_size; + const BalancedSplitter bs(logical_shape.At(split_axis), parallel_hierarchy.At(i)); + const auto& range = bs.At(parallel_rank.at(i)); + const int64_t dim_start = ranges[split_axis].begin() + range.begin(); ranges[split_axis].mut_begin() = dim_start; - ranges[split_axis].mut_end() = dim_start + range_size; + ranges[split_axis].mut_end() = dim_start + range.size(); } } } From 4d7edc7ab2708aaca85a0ab45674c85561619ebe Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 10:56:25 +0000 Subject: [PATCH 13/30] Compute position for range --- oneflow/core/job/nd_sbp_util.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index f54d1463377..78409fbf140 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -78,18 +78,21 @@ TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy, ranges[split_axis] = bs.At(id); } } else { + Shape physical_shape(logical_shape); FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) { const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); CHECK_GE(split_axis, 0); CHECK_LT(split_axis, ranges.size()); - CHECK_EQ(ranges[split_axis].size() % parallel_hierarchy.At(i), 0); - const BalancedSplitter bs(logical_shape.At(split_axis), parallel_hierarchy.At(i)); + CHECK_GE(ranges[split_axis].size(), parallel_hierarchy.At(i)); + const BalancedSplitter bs(physical_shape.At(split_axis), parallel_hierarchy.At(i)); const auto& range = bs.At(parallel_rank.at(i)); + const int64_t range_size = range.size(); const int64_t dim_start = ranges[split_axis].begin() + range.begin(); + physical_shape.Set(split_axis, range_size); ranges[split_axis].mut_begin() = dim_start; - ranges[split_axis].mut_end() = dim_start + range.size(); + ranges[split_axis].mut_end() = dim_start + range_size; } } } From b2249c3da5933ce008f39ac33dd312a483990412 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 10:59:09 +0000 Subject: [PATCH 14/30] Remove the difference between eager and lazy --- oneflow/core/operator/operator.cpp | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index d99127a1286..b225861d985 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -1641,27 +1641,28 @@ Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); - if (LazyMode::is_enabled()) { - CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) - << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis - << ") of logical shape must be divisible by parallel num, but got logical_shape: " + // if (LazyMode::is_enabled()) { + // CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) + // << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << + // split_axis + // << ") of logical shape must be divisible by parallel num, but got logical_shape: " + // << logical_shape.ToString() + // << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) + // << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); + // physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); + // } else { + if (physical->At(split_axis) > 0) { + CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) + << Error::RuntimeError() << "Expected size at split axis (" << split_axis + << ") of logical shape must be be greater than or equal to parallel num, but got " + "logical_shape: " << logical_shape.ToString() << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); - physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); - } else { - if (physical->At(split_axis) > 0) { - CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) - << Error::RuntimeError() << "Expected size at split axis (" << split_axis - << ") of logical shape must be be greater than or equal to parallel num, but got " - "logical_shape: " - << logical_shape.ToString() - << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) - << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); - const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); - physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); - } + const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); + physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); } + // } } } return physical; From 3c1362b1c81930670485532a7f13931e0e85ac1d Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 23:43:51 +0800 Subject: [PATCH 15/30] Allow unbalanced split for variables --- .../job_rewriter/optimizer_placement_optimization_pass.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index e62d2e7f3c5..2c669e61247 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -338,8 +338,7 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy if (sbp.has_split_parallel()) { const int64_t dim = sbp.split_parallel().axis(); if (dim >= cur_shape.NumAxes()) { return false; } - // Evenly split. - if (cur_shape.At(dim) % hierachy.At(i) != 0) { return false; } + // Unbalanced split and take the minimum cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i)); // Larger then min size. if (cur_shape.elem_cnt() < min_size) { return false; } From 58cdfb40b6536eb74c02174d3a69409676da374f Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 12:32:29 +0000 Subject: [PATCH 16/30] Add test script and print out information --- oneflow/core/graph/exec_graph.cpp | 7 ++-- oneflow/core/operator/operator.cpp | 55 ++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index 2c47076abc7..cd9668cb05a 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -76,8 +76,11 @@ namespace { Maybe CheckPhysicalBlobDesc(const BlobDesc& logical, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext* parallel_ctx, const BlobDesc& physical) { - CHECK_EQ_OR_RETURN(physical.shape(), *JUST(GetPhysicalShape(logical.shape(), nd_sbp, - parallel_desc, *parallel_ctx))); + auto& rhs = *JUST(GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx)); + CHECK_EQ_OR_RETURN(physical.shape(), + *JUST(GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx))) + << ", parallel num: " << parallel_ctx->parallel_id() << ", logical shape: " << logical.shape() + << ", lhs: " << physical.shape() << ", rhs: " << rhs; return Maybe::Ok(); } diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index b225861d985..cdd610ae729 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -37,6 +38,12 @@ namespace oneflow { namespace { +std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { + std::ostringstream out; + out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); + return out.str(); +} + DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -775,6 +782,22 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]); sum_priority_ratio += priority_ratio; + + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.embedding.word_embeddings.weight") + != std::string::npos) { + if (i == 0) { + std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) + << ", placement: " + << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) + << ", Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() + << std::endl; + } + std::cout << "idx: " << i << ", sbp: " + << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) + << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) + << std::endl; + } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -808,21 +831,25 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - if (select_sbp_idx == -1) { - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; - } + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; + + if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } + } - return Error::RuntimeError() << err.str(); + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.embedding.word_embeddings.weight") != std::string::npos) { + std::cout << err.str() << std::endl; + std::cout << "select idx: " << select_sbp_idx << std::endl; } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From 9a974e12abcdead838bc9d1e6d34688fd65fbb96 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 12:33:37 +0000 Subject: [PATCH 17/30] Pass 2d test cases --- python/oneflow/test/graph/test_gbc2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/graph/test_gbc2d.py b/python/oneflow/test/graph/test_gbc2d.py index d08ce287d17..0e459e3e298 100644 --- a/python/oneflow/test/graph/test_gbc2d.py +++ b/python/oneflow/test/graph/test_gbc2d.py @@ -53,7 +53,7 @@ def _test_general_basic_communication_same_placement(test_case, src_nd_sbp, dst_ # input placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) - local_np = np.arange(4 * 4).reshape(4, 4) + local_np = np.arange(4 * 5).reshape(4, 5) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) # check eager boxing From 9243427f23f5e2cda7533bc1f6e462eeb27d5a3f Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 27 Oct 2022 09:44:30 +0000 Subject: [PATCH 18/30] Resolve conflict --- oneflow/core/operator/user_op.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index c0f1c0f38b4..3845da8f566 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -376,7 +376,9 @@ class UserOpSbpContext : public user_op::SbpContext { DeviceType device_type() const override { return op_->device_type(); } - int64_t parallel_num() const override { return parallel_num_; } + int64_t parallel_num() const override { + return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); + } int64_t hierarchy_value() const override { return hierarchy_value_; } From cca3adb6c9a77e116e445ab9ca63ca21d303e252 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 28 Oct 2022 18:58:26 +0800 Subject: [PATCH 19/30] Can not merge some split --- .../core/boxing/nd_sbp_dim_reduce_boxing.cpp | 10 +- oneflow/core/framework/sbp_infer_util.cpp | 142 +++++++++++++++--- oneflow/core/framework/sbp_infer_util.h | 10 +- .../group_boxing_by_dst_parallel.cpp | 5 +- .../insert_nccl_logical_op_pass.cpp | 2 +- oneflow/core/operator/operator.cpp | 2 +- 6 files changed, 134 insertions(+), 37 deletions(-) diff --git a/oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp b/oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp index 4e028ce538b..4c507a1edbd 100644 --- a/oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp +++ b/oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp @@ -27,7 +27,7 @@ namespace oneflow { namespace { Maybe, Symbol>> RawInOutPlacedNdSbpDimReduce( - Symbol in, Symbol out) { + Symbol in, Symbol out, const Shape& logical_shape) { // reduce hierarchy ParallelDesc reduced_in_placement = *in->placement(); ParallelDesc reduced_out_placement = *out->placement(); @@ -35,14 +35,14 @@ Maybe, Symbol>> RawInOutPlacedNdSbpD NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(), &reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp, - &reduced_out_nd_sbp); + &reduced_out_nd_sbp, logical_shape); return std::make_tuple( JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))), JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement)))); } constexpr auto* InOutPlacedNdSbpDimReduce = - DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCached); + DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCachedCopiable); // NOLINTBEGIN(maybe-need-error-msg) Maybe RawCheckParallelDimReduce(Symbol in, Symbol out, @@ -51,7 +51,7 @@ Maybe RawCheckParallelDimReduce(Symbol in, Symbolplacement()->device_tag(), out->placement()->device_tag()); Symbol reduced_in; Symbol reduced_out; - std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out)); + std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, logical_shape)); for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num(); ++in_parallel_id) { @@ -102,7 +102,7 @@ Maybe ParallelDimReduce(const std::shared_ptr& tensor, Symbol reduced_in; Symbol reduced_out; - std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out)); + std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, *tensor->shape())); const std::shared_ptr& local_tensor = JUST(tensor->cur_rank_phy_tensor()); diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 95388e102b8..722f7c11f1e 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -271,7 +271,8 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp); + &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_blob_desc.shape()); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); @@ -422,6 +423,24 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, *reduced_out_parallel_desc = ParallelDesc(reduced_out_parallel_conf); } +// We can not just simply merging two same split +// For example, shape = [6], we are trying to merge [2, 2]: (S0, S0) -> [4]: S0 +// For each rank, [4]: S0 has number of data: 2, 2, 1, 1 +// For each rank, [2]: S0 has number of data: 3, 3 +// For each rank, [2, 2]: (S0, S0) has number of data: 2, 1, 2, 1 +// Thus {[2, 2]: (S0, S0)} != {[4]: S0} for shape [6] +// However {[2, 2]: (S0, S0)} == {[4]: S0} for shape [4], [5], [7], [8] +// More specifically, {[a, b]: (Si, Si)} == {[a*b]: Si} if and only if +// shape value % (a * b) == 0, 1, a*b - 1 +bool CanMergeSplit(int32_t shape_value, int32_t merged_split_hierarchy_value) { + int32_t remainder = shape_value % merged_split_hierarchy_value; + if (remainder <= 1 || remainder == merged_split_hierarchy_value - 1) { + return true; + } else { + return false; + } +} + } // namespace int32_t PartialRatio4Producer(const NdSbp& sbp_producer, @@ -434,29 +453,98 @@ int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, return Ratio4Sbp(sbp_consumer, consumer_parallel_desc, &SbpParallel::has_broadcast_parallel); } -void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, - ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp) { - const auto& hierarchy = parallel_desc.hierarchy(); - DimVector reduced_hierarchy; - FOR_RANGE(int64_t, i, 0, hierarchy->NumAxes()) { - if (hierarchy->At(i) != 1) { - if (reduced_nd_sbp->sbp_parallel().empty() - || (nd_sbp.sbp_parallel(i) - != reduced_nd_sbp->sbp_parallel(reduced_nd_sbp->sbp_parallel_size() - 1))) { - reduced_hierarchy.emplace_back(hierarchy->At(i)); - *reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(i); +void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, + NdSbp* reduced_nd_sbp, const Shape& logical_shape) { + reduced_hierarchy->clear(); + reduced_nd_sbp->clear_sbp_parallel(); + // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999] + // We hold the split when accessing the current dimension + // Do the true splitting until we reach the next step + // dim = 0, split_axis2holding_reduced_shapes: {(0: 601)} + // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 601)} + // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 150, 151)} + // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 150, 151)} + HashMap> split_axis2holding_reduced_shapes; + std::vector last_holding_reduced_shapes; + int32_t last_split_axis = -1; + auto add_to_reduced_sbp_hierarchy = [&](int32_t hierarchy_dim) { + // Clear the last holding split axis + if (last_split_axis >= 0) { + auto& holding_reduced_shapes = split_axis2holding_reduced_shapes[last_split_axis]; + holding_reduced_shapes.clear(); + for (int32_t last_holding_reduced_shape : last_holding_reduced_shapes) { + int32_t quotient = last_holding_reduced_shape / reduced_hierarchy->back(); + if (last_holding_reduced_shape % reduced_hierarchy->back() != 0) { + holding_reduced_shapes.insert(quotient + 1); + } + holding_reduced_shapes.insert(quotient); + } + } + // Add a new sbp_parallel and a new hierarchy dimension + const auto& curr_sbp_parallel = nd_sbp.sbp_parallel(hierarchy_dim); + reduced_hierarchy->emplace_back(hierarchy.At(hierarchy_dim)); + *reduced_nd_sbp->add_sbp_parallel() = curr_sbp_parallel; + // Hold the current split shape + if (curr_sbp_parallel.has_split_parallel()) { + last_holding_reduced_shapes.clear(); + last_split_axis = curr_sbp_parallel.split_parallel().axis(); + auto it = split_axis2holding_reduced_shapes.find(last_split_axis); + if (it == split_axis2holding_reduced_shapes.end()) { + // Looking at a dimension which is never splitted before + // Shape: [601, ...], sbp: (S0, ...) + last_holding_reduced_shapes.push_back(logical_shape.At(last_split_axis)); } else { - reduced_hierarchy.back() *= hierarchy->At(i); + // This dimension is splitted before + // Shape: [601, 301, ...], sbp: (S0, S1, B, S0, ...), hierarchy: [2, 3, 100, 7, ...] + // Looking at i = 3, we hold the second S0, but 601 is already splitted by the first S0. + // split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 100, 101)} + last_holding_reduced_shapes.assign(it->second.begin(), it->second.end()); + } + } else { + last_split_axis = -1; + } + }; + for (int32_t hierarchy_dim = 0; hierarchy_dim < hierarchy.NumAxes(); hierarchy_dim++) { + // Shrink those dimension with hierarchy value = 1 + if (hierarchy.At(hierarchy_dim) == 1) { continue; } + if (reduced_hierarchy->empty()) { + // Empty hierarchy, add to the back + add_to_reduced_sbp_hierarchy(hierarchy_dim); + continue; + } + const auto& current_sbp_parallel = nd_sbp.sbp_parallel(hierarchy_dim); + if (current_sbp_parallel + == reduced_nd_sbp->sbp_parallel(reduced_nd_sbp->sbp_parallel_size() - 1)) { + int32_t merged_hierarchy_value = reduced_hierarchy->back() * hierarchy.At(hierarchy_dim); + // You can merge two sbp with B or P. + // If sbp = S, then you need to make sure that all the shape value can be splitted + if (!current_sbp_parallel.has_split_parallel() + || std::all_of(last_holding_reduced_shapes.begin(), last_holding_reduced_shapes.end(), + [&](int32_t i) { return CanMergeSplit(i, merged_hierarchy_value); })) { + // Merge sbp and hierarchy + reduced_hierarchy->back() = merged_hierarchy_value; + continue; } } + // Can not merge, add to the back + add_to_reduced_sbp_hierarchy(hierarchy_dim); } // [1, 1, ..., 1]: Any --> [1]: (B) - if (reduced_hierarchy.empty()) { - reduced_hierarchy.emplace_back(hierarchy->At(0)); + if (reduced_hierarchy->empty()) { + reduced_hierarchy->emplace_back(hierarchy.At(0)); reduced_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); } +} + +void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, + ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, + const Shape& logical_shape) { + Shape reduced_hierarchy; + NdSbpDimReduce(*parallel_desc.hierarchy(), nd_sbp, &reduced_hierarchy, reduced_nd_sbp, + logical_shape); + ParallelConf reduced_parallel_conf = parallel_desc.parallel_conf(); - Shape(reduced_hierarchy).ToProto(reduced_parallel_conf.mutable_hierarchy()); + reduced_hierarchy.ToProto(reduced_parallel_conf.mutable_hierarchy()); *reduced_parallel_desc = ParallelDesc(reduced_parallel_conf); } @@ -464,7 +552,7 @@ void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp, - NdSbp* reduced_out_nd_sbp) { + NdSbp* reduced_out_nd_sbp, const Shape& logical_shape) { const int64_t in_hierarchy_axes = in_parallel_desc.hierarchy()->NumAxes(); const int64_t out_hierarchy_axes = out_parallel_desc.hierarchy()->NumAxes(); if (in_hierarchy_axes == 1 && out_hierarchy_axes == 1) { @@ -473,8 +561,10 @@ void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, *reduced_in_nd_sbp = in_nd_sbp; *reduced_out_nd_sbp = out_nd_sbp; } else if (in_hierarchy_axes != out_hierarchy_axes) { - NdSbpDimReduce(in_parallel_desc, in_nd_sbp, reduced_in_parallel_desc, reduced_in_nd_sbp); - NdSbpDimReduce(out_parallel_desc, out_nd_sbp, reduced_out_parallel_desc, reduced_out_nd_sbp); + NdSbpDimReduce(in_parallel_desc, in_nd_sbp, reduced_in_parallel_desc, reduced_in_nd_sbp, + logical_shape); + NdSbpDimReduce(out_parallel_desc, out_nd_sbp, reduced_out_parallel_desc, reduced_out_nd_sbp, + logical_shape); } else { CollaborativeParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp, reduced_in_parallel_desc, reduced_out_parallel_desc, @@ -497,7 +587,8 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp); + &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_blob_desc.shape()); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); @@ -675,12 +766,12 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbpDimReduce(producer_parallel_desc, producer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_in_nd_sbp); + &reduced_in_nd_sbp, logical_blob_desc.shape()); ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; NdSbp reduced_out_nd_sbp; NdSbpDimReduce(consumer_parallel_desc, consumer_sbp_parallel, &reduced_out_parallel_desc, - &reduced_out_nd_sbp); + &reduced_out_nd_sbp, logical_blob_desc.shape()); // In 90% of the transfer, we would have the same parallel description for producer and consumer // We need to speed it up and give an approximation of the cost if (reduced_in_parallel_desc == reduced_out_parallel_desc @@ -754,7 +845,8 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, // Decide the priority to infer sbp double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp, const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { + const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp, + const Shape& logical_shape) { if (producer_nd_sbp == consumer_nd_sbp && producer_parallel_desc == consumer_parallel_desc) { // Highest priority: this blob have the same placement and sbp on both the producer and // consumer @@ -764,13 +856,13 @@ double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consum ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc, - &reduced_in_nd_sbp); + &reduced_in_nd_sbp, logical_shape); // Dim reduction for consumer ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; NdSbp reduced_out_nd_sbp; NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc, - &reduced_out_nd_sbp); + &reduced_out_nd_sbp, logical_shape); if (requires_same_sbp) { // This blob does not support boxing diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index fabb13edbfa..4bca4e5562a 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -43,14 +43,17 @@ int32_t PartialRatio4Producer(const NdSbp& sbp_producer, int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, const ParallelDesc& consumer_parallel_desc); +void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, + NdSbp* reduced_nd_sbp, const Shape& logical_shape); void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, - ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp); + ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, + const Shape& logical_shape); void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp, - NdSbp* reduced_out_nd_sbp); + NdSbp* reduced_out_nd_sbp, const Shape& logical_shape); double GetValidMaxCopyCost(); @@ -105,7 +108,8 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, double ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel, const NdSbp& consumer_sbp_parallel, const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp); + const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp, + const Shape& logical_shape); // The transfer ratio for general basic communication // Cost = ratio * data amount diff --git a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp index dd3e8039af7..fe77fbbdd68 100644 --- a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp +++ b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp @@ -66,6 +66,7 @@ Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_bu if (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) { continue; } const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn); const OpNode& producer = node->ProducerOpNode4Lbi(lbi); + const auto& logical_shape = node->LogicalBlobDesc4Lbi(lbi).shape(); const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi); const std::string& producer_lbn = *CHECK_JUST(producer.op().obn4lbi(lbi)); const ParallelDesc& producer_parallel_desc = @@ -73,7 +74,7 @@ Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_bu ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc, - &reduced_in_nd_sbp); + &reduced_in_nd_sbp, logical_shape); const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn); const ParallelDesc& consumer_parallel_desc = @@ -81,7 +82,7 @@ Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_bu ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; NdSbp reduced_out_nd_sbp; NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc, - &reduced_out_nd_sbp); + &reduced_out_nd_sbp, logical_shape); if (reduced_in_parallel_desc == reduced_out_parallel_desc && reduced_in_nd_sbp == reduced_out_nd_sbp) { diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 20885c6633e..b10e300542e 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -422,7 +422,7 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const InOutParallelDimReduce(src_node->parallel_desc(), dst_node->parallel_desc(), src_node->NdSbp4Lbi(lbi), dst_node->NdSbp4Lbi(lbi), src_reduced_parallel_desc, dst_reduced_parallel_desc, src_reduced_nd_sbp, - dst_reduced_nd_sbp); + dst_reduced_nd_sbp, logical_blob_desc.shape()); CHECK_EQ(src_reduced_parallel_desc->parallel_num(), dst_reduced_parallel_desc->parallel_num()); std::shared_ptr src_reduced_hierarchy = src_reduced_parallel_desc->hierarchy(); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 2eb50bffd66..751af17d166 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -785,7 +785,7 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->nd_sbp(), JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn), producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), - requires_same_sbp[ibn_id]); + requires_same_sbp[ibn_id], producer_infer_hint4ibn->logical_blob_desc().shape()); sum_priority_ratio += priority_ratio; if (GlobalProcessCtx::Rank() == 0 From 3d06827e4ae9a75a765eb4fc9788db84ccb8a4f2 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 28 Oct 2022 21:26:09 +0800 Subject: [PATCH 20/30] Reduce in and out sbp simultaneously --- oneflow/core/framework/sbp_infer_util.cpp | 107 ++++++++++++++-------- oneflow/core/framework/sbp_infer_util.h | 3 + 2 files changed, 70 insertions(+), 40 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 722f7c11f1e..a3b53d71e45 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -376,6 +376,7 @@ Maybe GetComputeCopyCostFunc() { } } +// TODO: Remove this void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, @@ -455,8 +456,14 @@ int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, NdSbp* reduced_nd_sbp, const Shape& logical_shape) { + NdSbpsDimReduce(hierarchy, {&nd_sbp}, reduced_hierarchy, {reduced_nd_sbp}, logical_shape); +} + +void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd_sbps, + Shape* reduced_hierarchy, const std::vector& reduced_nd_sbps, + const Shape& logical_shape) { reduced_hierarchy->clear(); - reduced_nd_sbp->clear_sbp_parallel(); + for (auto& reduced_nd_sbp : reduced_nd_sbps) { reduced_nd_sbp->clear_sbp_parallel(); } // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999] // We hold the split when accessing the current dimension // Do the true splitting until we reach the next step @@ -464,45 +471,56 @@ void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_ // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 601)} // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 150, 151)} // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 150, 151)} - HashMap> split_axis2holding_reduced_shapes; - std::vector last_holding_reduced_shapes; - int32_t last_split_axis = -1; + int32_t sbp_num = nd_sbps.size(); + std::vector>> index2split_axis2holding_reduced_shapes(sbp_num); + std::vector> index2last_holding_reduced_shapes(sbp_num); + std::vector last_split_axises(sbp_num, -1); + std::vector indexes(sbp_num); + for (int32_t index = 0; index < sbp_num; index++) { indexes[index] = index; } auto add_to_reduced_sbp_hierarchy = [&](int32_t hierarchy_dim) { // Clear the last holding split axis - if (last_split_axis >= 0) { - auto& holding_reduced_shapes = split_axis2holding_reduced_shapes[last_split_axis]; - holding_reduced_shapes.clear(); - for (int32_t last_holding_reduced_shape : last_holding_reduced_shapes) { - int32_t quotient = last_holding_reduced_shape / reduced_hierarchy->back(); - if (last_holding_reduced_shape % reduced_hierarchy->back() != 0) { - holding_reduced_shapes.insert(quotient + 1); + for (int32_t index = 0; index < sbp_num; index++) { + auto& split_axis2holding_reduced_shapes = index2split_axis2holding_reduced_shapes[index]; + auto& last_holding_reduced_shapes = index2last_holding_reduced_shapes[index]; + auto& last_split_axis = last_split_axises[index]; + auto& nd_sbp = nd_sbps[index]; + auto& reduced_nd_sbp = reduced_nd_sbps[index]; + if (last_split_axis >= 0) { + auto& holding_reduced_shapes = split_axis2holding_reduced_shapes[last_split_axis]; + holding_reduced_shapes.clear(); + for (int32_t last_holding_reduced_shape : last_holding_reduced_shapes) { + int32_t quotient = last_holding_reduced_shape / reduced_hierarchy->back(); + if (last_holding_reduced_shape % reduced_hierarchy->back() != 0) { + holding_reduced_shapes.insert(quotient + 1); + } + holding_reduced_shapes.insert(quotient); } - holding_reduced_shapes.insert(quotient); } - } - // Add a new sbp_parallel and a new hierarchy dimension - const auto& curr_sbp_parallel = nd_sbp.sbp_parallel(hierarchy_dim); - reduced_hierarchy->emplace_back(hierarchy.At(hierarchy_dim)); - *reduced_nd_sbp->add_sbp_parallel() = curr_sbp_parallel; - // Hold the current split shape - if (curr_sbp_parallel.has_split_parallel()) { - last_holding_reduced_shapes.clear(); - last_split_axis = curr_sbp_parallel.split_parallel().axis(); - auto it = split_axis2holding_reduced_shapes.find(last_split_axis); - if (it == split_axis2holding_reduced_shapes.end()) { - // Looking at a dimension which is never splitted before - // Shape: [601, ...], sbp: (S0, ...) - last_holding_reduced_shapes.push_back(logical_shape.At(last_split_axis)); + // Add a new sbp_parallel and a new hierarchy dimension + const auto& curr_sbp_parallel = nd_sbp->sbp_parallel(hierarchy_dim); + *reduced_nd_sbp->add_sbp_parallel() = curr_sbp_parallel; + // Hold the current split shape + if (curr_sbp_parallel.has_split_parallel()) { + last_holding_reduced_shapes.clear(); + last_split_axis = curr_sbp_parallel.split_parallel().axis(); + auto it = split_axis2holding_reduced_shapes.find(last_split_axis); + if (it == split_axis2holding_reduced_shapes.end()) { + // Looking at a dimension which is never splitted before + // Shape: [601, ...], sbp: (S0, ...) + last_holding_reduced_shapes.push_back(logical_shape.At(last_split_axis)); + } else { + // This dimension is splitted before + // Shape: [601, 301, ...], sbp: (S0, S1, B, S0, ...), hierarchy: [2, 3, 100, 7, ...] + // Looking at i = 3, we hold the second S0, but 601 is already splitted by the first S0. + // split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 100, 101)} + last_holding_reduced_shapes.assign(it->second.begin(), it->second.end()); + } } else { - // This dimension is splitted before - // Shape: [601, 301, ...], sbp: (S0, S1, B, S0, ...), hierarchy: [2, 3, 100, 7, ...] - // Looking at i = 3, we hold the second S0, but 601 is already splitted by the first S0. - // split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 100, 101)} - last_holding_reduced_shapes.assign(it->second.begin(), it->second.end()); + last_split_axis = -1; } - } else { - last_split_axis = -1; } + // Add a new hierarchy dimension + reduced_hierarchy->emplace_back(hierarchy.At(hierarchy_dim)); }; for (int32_t hierarchy_dim = 0; hierarchy_dim < hierarchy.NumAxes(); hierarchy_dim++) { // Shrink those dimension with hierarchy value = 1 @@ -512,15 +530,22 @@ void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_ add_to_reduced_sbp_hierarchy(hierarchy_dim); continue; } - const auto& current_sbp_parallel = nd_sbp.sbp_parallel(hierarchy_dim); - if (current_sbp_parallel - == reduced_nd_sbp->sbp_parallel(reduced_nd_sbp->sbp_parallel_size() - 1)) { + if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) { + // reduced_hierarchy->size() == reduced_nd_sbps[index]->sbp_parallel_size() + // Basically, current nd sbp == reduced nd sbp.back() + return nd_sbps[index]->sbp_parallel(hierarchy_dim) + == reduced_nd_sbps[index]->sbp_parallel(reduced_hierarchy->size() - 1); + })) { int32_t merged_hierarchy_value = reduced_hierarchy->back() * hierarchy.At(hierarchy_dim); // You can merge two sbp with B or P. // If sbp = S, then you need to make sure that all the shape value can be splitted - if (!current_sbp_parallel.has_split_parallel() - || std::all_of(last_holding_reduced_shapes.begin(), last_holding_reduced_shapes.end(), - [&](int32_t i) { return CanMergeSplit(i, merged_hierarchy_value); })) { + if (std::all_of(indexes.begin(), indexes.end(), [&](int32_t index) { + return !nd_sbps[index]->sbp_parallel(hierarchy_dim).has_split_parallel() + || std::all_of(index2last_holding_reduced_shapes[index].begin(), + index2last_holding_reduced_shapes[index].end(), [&](int32_t i) { + return CanMergeSplit(i, merged_hierarchy_value); + }); + })) { // Merge sbp and hierarchy reduced_hierarchy->back() = merged_hierarchy_value; continue; @@ -532,7 +557,9 @@ void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_ // [1, 1, ..., 1]: Any --> [1]: (B) if (reduced_hierarchy->empty()) { reduced_hierarchy->emplace_back(hierarchy.At(0)); - reduced_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); + for (auto& reduced_nd_sbp : reduced_nd_sbps) { + reduced_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); + } } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 4bca4e5562a..8454d767c78 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -45,6 +45,9 @@ int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy, NdSbp* reduced_nd_sbp, const Shape& logical_shape); +void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd_sbps, + Shape* reduced_hierarchy, const std::vector& reduced_nd_sbps, + const Shape& logical_shape); void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, const Shape& logical_shape); From 34c3395c64cb6cb261b36fde6e0d60948142bba7 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 17:56:39 +0800 Subject: [PATCH 21/30] Speed up for 1d sbp Package up the function for replacing hierarchy --- oneflow/core/framework/sbp_infer_util.cpp | 26 +++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index a3b53d71e45..3a82d58ace9 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -424,6 +424,14 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, *reduced_out_parallel_desc = ParallelDesc(reduced_out_parallel_conf); } +// Replace the hierarchy and then create a new parallel description +void ReplaceHierarchy4ParallelDesc(const ParallelDesc& old_parallel_desc, + const Shape& new_hierarchy, ParallelDesc* new_parallel_desc) { + ParallelConf new_parallel_conf = old_parallel_desc.parallel_conf(); + new_hierarchy.ToProto(new_parallel_conf.mutable_hierarchy()); + *new_parallel_desc = ParallelDesc(new_parallel_conf); +} + // We can not just simply merging two same split // For example, shape = [6], we are trying to merge [2, 2]: (S0, S0) -> [4]: S0 // For each rank, [4]: S0 has number of data: 2, 2, 1, 1 @@ -462,6 +470,13 @@ void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd_sbps, Shape* reduced_hierarchy, const std::vector& reduced_nd_sbps, const Shape& logical_shape) { + int32_t sbp_num = nd_sbps.size(); + // Speed up for 1d sbp + if (hierarchy.NumAxes() == 1) { + *reduced_hierarchy = hierarchy; + for (int32_t index = 0; index < sbp_num; index++) { *reduced_nd_sbps[index] = *nd_sbps[index]; } + return; + } reduced_hierarchy->clear(); for (auto& reduced_nd_sbp : reduced_nd_sbps) { reduced_nd_sbp->clear_sbp_parallel(); } // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999] @@ -471,7 +486,6 @@ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 601)} // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 150, 151)} // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 150, 151)} - int32_t sbp_num = nd_sbps.size(); std::vector>> index2split_axis2holding_reduced_shapes(sbp_num); std::vector> index2last_holding_reduced_shapes(sbp_num); std::vector last_split_axises(sbp_num, -1); @@ -566,13 +580,17 @@ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, const Shape& logical_shape) { + // Speed up for 1d sbp + if (parallel_desc.hierarchy()->NumAxes() == 1) { + *reduced_parallel_desc = parallel_desc; + *reduced_nd_sbp = nd_sbp; + return; + } Shape reduced_hierarchy; NdSbpDimReduce(*parallel_desc.hierarchy(), nd_sbp, &reduced_hierarchy, reduced_nd_sbp, logical_shape); - ParallelConf reduced_parallel_conf = parallel_desc.parallel_conf(); - reduced_hierarchy.ToProto(reduced_parallel_conf.mutable_hierarchy()); - *reduced_parallel_desc = ParallelDesc(reduced_parallel_conf); + ReplaceHierarchy4ParallelDesc(parallel_desc, reduced_hierarchy, reduced_parallel_desc); } void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, From fc6127f242e4995986263d4fd2a69d36304c2823 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 20:11:18 +0800 Subject: [PATCH 22/30] Reduced simultaneously with the same hierarchy --- oneflow/core/framework/sbp_infer_util.cpp | 130 +++++++++++----------- oneflow/core/framework/sbp_infer_util.h | 5 + 2 files changed, 73 insertions(+), 62 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 3a82d58ace9..8de13f9afa1 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -376,60 +376,16 @@ Maybe GetComputeCopyCostFunc() { } } -// TODO: Remove this -void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, - const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, - const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, - ParallelDesc* reduced_out_parallel_desc, - NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp) { - const auto& in_hierarchy = in_parallel_desc.hierarchy(); - const auto& out_hierarchy = out_parallel_desc.hierarchy(); - CHECK_EQ(in_hierarchy->NumAxes(), out_hierarchy->NumAxes()); - - DimVector reduced_in_hierarchy; - DimVector reduced_out_hierarchy; - FOR_RANGE(int64_t, i, 0, in_hierarchy->NumAxes()) { - if (in_hierarchy->At(i) != 1 || out_hierarchy->At(i) != 1) { - if (reduced_in_nd_sbp->sbp_parallel().empty() - || (in_nd_sbp.sbp_parallel(i) - != reduced_in_nd_sbp->sbp_parallel(reduced_in_nd_sbp->sbp_parallel_size() - 1) - || out_nd_sbp.sbp_parallel(i) - != reduced_out_nd_sbp->sbp_parallel(reduced_out_nd_sbp->sbp_parallel_size() - - 1))) { - reduced_in_hierarchy.emplace_back(in_hierarchy->At(i)); - *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(i); - - reduced_out_hierarchy.emplace_back(out_hierarchy->At(i)); - *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(i); - } else { - reduced_in_hierarchy.back() *= in_hierarchy->At(i); - reduced_out_hierarchy.back() *= out_hierarchy->At(i); - } - } - } - if (reduced_in_hierarchy.empty()) { - reduced_in_hierarchy.emplace_back(in_hierarchy->At(0)); - *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(0); - - reduced_out_hierarchy.emplace_back(out_hierarchy->At(0)); - *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(0); - } - - ParallelConf reduced_in_parallel_conf = in_parallel_desc.parallel_conf(); - Shape(reduced_in_hierarchy).ToProto(reduced_in_parallel_conf.mutable_hierarchy()); - *reduced_in_parallel_desc = ParallelDesc(reduced_in_parallel_conf); - - ParallelConf reduced_out_parallel_conf = out_parallel_desc.parallel_conf(); - Shape(reduced_out_hierarchy).ToProto(reduced_out_parallel_conf.mutable_hierarchy()); - *reduced_out_parallel_desc = ParallelDesc(reduced_out_parallel_conf); -} - // Replace the hierarchy and then create a new parallel description void ReplaceHierarchy4ParallelDesc(const ParallelDesc& old_parallel_desc, const Shape& new_hierarchy, ParallelDesc* new_parallel_desc) { - ParallelConf new_parallel_conf = old_parallel_desc.parallel_conf(); - new_hierarchy.ToProto(new_parallel_conf.mutable_hierarchy()); - *new_parallel_desc = ParallelDesc(new_parallel_conf); + if (*old_parallel_desc.hierarchy() == new_hierarchy) { + *new_parallel_desc = old_parallel_desc; + } else { + ParallelConf new_parallel_conf = old_parallel_desc.parallel_conf(); + new_hierarchy.ToProto(new_parallel_conf.mutable_hierarchy()); + *new_parallel_desc = ParallelDesc(new_parallel_conf); + } } // We can not just simply merging two same split @@ -593,27 +549,77 @@ void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ReplaceHierarchy4ParallelDesc(parallel_desc, reduced_hierarchy, reduced_parallel_desc); } +void InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy, + const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, + Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy, + NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, + const Shape& logical_shape) { + if (in_hierarchy == out_hierarchy) { + // [2, 4]: (S0, S0) -> [2, 4]: (S0, S1) + NdSbpsDimReduce(in_hierarchy, {&in_nd_sbp, &out_nd_sbp}, reduced_in_hierarchy, + {reduced_in_nd_sbp, reduced_out_nd_sbp}, logical_shape); + *reduced_out_hierarchy = *reduced_in_hierarchy; + } else { + // [2, 4]: (S0, S0) -> [4, 2]: (S0, S1) + // [2, 4]: (S0, S0) -> [3, 3]: (S0, S1) + NdSbpDimReduce(in_hierarchy, in_nd_sbp, reduced_in_hierarchy, reduced_in_nd_sbp, logical_shape); + NdSbpDimReduce(out_hierarchy, out_nd_sbp, reduced_out_hierarchy, reduced_out_nd_sbp, + logical_shape); + + // Sbp of 3d or higher dimension would use general basic communication + // Only looks at 1d to 2d or 2d to 1d + if (reduced_in_hierarchy->NumAxes() + reduced_out_hierarchy->NumAxes() == 3 + && reduced_in_hierarchy->elem_cnt() == reduced_out_hierarchy->elem_cnt()) { + if (reduced_in_hierarchy->NumAxes() == 1) { + // [8]: S0 -> [4, 2]: (S0, S1) + // [8]: B -> [2, 4]: (S0, S1) + const auto& in_sbp_parallel = reduced_in_nd_sbp->sbp_parallel(0); + if (!in_sbp_parallel.has_split_parallel() + || CanMergeSplit(logical_shape.At(in_sbp_parallel.split_parallel().axis()), + reduced_in_hierarchy->elem_cnt())) { + // Change [8]: S0 -> [4, 2]: (S0, S1) to [4, 2]: (S0, S0) -> [4, 2]: (S0, S1) + // Change [8]: B -> [2, 4]: (S0, S1) to [2, 4]: (B, B) -> [2, 4]: (S0, S1) + *reduced_in_nd_sbp->add_sbp_parallel() = in_sbp_parallel; + *reduced_in_hierarchy = *reduced_out_hierarchy; + } + } else { + // [2, 3]: (S0, P) -> [6]: S0 + // [3, 4]: (B, S1) -> [12]: B + const auto& out_sbp_parallel = reduced_out_nd_sbp->sbp_parallel(0); + if (!out_sbp_parallel.has_split_parallel() + || CanMergeSplit(logical_shape.At(out_sbp_parallel.split_parallel().axis()), + reduced_out_hierarchy->elem_cnt())) { + // Change [2, 3]: (S0, P) -> [6]: S0 to [2, 3]: (S0, P) -> [2, 3]: (S0, S0) + // Change [3, 4]: (B, S1) -> [12]: B to [3, 4]: (B, S1) -> [3, 4]: (B, B) + *reduced_out_nd_sbp->add_sbp_parallel() = out_sbp_parallel; + *reduced_out_hierarchy = *reduced_in_hierarchy; + } + } + } + } +} + void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, const Shape& logical_shape) { - const int64_t in_hierarchy_axes = in_parallel_desc.hierarchy()->NumAxes(); - const int64_t out_hierarchy_axes = out_parallel_desc.hierarchy()->NumAxes(); - if (in_hierarchy_axes == 1 && out_hierarchy_axes == 1) { + // Speed up for 1d sbp + if (in_parallel_desc.hierarchy()->NumAxes() == 1 + && out_parallel_desc.hierarchy()->NumAxes() == 1) { *reduced_in_parallel_desc = in_parallel_desc; *reduced_out_parallel_desc = out_parallel_desc; *reduced_in_nd_sbp = in_nd_sbp; *reduced_out_nd_sbp = out_nd_sbp; - } else if (in_hierarchy_axes != out_hierarchy_axes) { - NdSbpDimReduce(in_parallel_desc, in_nd_sbp, reduced_in_parallel_desc, reduced_in_nd_sbp, - logical_shape); - NdSbpDimReduce(out_parallel_desc, out_nd_sbp, reduced_out_parallel_desc, reduced_out_nd_sbp, - logical_shape); } else { - CollaborativeParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp, - reduced_in_parallel_desc, reduced_out_parallel_desc, - reduced_in_nd_sbp, reduced_out_nd_sbp); + Shape reduced_in_hierarchy; + Shape reduced_out_hierarchy; + InOutParallelDimReduce(*in_parallel_desc.hierarchy(), *out_parallel_desc.hierarchy(), in_nd_sbp, + out_nd_sbp, &reduced_in_hierarchy, &reduced_out_hierarchy, + reduced_in_nd_sbp, reduced_out_nd_sbp, logical_shape); + ReplaceHierarchy4ParallelDesc(in_parallel_desc, reduced_in_hierarchy, reduced_in_parallel_desc); + ReplaceHierarchy4ParallelDesc(out_parallel_desc, reduced_out_hierarchy, + reduced_out_parallel_desc); } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 8454d767c78..d4dca3cc71e 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -52,6 +52,11 @@ void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp, const Shape& logical_shape); +void InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy, + const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, + Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy, + NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp, + const Shape& logical_shape); void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc, const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc, From 38f405ccfa2e18627db049515350e4601c46c658 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 21:57:32 +0800 Subject: [PATCH 23/30] Deal with 1to2d and 2to1d in InOutParallelDimReduce() --- .../hierarchical_sub_task_graph_builder_impl.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index 7592e50c9f2..11f3a0e6983 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -498,7 +498,7 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( NdSbp reduced_out_nd_sbp; InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp, &reduced_in_parallel_desc, &reduced_out_parallel_desc, &reduced_in_nd_sbp, - &reduced_out_nd_sbp); + &reduced_out_nd_sbp, logical_blob_desc.shape()); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2) @@ -520,13 +520,13 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); - } else if (in_hierarchy->elem_cnt() == out_hierarchy->elem_cnt() - && ((in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 2) - || (in_hierarchy->NumAxes() == 2 && out_hierarchy->NumAxes() == 1))) { - return impl_->expand_to_same_2d_hierarchy_sub_tsk_gph_builder_->Build( - ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, - reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, - time_shape); + // } else if (in_hierarchy->elem_cnt() == out_hierarchy->elem_cnt() + // && ((in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 2) + // || (in_hierarchy->NumAxes() == 2 && out_hierarchy->NumAxes() == 1))) { + // return impl_->expand_to_same_2d_hierarchy_sub_tsk_gph_builder_->Build( + // ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, + // reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, + // reduced_out_nd_sbp, time_shape); } else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( From c414ce71b7a0026b83d3167045c643785e15bfaf Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 14:14:28 +0000 Subject: [PATCH 24/30] Pass 1to2d and 2to1d test cases --- python/oneflow/test/graph/test_gbc1to2d.py | 6 +++--- python/oneflow/test/graph/test_gbc2to1d.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/oneflow/test/graph/test_gbc1to2d.py b/python/oneflow/test/graph/test_gbc1to2d.py index f72e5aa8ac7..18c996d1cbc 100644 --- a/python/oneflow/test/graph/test_gbc1to2d.py +++ b/python/oneflow/test/graph/test_gbc1to2d.py @@ -33,8 +33,8 @@ def _test_general_basic_communication_1d_to_2d(test_case, src_nd_sbp, dst_nd_sbp # input placement_x = flow.placement("cuda", ranks=[0, 1, 2]) - placement_y = flow.placement("cuda", ranks=[[3, 4], [1, 2]]) - local_np = np.arange(4 * 12).reshape(4, 12) + placement_y = flow.placement("cuda", ranks=[[3, 0], [1, 2]]) + local_np = np.arange(4 * 14).reshape(4, 14) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) # check eager boxing @@ -77,7 +77,7 @@ def gen_nd_sbp_2d(): return nd_sbp_list -@flow.unittest.skip_unless_2n4d() +@flow.unittest.skip_unless_1n4d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGeneralBasicCommunication(flow.unittest.TestCase): def test_general_basic_communication(test_case): diff --git a/python/oneflow/test/graph/test_gbc2to1d.py b/python/oneflow/test/graph/test_gbc2to1d.py index 0dfe60ef43c..49ae148030b 100644 --- a/python/oneflow/test/graph/test_gbc2to1d.py +++ b/python/oneflow/test/graph/test_gbc2to1d.py @@ -34,7 +34,7 @@ def _test_general_basic_communication_2d_to_1d(test_case, src_nd_sbp, dst_nd_sbp # input placement_x = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) placement_y = flow.placement("cuda", ranks=[0, 3, 4]) - local_np = np.arange(12 * 4).reshape(12, 4) + local_np = np.arange(13 * 5).reshape(13, 5) x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) # check eager boxing From 75d5a8214b30782348564488ef32b8875426c2c1 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 22:24:10 +0800 Subject: [PATCH 25/30] Remove the old code --- ...erarchical_sub_task_graph_builder_impl.cpp | 64 +------------------ 1 file changed, 3 insertions(+), 61 deletions(-) diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index 11f3a0e6983..7f67db4fe68 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -411,62 +411,10 @@ class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilde std::unique_ptr dim0_nd_sbp_mismatched_sub_tsk_gph_builder_; }; -class ExpandToSame2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder { - public: - OF_DISALLOW_COPY_AND_MOVE(ExpandToSame2DHierarchySubTskGphBuilder); - ExpandToSame2DHierarchySubTskGphBuilder() { - same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder()); - } - ~ExpandToSame2DHierarchySubTskGphBuilder() override = default; - - Maybe Build(SubTskGphBuilderCtx* ctx, - const std::vector& sorted_in_tasks, - std::vector* sorted_out_tasks, - std::vector>* sorted_ctrl_tasks, - const ParallelDesc& in_parallel_desc, - const ParallelDesc& out_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, - const Shape& time_shape) const override { - if (in_parallel_desc.hierarchy()->elem_cnt() == out_parallel_desc.hierarchy()->elem_cnt() - && in_parallel_desc.hierarchy()->NumAxes() == 1 - && out_parallel_desc.hierarchy()->NumAxes() == 2) { - ParallelConf intermediate_parallel_conf = in_parallel_desc.parallel_conf(); - out_parallel_desc.hierarchy()->ToProto(intermediate_parallel_conf.mutable_hierarchy()); - NdSbp intermediate_nd_sbp; - *intermediate_nd_sbp.add_sbp_parallel() = in_nd_sbp.sbp_parallel(0); - *intermediate_nd_sbp.add_sbp_parallel() = in_nd_sbp.sbp_parallel(0); - return same_2d_hierarchy_sub_tsk_gph_builder_->Build( - ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, - ParallelDesc(intermediate_parallel_conf), out_parallel_desc, lbi, logical_blob_desc, - intermediate_nd_sbp, out_nd_sbp, time_shape); - } else if (in_parallel_desc.hierarchy()->elem_cnt() == out_parallel_desc.hierarchy()->elem_cnt() - && in_parallel_desc.hierarchy()->NumAxes() == 2 - && out_parallel_desc.hierarchy()->NumAxes() == 1) { - ParallelConf intermediate_parallel_conf = out_parallel_desc.parallel_conf(); - in_parallel_desc.hierarchy()->ToProto(intermediate_parallel_conf.mutable_hierarchy()); - NdSbp intermediate_nd_sbp; - *intermediate_nd_sbp.add_sbp_parallel() = out_nd_sbp.sbp_parallel(0); - *intermediate_nd_sbp.add_sbp_parallel() = out_nd_sbp.sbp_parallel(0); - return same_2d_hierarchy_sub_tsk_gph_builder_->Build( - ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, - ParallelDesc(intermediate_parallel_conf), lbi, logical_blob_desc, in_nd_sbp, - intermediate_nd_sbp, time_shape); - } else { - return Error::BoxingNotSupportedError(); - } - } - - private: - std::unique_ptr same_2d_hierarchy_sub_tsk_gph_builder_; -}; - struct DispatchHierarchicalSubTskGphBuilder::Impl { Impl(); std::unique_ptr flat_sub_tsk_gph_builder_; std::unique_ptr same_2d_hierarchy_sub_tsk_gph_builder_; - std::unique_ptr - expand_to_same_2d_hierarchy_sub_tsk_gph_builder_; std::unique_ptr nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; @@ -474,8 +422,6 @@ struct DispatchHierarchicalSubTskGphBuilder::Impl { DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder()); same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder()); - expand_to_same_2d_hierarchy_sub_tsk_gph_builder_.reset( - new ExpandToSame2DHierarchySubTskGphBuilder()); nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder()); } @@ -496,6 +442,9 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( ParallelDesc reduced_out_parallel_desc = out_parallel_desc; NdSbp reduced_in_nd_sbp; NdSbp reduced_out_nd_sbp; + // The 1d to 2d and 2d to 1d cases are consider in this function + // If it gives out 1d sbp and 2d sbp simultaneously, then that the 2d sbp can not be converted + // to 1d sbp and 1d sbp can not be expanded to 2d sbp. InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp, &reduced_in_parallel_desc, &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); @@ -520,13 +469,6 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); - // } else if (in_hierarchy->elem_cnt() == out_hierarchy->elem_cnt() - // && ((in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 2) - // || (in_hierarchy->NumAxes() == 2 && out_hierarchy->NumAxes() == 1))) { - // return impl_->expand_to_same_2d_hierarchy_sub_tsk_gph_builder_->Build( - // ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, - // reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, - // reduced_out_nd_sbp, time_shape); } else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( From 575a55ff197630756f9adcce34d19582a50d3295 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 22:25:36 +0800 Subject: [PATCH 26/30] Revert "Add test script and print out information" This reverts commit 58cdfb40b6536eb74c02174d3a69409676da374f. --- oneflow/core/graph/exec_graph.cpp | 7 ++-- oneflow/core/operator/operator.cpp | 55 ++++++++---------------------- 2 files changed, 16 insertions(+), 46 deletions(-) diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index cd9668cb05a..2c47076abc7 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -76,11 +76,8 @@ namespace { Maybe CheckPhysicalBlobDesc(const BlobDesc& logical, const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, const ParallelContext* parallel_ctx, const BlobDesc& physical) { - auto& rhs = *JUST(GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx)); - CHECK_EQ_OR_RETURN(physical.shape(), - *JUST(GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx))) - << ", parallel num: " << parallel_ctx->parallel_id() << ", logical shape: " << logical.shape() - << ", lhs: " << physical.shape() << ", rhs: " << rhs; + CHECK_EQ_OR_RETURN(physical.shape(), *JUST(GetPhysicalShape(logical.shape(), nd_sbp, + parallel_desc, *parallel_ctx))); return Maybe::Ok(); } diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 751af17d166..795d07d483e 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,7 +17,6 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" -#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -38,12 +37,6 @@ namespace oneflow { namespace { -std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { - std::ostringstream out; - out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); - return out.str(); -} - DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -787,22 +780,6 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id], producer_infer_hint4ibn->logical_blob_desc().shape()); sum_priority_ratio += priority_ratio; - - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.embedding.word_embeddings.weight") - != std::string::npos) { - if (i == 0) { - std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) - << ", placement: " - << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) - << ", Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() - << std::endl; - } - std::cout << "idx: " << i << ", sbp: " - << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) - << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) - << std::endl; - } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -836,25 +813,21 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; - - if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } - } + if (select_sbp_idx == -1) { + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; + } - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.embedding.word_embeddings.weight") != std::string::npos) { - std::cout << err.str() << std::endl; - std::cout << "select idx: " << select_sbp_idx << std::endl; + return Error::RuntimeError() << err.str(); } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From c8216b9d72a39e6032f158eacfb79175ab3d00b6 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 31 Oct 2022 22:41:16 +0800 Subject: [PATCH 27/30] Add the check for split questionary back --- oneflow/core/job/job_build_and_infer_ctx.cpp | 6 ++++-- oneflow/core/operator/operator.cpp | 12 +----------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index d5600c8ec7d..cca19606207 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -318,10 +318,12 @@ Maybe JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t p if (sbp_parallel.has_split_parallel()) { const int64_t axis = sbp_parallel.split_parallel().axis(); CHECK_GT_OR_RETURN(current_shape.At(axis), 0); - CHECK_EQ_OR_RETURN(current_shape.At(axis) % parallel_hierarchy->At(i), 0) + // Support unbalanced splitting + CHECK_GE_OR_RETURN(current_shape.At(axis), parallel_hierarchy->At(i)) << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " cannot split blob by parallel_hierarchy: " << std::to_string(parallel_hierarchy->At(i)); + // Split and take the minimum one current_shape.Set(axis, current_shape.At(axis) / parallel_hierarchy->At(i)); } } @@ -583,7 +585,7 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con } JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn)); // Check splitability - // JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num())); + JUST(CheckOpBlobSplitability(op, parallel_desc.parallel_num())); return op->GetOpAttributeWithoutOpNameAndLbn(); } diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 795d07d483e..ec5bb4faa10 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -1647,16 +1647,7 @@ Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp const auto& sbp_parallel = nd_sbp.sbp_parallel(i); if (sbp_parallel.has_split_parallel()) { const int64_t split_axis = sbp_parallel.split_parallel().axis(); - // if (LazyMode::is_enabled()) { - // CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0) - // << Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << - // split_axis - // << ") of logical shape must be divisible by parallel num, but got logical_shape: " - // << logical_shape.ToString() - // << ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc))) - // << ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp)); - // physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i)); - // } else { + // Both the lazy and eager mode support unbalanced splitting now if (physical->At(split_axis) > 0) { CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i)) << Error::RuntimeError() << "Expected size at split axis (" << split_axis @@ -1668,7 +1659,6 @@ Maybe GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i)); physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size()); } - // } } } return physical; From 7deae839ba3bb19f3fc2cb22be7bd6894f86c6a4 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 2 Nov 2022 21:07:30 +0800 Subject: [PATCH 28/30] Feat speed up cost computation (#9355) * Compilation speed up * Speed up compilation for cost between 1d sbp --- oneflow/core/framework/sbp_infer_util.cpp | 180 ++++++++++++---------- 1 file changed, 97 insertions(+), 83 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 8de13f9afa1..4a553db5642 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -73,8 +73,9 @@ int32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel, const SbpParallel& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, - const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc) { + bool on_same_devices, + int32_t producer_parallel_num, + int32_t consumer_parallel_num) { if (!(CheckSbpParallel(producer_sbp_parallel) && CheckSbpParallel(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } @@ -89,7 +90,7 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb // NOTE: A tensor placed on cpu with a consumer operator that accepts cuda inputs would be // transfered to cuda later. We might not have correct parallel description at this moment. - if (producer_parallel_desc == consumer_parallel_desc) { + if (on_same_devices && producer_parallel_num == consumer_parallel_num) { // Same sbp, no cost: S->S, B->B, P->P if (producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; } double logical_blob_size = @@ -99,8 +100,8 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb // arbitrary i. // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { - return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(), - consumer_parallel_desc.parallel_num()); + return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, + consumer_parallel_num); } // B->S if (producer_sbp_parallel.has_broadcast_parallel()) { return 1.0; } @@ -110,15 +111,14 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb if (consumer_sbp_parallel.has_split_parallel() && producer_sbp_parallel.has_split_parallel()) { // S(0)->S(1), S(1)->S(0), etc. - return logical_blob_size * (producer_parallel_desc.parallel_num() - 1) - / producer_parallel_desc.parallel_num(); + return logical_blob_size * (producer_parallel_num - 1) / producer_parallel_num; } else { // P->S, S->B/P - return logical_blob_size * (producer_parallel_desc.parallel_num() - 1); + return logical_blob_size * (producer_parallel_num - 1); } } // P->B - return 2 * logical_blob_size * (producer_parallel_desc.parallel_num() - 1); + return 2 * logical_blob_size * (producer_parallel_num - 1); } else { // Not supporting P->P for different placement if (LazyMode::is_enabled()) { @@ -133,17 +133,16 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb double overall_cost = logical_blob_size; // ? -> B if (consumer_sbp_parallel.has_broadcast_parallel()) { - overall_cost += (consumer_parallel_desc.parallel_num() - 1) * logical_blob_size; + overall_cost += (consumer_parallel_num - 1) * logical_blob_size; } // P -> ? if (producer_sbp_parallel.has_partial_sum_parallel()) { - overall_cost += (producer_parallel_desc.parallel_num() - 1) * logical_blob_size; + overall_cost += (producer_parallel_num - 1) * logical_blob_size; } // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { - overall_cost += - Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(), - consumer_parallel_desc.parallel_num()); + overall_cost += Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, + consumer_parallel_num); } // For B->S, S->S, overall_cost == logical_blob_size; return overall_cost; @@ -202,9 +201,8 @@ double ComputCopyCostBetweenTwoDiffSbpParallel(const SbpParallel& producer_sbp_p Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp, double logical_blob_size, - const std::shared_ptr& hierarchy, - bool on_same_devices) { - if (hierarchy->NumAxes() != 2) { return kUnsupportedBoxing; } + const Shape& hierarchy, bool on_same_devices) { + if (hierarchy.NumAxes() != 2) { return kUnsupportedBoxing; } const auto& producer_sbp_size = producer_nd_sbp.sbp_parallel_size(); const auto& consumer_sbp_size = consumer_nd_sbp.sbp_parallel_size(); // One of the SBP should have size 2 @@ -221,7 +219,7 @@ Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, // The SBP parallel are the same at dimension (dim_same_sbp) if (producer_nd_sbp.sbp_parallel(dim_producer) == consumer_nd_sbp.sbp_parallel(dim_consumer)) { if (!producer_nd_sbp.sbp_parallel(dim_producer).has_split_parallel()) { - logical_blob_size *= hierarchy->At(dim_same_sbp); + logical_blob_size *= hierarchy.At(dim_same_sbp); } // The SBP parallel are different at dimension (dim_diff_sbp) int32_t dim_diff_sbp = 1 - dim_same_sbp; @@ -241,7 +239,7 @@ Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, } return ComputCopyCostBetweenTwoDiffSbpParallel( producer_nd_sbp.sbp_parallel(dim_producer), consumer_nd_sbp.sbp_parallel(dim_consumer), - logical_blob_size, hierarchy->At(dim_diff_sbp), on_same_devices); + logical_blob_size, hierarchy.At(dim_diff_sbp), on_same_devices); } } return kUnsupportedBoxing; @@ -265,34 +263,36 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle return kUnsupportedBoxing; } - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); + + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, - consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); - const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); - const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); - bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. - if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } + if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { + return 0.0; + } if (requires_same_sbp) { return kUnsupportedBoxing; } - int32_t in_dim = in_hierarchy->NumAxes(); - int32_t out_dim = out_hierarchy->NumAxes(); + int32_t in_dim = reduced_in_hierarchy.NumAxes(); + int32_t out_dim = reduced_out_hierarchy.NumAxes(); // We support different hierarchy for 1D sbp if (in_dim == 1 && out_dim == 1) { return ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc, - reduced_in_parallel_desc, reduced_out_parallel_desc); + on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt()); } double total_cost = 1.0; - if (reduced_in_parallel_desc == reduced_out_parallel_desc) { + if (on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { // NOTE: After analysis, transfer cost increase if spliting the same dimension. // Example 1: (S(1), S(0), S(1), S(0)) -> (S(0), S(0), S(0), S(0)) // Example 2: (B, S(0)) -> (S(0), S(0)) @@ -301,14 +301,15 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle // simplification. bool normal_case = true; // nd to nd - for (int32_t i = 0; i < reduced_in_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < in_dim; ++i) { const auto& in_sbp = reduced_in_nd_sbp.sbp_parallel(i); const auto& out_sbp = reduced_out_nd_sbp.sbp_parallel(i); // Have bugs here. (B, S0) -> (S0, S0) will give a cost 0. // Actually it is (1-1/m)T for hierarchy (n, m) // TODO: Fix that after support all sbp combination for eager. total_cost += JUST(ComputCopyCostBetweenTwoSbpParallel( - in_sbp, out_sbp, logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc)); + in_sbp, out_sbp, logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), + reduced_out_hierarchy.elem_cnt())); // Add the penalty for P in the consumer if (out_sbp.has_partial_sum_parallel() && (in_sbp != out_sbp)) { total_cost += Penalty4PartialInConsumer( @@ -338,20 +339,20 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); { double in_cost = 1.0; - for (int32_t i = 0; i < reduced_in_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < in_dim; ++i) { // P -> ? if (reduced_in_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { - in_cost *= reduced_in_parallel_desc.hierarchy()->At(i); + in_cost *= reduced_in_hierarchy.At(i); } } total_cost += logical_blob_size * in_cost; } { double out_cost = 1.0; - for (int32_t i = 0; i < reduced_out_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < out_dim; ++i) { // ? -> B if (reduced_out_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { - out_cost *= reduced_out_parallel_desc.hierarchy()->At(i); + out_cost *= reduced_out_hierarchy.At(i); } // Add the penalty for P in the consumer if (reduced_out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { @@ -430,7 +431,13 @@ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd // Speed up for 1d sbp if (hierarchy.NumAxes() == 1) { *reduced_hierarchy = hierarchy; - for (int32_t index = 0; index < sbp_num; index++) { *reduced_nd_sbps[index] = *nd_sbps[index]; } + for (int32_t index = 0; index < sbp_num; index++) { + if (hierarchy.elem_cnt() == 1) { + reduced_nd_sbps[index]->add_sbp_parallel()->mutable_broadcast_parallel(); + } else { + *reduced_nd_sbps[index] = *nd_sbps[index]; + } + } return; } reduced_hierarchy->clear(); @@ -632,19 +639,19 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); + + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, - consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); - - const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); - const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); - int32_t in_dim = in_hierarchy->NumAxes(); - int32_t out_dim = out_hierarchy->NumAxes(); + int32_t in_dim = reduced_in_hierarchy.NumAxes(); + int32_t out_dim = reduced_out_hierarchy.NumAxes(); // Not supporting n-D sbp with n >= 3 // TODO: Support it in the future if (std::min(in_dim, out_dim) <= 0 || std::max(in_dim, out_dim) >= 3) { @@ -653,7 +660,9 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. - if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } + if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { + return 0.0; + } if (requires_same_sbp) { return kUnsupportedBoxing; } // We support different hierarchy for 1D sbp @@ -661,7 +670,8 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel return GetTransferCost() + JUST(ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), - logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc)); + logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), + reduced_out_hierarchy.elem_cnt())); } #ifdef WITH_CUDA @@ -682,33 +692,36 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel #endif // WITH_CUDA // Not supporting different hierarchy without general basic communication - if (in_hierarchy->elem_cnt() != out_hierarchy->elem_cnt()) { return kUnsupportedBoxing; } + if (reduced_in_hierarchy.elem_cnt() != reduced_out_hierarchy.elem_cnt()) { + return kUnsupportedBoxing; + } - bool on_same_devices = - reduced_in_parallel_desc.EqualsIgnoringHierarchy(reduced_out_parallel_desc); double logical_blob_size = logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); if (in_dim == 2 && out_dim == 2) { // Not supporting different hierarchy // TODO: Support it in the future - if (*in_hierarchy != *out_hierarchy) { return kUnsupportedBoxing; } + if (reduced_in_hierarchy != reduced_out_hierarchy) { return kUnsupportedBoxing; } return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, in_hierarchy, on_same_devices)); + logical_blob_size, reduced_in_hierarchy, + on_same_devices)); } // (in_dim == 2 && out_dim == 1) || (in_dim == 1 && out_dim == 2) if (in_dim == 2 && out_dim == 1) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, in_hierarchy, on_same_devices)); + logical_blob_size, reduced_in_hierarchy, + on_same_devices)); } if (in_dim == 1 && out_dim == 2) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, out_hierarchy, on_same_devices)); + logical_blob_size, reduced_out_hierarchy, + on_same_devices)); } return Error::RuntimeError() @@ -813,23 +826,26 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { - // Reduce before cost computation - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - NdSbp reduced_in_nd_sbp; - NdSbpDimReduce(producer_parallel_desc, producer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_in_nd_sbp, logical_blob_desc.shape()); - - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; - NdSbp reduced_out_nd_sbp; - NdSbpDimReduce(consumer_parallel_desc, consumer_sbp_parallel, &reduced_out_parallel_desc, - &reduced_out_nd_sbp, logical_blob_desc.shape()); // In 90% of the transfer, we would have the same parallel description for producer and consumer // We need to speed it up and give an approximation of the cost - if (reduced_in_parallel_desc == reduced_out_parallel_desc - && reduced_in_nd_sbp == reduced_out_nd_sbp) { - if (producer_sbp_parallel == consumer_sbp_parallel) { + if (producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { + // [2, 2]: (S0, S1) -> [2, 2]: (S0, S1) + if (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy() + && producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; - } else { + } + // Reduce before cost computation + Shape reduced_in_hierarchy; + NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; + NdSbp reduced_out_nd_sbp; + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_blob_desc.shape()); + + // [2, 2]: (B, B) -> [4]: B + if (reduced_in_hierarchy == reduced_out_hierarchy && reduced_in_nd_sbp == reduced_out_nd_sbp) { return 1.0; } } @@ -903,22 +919,20 @@ double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consum // consumer return 0.0; } - // Dim reduction for producer - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; - NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc, - &reduced_in_nd_sbp, logical_shape); - - // Dim reduction for consumer - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc, - &reduced_out_nd_sbp, logical_shape); + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_nd_sbp, consumer_nd_sbp, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_shape); if (requires_same_sbp) { // This blob does not support boxing - if (reduced_in_nd_sbp == reduced_out_nd_sbp - && reduced_in_parallel_desc == reduced_out_parallel_desc) { + if (reduced_in_nd_sbp == reduced_out_nd_sbp && reduced_in_hierarchy == reduced_out_hierarchy + && producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { // Normal priority: No transfer occurs but we have different sbp // For example: [1]:S0 -> [1]:B // [1, 2]:(P, S0) -> [1, 2]:(S0, S0) From efbb118d3aad8467e62428bbad73310c8664c4b5 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 4 Nov 2022 09:47:22 +0000 Subject: [PATCH 29/30] fix comment typeo --- oneflow/core/framework/sbp_infer_util.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 4a553db5642..07ad0f2d53c 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -445,10 +445,12 @@ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd // At this moment, if we have [2, 4, 3, 7]: (S0, S1, S0, S0) for logical shape [601, 301, 999] // We hold the split when accessing the current dimension // Do the true splitting until we reach the next step - // dim = 0, split_axis2holding_reduced_shapes: {(0: 601)} - // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 601)} - // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 150, 151)} - // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 150, 151)} + // dim = 0, split_axis2holding_reduced_shapes: {(0: 601)}, last split axis = -1 + // dim = 1, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 301)}, last split axis = 0 + // dim = 2, split_axis2holding_reduced_shapes: {(0: 300, 301), (1: 75, 76)}, last split axis = 1 + // dim = 3, at this moment, last split axis (0) == current split axis (0), + // dim = 3, but judging 300 % (3 * 7) = 6 fails the CanMergeSplit(), not merging + // dim = 3, split_axis2holding_reduced_shapes: {(0: 100, 101), (1: 75, 76)}, last split axis = 0 std::vector>> index2split_axis2holding_reduced_shapes(sbp_num); std::vector> index2last_holding_reduced_shapes(sbp_num); std::vector last_split_axises(sbp_num, -1); From 9002bb5321fcf0d65760c76aa26face189ce16f3 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Fri, 4 Nov 2022 12:00:27 +0000 Subject: [PATCH 30/30] Address comment --- oneflow/core/job/job_build_and_infer_ctx.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index cca19606207..5326cdd2c67 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/time_util.h" +#include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/config_def.h" #include "oneflow/core/framework/to_string.h" @@ -304,7 +305,8 @@ Maybe JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t p if (logical_blob_desc.shape().NumAxes() > 0) { CHECK_GE_OR_RETURN(logical_blob_desc.shape().At(axis), blob_parallel_num) << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() - << " cannot split blob by parallel_num: " << std::to_string(blob_parallel_num); + << " shape: " << logical_blob_desc.shape() + << " cannot be splitted by parallel_num: " << blob_parallel_num << " at axis " << axis; } } } else { @@ -321,8 +323,9 @@ Maybe JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t p // Support unbalanced splitting CHECK_GE_OR_RETURN(current_shape.At(axis), parallel_hierarchy->At(i)) << "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() - << " cannot split blob by parallel_hierarchy: " - << std::to_string(parallel_hierarchy->At(i)); + << " shape: " << logical_blob_desc.shape() + << " cannot be splitted by nd sbp: " << NdSbpToString(pair.second) << " at axis " + << axis << " with parallel_hierarchy: " << *parallel_hierarchy; // Split and take the minimum one current_shape.Set(axis, current_shape.At(axis) / parallel_hierarchy->At(i)); }