From 2bb079963e1e06c534c6002f2cb804a6f60f797d Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Jan 2022 15:09:23 +0800 Subject: [PATCH] Auto parallel/cost for middle nodes (#7368) * Add transfer cost for middle nodes * Fix the bug of parallel desc * Add the tag in boxing collector for computing cost only * Deal with different placement in boxing collector * Check different parallel desc * Revert "Add transfer cost for middle nodes" This reverts commit d377f1160577be8c656cecc9ae897bc61ff5b720. * Compute the cost with middle nodes * Replace ComputeCopyCostBetweenNdSbp with ComputeCopyCostWithMiddleNodes * Parse transfer_cost from environment * Speed up * Reduce threshold to avoid strategy explosion * Format * Format --- .../core/auto_parallel/boxing_collector.cpp | 70 ++++++++++++------- oneflow/core/auto_parallel/boxing_collector.h | 11 ++- oneflow/core/auto_parallel/sbp_collector.cpp | 7 +- oneflow/core/auto_parallel/sbp_edge.h | 10 ++- oneflow/core/auto_parallel/sbp_graph.h | 2 +- oneflow/core/common/util.cpp | 9 +++ oneflow/core/common/util.h | 2 + oneflow/core/framework/sbp_infer_util.cpp | 45 ++++++++++++ oneflow/core/framework/sbp_infer_util.h | 11 ++- .../job_rewriter/boxing_with_middle_nodes.cpp | 9 ++- oneflow/core/operator/operator.cpp | 68 ++++++++++-------- 11 files changed, 171 insertions(+), 73 deletions(-) diff --git a/oneflow/core/auto_parallel/boxing_collector.cpp b/oneflow/core/auto_parallel/boxing_collector.cpp index d049fe290b5..9f221055561 100644 --- a/oneflow/core/auto_parallel/boxing_collector.cpp +++ b/oneflow/core/auto_parallel/boxing_collector.cpp @@ -42,6 +42,9 @@ void DfsSetNdSbp(std::vector<::oneflow::cfg::SbpParallel>& id2SbpParallel, int32 } } // namespace +// A constructor with init, designed for uncustomized boxing collector +BoxingCollector::BoxingCollector(int32_t max_axis) { Init(max_axis); } + // Construct a boxing collector with given maximum number of axis void BoxingCollector::Init(int32_t max_axis) { // Set up at least two split for op graph. @@ -88,10 +91,10 @@ void BoxingCollector::GenerateNdSbpList() { // Generate the transfer rule for different combinations and hierarchies Maybe BoxingCollector::GenerateCombination(int32_t max_middle_node_num) { // other parameters - // To be noted that the performance of this function are all the same with different hierarchy + // NOTE: The performance of this function are all the same with different hierarchy Shape hierarchy44({4, 4}); std::shared_ptr in_hierarchy = std::make_shared(hierarchy44); - auto in_parallel_desc = JUST(ParallelDesc::New("cpu", {"0", "1"}, in_hierarchy)); + auto in_parallel_desc = JUST(ParallelDesc::New("cpu", {"0:0-15"}, in_hierarchy)); BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false); // Store the origin transfer cost information int32_t n = nd_sbp_lists.size(); @@ -243,33 +246,43 @@ void BoxingCollector::PrintBoxingTables() { } // Ask if the boxing algorithm accepts the current sbp combination -Maybe BoxingCollector::AskSbpCombination(const cfg::NdSbp& sbp_producer, - const cfg::NdSbp& sbp_consumer, - const BlobDesc& logical_blob_desc, - const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc, - bool customized, - std::vector& middle_sbps) { - // Check the devices and hierarchy +Maybe BoxingCollector::AskSbpCombination( + const cfg::NdSbp& sbp_producer, const cfg::NdSbp& sbp_consumer, + const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc, bool is_customized, + std::vector& middle_sbps, bool compute_cost) { + middle_sbps.clear(); // At this moment, we do not support [2, 3] -> [3, 2] // TODO: support [2, 3] -> [3, 2] - CHECK_OR_RETURN(producer_parallel_desc.EqualsIgnoringDeviceType(consumer_parallel_desc)) - << "Boxing does not support transfer for different machines or devices or hierarchy"; - middle_sbps.clear(); + // Middle nodes does not support transfer for different machines or devices or hierarchy + if (producer_parallel_desc != consumer_parallel_desc) { + CHECK_OR_RETURN( + compute_cost + || JUST(ComputeCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc, + producer_parallel_desc, consumer_parallel_desc, false)) + < GetValidMaxCopyCost()) + << "Boxing does not support " << NdSbpParallelToString(sbp_producer) << " -> " + << NdSbpParallelToString(sbp_consumer) << " for two different placement "; + return Maybe::Ok(); + } const auto& parallel_hierarchy = producer_parallel_desc.hierarchy(); // Dealing with 1D sbp if (parallel_hierarchy->NumAxes() == 1) { CHECK_OR_RETURN( - JUST(ComputeCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc, - producer_parallel_desc, consumer_parallel_desc, false)) - < GetValidMaxCopyCost()) + compute_cost + || JUST(ComputeCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc, + producer_parallel_desc, consumer_parallel_desc, false)) + < GetValidMaxCopyCost()) << "Boxing does not support " << NdSbpParallelToString(sbp_producer) << " -> " << NdSbpParallelToString(sbp_consumer) << " for 1D sbp"; return Maybe::Ok(); } // Dealing with nD sbp, n>2 - CHECK_OR_RETURN(parallel_hierarchy->NumAxes() == 2) - << "Boxing does not support a hierarchy with dimension greater than 2"; + if (parallel_hierarchy->NumAxes() > 2) { + CHECK_OR_RETURN(compute_cost) + << "Boxing does not support a hierarchy with dimension greater than 2"; + return Maybe::Ok(); + } // Dealing with 2D sbp const auto& it_producer = NdSbpUniverse.find(sbp_producer); const auto& it_consumer = NdSbpUniverse.find(sbp_consumer); @@ -277,9 +290,12 @@ Maybe BoxingCollector::AskSbpCombination(const cfg::NdSbp& sbp_producer, int32_t i = it_producer->second; int32_t j = it_consumer->second; // Such combination can not be support with limited middle nodes - CHECK(minimum_copy_cost[i][j] < GetValidMaxCopyCost()) - << "Boxing does not support " << NdSbpParallelToString(sbp_producer) << " -> " - << NdSbpParallelToString(sbp_consumer) << " for 2D sbp"; + if (minimum_copy_cost[i][j] > GetValidMaxCopyCost()) { + CHECK_OR_RETURN(compute_cost) + << "Boxing does not support " << NdSbpParallelToString(sbp_producer) << " -> " + << NdSbpParallelToString(sbp_consumer) << " for 2D sbp"; + return Maybe::Ok(); + } // Current design can deal with such combination. Do not need to insert middle nodes if (middle_nodes[i][j].size() == 0) { return Maybe::Ok(); } // Find a list of middle nodes with minimum storage @@ -310,9 +326,13 @@ Maybe BoxingCollector::AskSbpCombination(const cfg::NdSbp& sbp_producer, } // // If we can not found a list of middle nodes even after customized boxing collector - CHECK_OR_RETURN(customized) << "Boxing does not support " << NdSbpParallelToString(sbp_producer) - << " -> " << NdSbpParallelToString(sbp_consumer) - << " for Shape: " << logical_blob_desc.shape(); + if (is_customized) { + CHECK_OR_RETURN(compute_cost) << "Boxing does not support " + << NdSbpParallelToString(sbp_producer) << " -> " + << NdSbpParallelToString(sbp_consumer) + << " for Shape: " << logical_blob_desc.shape(); + return Maybe::Ok(); + } // Customized boxing collector and try the algorithm again BoxingCollector customized_boxing_collector; @@ -323,7 +343,7 @@ Maybe BoxingCollector::AskSbpCombination(const cfg::NdSbp& sbp_producer, customized_boxing_collector.GenerateCombination(5); JUST(customized_boxing_collector.AskSbpCombination(sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, consumer_parallel_desc, - false, middle_sbps)); + false, middle_sbps, compute_cost)); return Maybe::Ok(); } diff --git a/oneflow/core/auto_parallel/boxing_collector.h b/oneflow/core/auto_parallel/boxing_collector.h index 721ff0e45d3..298d2cab315 100644 --- a/oneflow/core/auto_parallel/boxing_collector.h +++ b/oneflow/core/auto_parallel/boxing_collector.h @@ -30,6 +30,9 @@ class BoxingCollector final { ~BoxingCollector() = default; + // A constructor with init, designed for uncustomized boxing collector + BoxingCollector(int32_t max_axis); + // Set default Sbp list void CollectUniverse(int32_t max_axis); @@ -43,12 +46,14 @@ class BoxingCollector final { // Print the cost and middle nodes void PrintBoxingTables(); // Ask if the boxing algorithm accepts the current sbp combination - // If customized is true and we can not find a middle node list with + // If is_customized is true and we can not find a middle node list with + // resonable cost, error occurs. + // If compute_cost is true, then no error occur even if no suitable middle nodes paths found. Maybe AskSbpCombination(const cfg::NdSbp& sbp_producer, const cfg::NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc, bool customized, - std::vector& middle_sbps); + const ParallelDesc& consumer_parallel_desc, bool is_customized, + std::vector& middle_sbps, bool compute_cost); // Filter nd sbp from nd_sbp_lists with given logical shape Maybe FilterNdSbpList4LogicalShape(const BlobDesc& logical_blob_desc, const Shape& parallel_hierarchy); diff --git a/oneflow/core/auto_parallel/sbp_collector.cpp b/oneflow/core/auto_parallel/sbp_collector.cpp index 7f91ba2981a..316f170a281 100644 --- a/oneflow/core/auto_parallel/sbp_collector.cpp +++ b/oneflow/core/auto_parallel/sbp_collector.cpp @@ -122,9 +122,10 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode // compute copy cost for a specific logical blob // Use the parallel description of producer as those for consumer for now. - sbp_edge->Cost[sbp_id_producer][sbp_id_consumer] += CHECK_JUST(ComputeCopyCostBetweenNdSbp( - sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, - producer_parallel_desc, /*is_same=*/false)); + sbp_edge->Cost[sbp_id_producer][sbp_id_consumer] += + CHECK_JUST(ComputeCopyCostWithMiddleNodes(sbp_producer, sbp_consumer, logical_blob_desc, + producer_parallel_desc, + producer_parallel_desc, /*is_same=*/false)); } } } diff --git a/oneflow/core/auto_parallel/sbp_edge.h b/oneflow/core/auto_parallel/sbp_edge.h index 81834f2010f..bf573f4f31a 100644 --- a/oneflow/core/auto_parallel/sbp_edge.h +++ b/oneflow/core/auto_parallel/sbp_edge.h @@ -426,6 +426,7 @@ void SbpEdge::InitializeCopyCost(const std::string& ibn, bool comp // B->S cause cudaEventSynchronize in current implementation. bool is_same_sbp = (!compute_cost) || IsSameSbp(consumer, ibn); int32_t consumer_sbp_size = EndNode->SbpSignatureList.size(); + LazyMode::Guard enable_lazy_mode(true); // look through sbp signature in producer for (int32_t sbp_id_producer = 0; sbp_id_producer < StartNode->SbpSignatureList.size(); @@ -443,12 +444,9 @@ void SbpEdge::InitializeCopyCost(const std::string& ibn, bool comp const cfg::NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn); // compute copy cost for a specific logical blob - { - LazyMode::Guard enable_lazy_mode(true); - Cost[sbp_id_producer][sbp_id_consumer] += CHECK_JUST(ComputeCopyCostBetweenNdSbp( - sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, - consumer_parallel_desc, is_same_sbp)); - } + Cost[sbp_id_producer][sbp_id_consumer] += CHECK_JUST(ComputeCopyCostWithMiddleNodes( + sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, + consumer_parallel_desc, is_same_sbp)); } } } diff --git a/oneflow/core/auto_parallel/sbp_graph.h b/oneflow/core/auto_parallel/sbp_graph.h index 53caeac46b5..ba965360ae6 100644 --- a/oneflow/core/auto_parallel/sbp_graph.h +++ b/oneflow/core/auto_parallel/sbp_graph.h @@ -796,7 +796,7 @@ int32_t SbpGraph::PickAndMerge() { SbpEdge* merging_edge = nullptr; for (int32_t i = 0; i < NodeList.size(); i++) { for (SbpEdge* edge_in : NodeList[i]->EdgesIn) { - curr_cut_ratio = edge_in->FindCutRatio(Threshold * 10); + curr_cut_ratio = edge_in->FindCutRatio(Threshold); if (curr_cut_ratio < min_cut_ratio) { min_cut_ratio = curr_cut_ratio; merging_edge = edge_in; diff --git a/oneflow/core/common/util.cpp b/oneflow/core/common/util.cpp index a3fc3941ab2..45a28190af9 100644 --- a/oneflow/core/common/util.cpp +++ b/oneflow/core/common/util.cpp @@ -160,4 +160,13 @@ std::string GetStringFromEnv(const std::string& env_var, const std::string& defa } } +double ParseDoubleFromEnv(const std::string& env_var, double default_value) { + const char* env_p = std::getenv(env_var.c_str()); + if (env_p == nullptr) { + return default_value; + } else { + return strtod(env_p, NULL); + } +} + } // namespace oneflow diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index 395fb6c96ba..8dfe70f157e 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -241,6 +241,8 @@ int64_t ParseIntegerFromEnv(const std::string& env_var, int64_t default_value); std::string GetStringFromEnv(const std::string& env_var, const std::string& default_value); +double ParseDoubleFromEnv(const std::string& env_var, double default_value); + #define OF_PREDICT_TRUE likely #define OF_PREDICT_FALSE unlikely diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 3157f9a6742..dd92450d92d 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -15,10 +15,13 @@ limitations under the License. */ #include "oneflow/core/framework/sbp_infer_util.h" +#include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/util.h" #include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/job/parallel_desc.h" namespace oneflow { @@ -443,4 +446,46 @@ Maybe ComputeCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_paralle consumer_parallel_desc, requires_same_sbp); } +Maybe ComputeCopyCostWithMiddleNodes(const cfg::NdSbp& producer_sbp_parallel, + const cfg::NdSbp& consumer_sbp_parallel, + const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc, + bool requires_same_sbp) { + // Initialize boxing collector + constexpr int32_t kRegularMaxSplitAxes = 6; + static thread_local BoxingCollector boxing_collector(kRegularMaxSplitAxes); + std::vector middle_sbps; + // Ask for middle nodes + boxing_collector.AskSbpCombination(producer_sbp_parallel, consumer_sbp_parallel, + logical_blob_desc, producer_parallel_desc, + consumer_parallel_desc, /*is_customized=*/false, middle_sbps, + /*compute_cost=*/true); + // Parameters + double total_cost = 0.0; + double transfer_cost = ParseDoubleFromEnv("AUTO_PARALLEL_TRANSFER_COST", 1.65e7); + // Set up the information of the first node in the first connection + const cfg::NdSbp* pre_nd_sbp = &producer_sbp_parallel; + const ParallelDesc* pre_parallel_desc = &producer_parallel_desc; + // Connection for the next middle node + for (const auto& middle_sbp : middle_sbps) { + // We use the parallel description of consumer as the parallel description for all the middle + // nodes, following the same procedure in boxing_with_middle_nodes.cpp + // TODO: Needs more effort if dealing with different placement + total_cost += JUST(ComputeCopyCostBetweenNdSbp(*pre_nd_sbp, middle_sbp, logical_blob_desc, + *pre_parallel_desc, consumer_parallel_desc, + requires_same_sbp)) + + transfer_cost; + // Set up the information of the first node in the next connection + pre_nd_sbp = &middle_sbp; + pre_parallel_desc = &consumer_parallel_desc; + } + // Connection between the last middle node and consumer + total_cost += JUST(ComputeCopyCostBetweenNdSbp(*pre_nd_sbp, consumer_sbp_parallel, + logical_blob_desc, *pre_parallel_desc, + consumer_parallel_desc, requires_same_sbp)); + + return total_cost; +} + } // namespace oneflow diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index a5f2f85dd4c..dc248f21e31 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -40,7 +40,7 @@ double Storage4NdSbp(const cfg::NdSbp& nd_sbp, Shape& logical_shape, Maybe FilterNdSbpByLogicalShape(const cfg::NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy); -// TODO: unified lazy and eager boxing +// TODO: Unify lazy and eager boxing Maybe ComputeCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_parallel, const cfg::NdSbp& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, @@ -48,6 +48,15 @@ Maybe ComputeCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_paralle const ParallelDesc& consumer_parallel_desc, bool is_same_sbp); +// The public interface for computing cost +// It uses the middle nodes algorithm. +Maybe ComputeCopyCostWithMiddleNodes(const cfg::NdSbp& producer_sbp_parallel, + const cfg::NdSbp& consumer_sbp_parallel, + const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc, + bool requires_same_sbp); + } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_ diff --git a/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp b/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp index 71ce398ad70..f8fcaf31b87 100644 --- a/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp +++ b/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp @@ -47,13 +47,12 @@ Maybe BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_build const cfg::NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn); // Needs more effort if dealing with different placement - if (producer.parallel_desc() == node->parallel_desc() - && (node->parallel_desc().parallel_num() != 1 && producer_nd_sbp != consumer_nd_sbp)) { + if (node->parallel_desc().parallel_num() != 1 && producer_nd_sbp != consumer_nd_sbp) { const auto& logical_blob_desc = producer.LogicalBlobDesc4Lbi(lbi); // Ask for middle nodes - boxing_collector.AskSbpCombination(producer_nd_sbp, consumer_nd_sbp, logical_blob_desc, - producer.parallel_desc(), node->parallel_desc(), true, - middle_sbps); + boxing_collector.AskSbpCombination( + producer_nd_sbp, consumer_nd_sbp, logical_blob_desc, producer.parallel_desc(), + node->parallel_desc(), /*is_customized=*/false, middle_sbps, /*compute_cost=*/false); // move to the next ibn if no middle nodes needed if (middle_sbps.size() <= 0) { continue; } LogicalBlobId middle_node_lbi = lbi; diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 838c30869ea..321212143bc 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -695,40 +695,50 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( const std::vector& nd_sbp_sig_list) const { int32_t select_sbp_idx = -1; double min_copy_cost = GetValidMaxCopyCost(); - for (int32_t i = 0; i < nd_sbp_sig_list.size(); ++i) { - double total_copy_cost = 0.0; - for (const auto& ibn : input_bns()) { - const auto& blob_modifier_ = InputBlobModifier4Ibn(ibn); - bool is_same_sbp = - (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) - || (!IsPODDataType(JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc().data_type())); - total_copy_cost += JUST(ComputeCopyCostBetweenNdSbp( - JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(), nd_sbp_sig_list.at(i).bn_in_op2nd_sbp()[ibn], - JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc(), - JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), - is_same_sbp)); - } - if (total_copy_cost <= min_copy_cost) { - select_sbp_idx = i; - min_copy_cost = total_copy_cost; + // We notice that we have a lot of inquiries asking for the cost. + // If the candidate list only have one entry, select it to reduce the inquiries. + // Normally, we support all the sbp combination for boxing. Therefore, we do not need to worry + // about the case that we can not transfer to this sbp signature. Even if we do not support such + // transfer, a report would be sent in boxing_with_middle_nodes.cpp. + if (nd_sbp_sig_list.size() == 1) { + select_sbp_idx = 0; + } else { + for (int32_t i = 0; i < nd_sbp_sig_list.size(); ++i) { + double total_copy_cost = 0.0; + for (const auto& ibn : input_bns()) { + const auto& blob_modifier_ = InputBlobModifier4Ibn(ibn); + bool is_same_sbp = + (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) + || (!IsPODDataType(JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc().data_type())); + total_copy_cost += JUST(ComputeCopyCostWithMiddleNodes( + JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(), nd_sbp_sig_list.at(i).bn_in_op2nd_sbp()[ibn], + JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc(), + JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), + is_same_sbp)); + // Reduce inquiries + if (total_copy_cost > min_copy_cost) { break; } + } + if (total_copy_cost <= min_copy_cost) { + select_sbp_idx = i; + min_copy_cost = total_copy_cost; + // Reduce inquiries + if (total_copy_cost == 0.0) { break; } + } } - } - // Can't find any available sbp - if (select_sbp_idx == -1) { - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "Condidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - { - std::ostringstream input_sbp_str; + // Can't find any available sbp + if (select_sbp_idx == -1) { + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "Condidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; for (const auto& ibn : input_bns()) { const cfg::NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - input_sbp_str << " " << ibn << ": " << NdSbpToString(nd_sbp) << ";"; + err << " " << ibn << ": " << NdSbpToString(nd_sbp) << ";"; } - err << input_sbp_str.str() << std::endl; + + return Error::RuntimeError() << err.str(); } - return Error::RuntimeError() << err.str(); } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); return Maybe::Ok();