From 0f570019b677ba762de1287631901baf6a75032f Mon Sep 17 00:00:00 2001 From: Yinggang Wang Date: Tue, 25 Jan 2022 17:33:36 +0800 Subject: [PATCH] Fix boxing parallel desc (#7356) * fix(AutoParallel): fix lazy copy cost between op * fix(*): fix nd_sbp_infer_hint bug * fix(*): infer blob parallel desc before inferring sbp * refine format * refine param name --- oneflow/core/auto_parallel/sbp_edge.h | 7 ++++-- .../consistent_tensor_infer_cache.cpp | 1 + oneflow/core/framework/sbp_infer_util.cpp | 24 ++++++------------- oneflow/core/graph/op_graph.cpp | 4 +++- oneflow/core/job/job_build_and_infer_ctx.cpp | 2 +- oneflow/core/operator/operator.cpp | 3 ++- 6 files changed, 19 insertions(+), 22 deletions(-) diff --git a/oneflow/core/auto_parallel/sbp_edge.h b/oneflow/core/auto_parallel/sbp_edge.h index 58f5e716aa9..81834f2010f 100644 --- a/oneflow/core/auto_parallel/sbp_edge.h +++ b/oneflow/core/auto_parallel/sbp_edge.h @@ -412,8 +412,11 @@ void SbpEdge::InitializeCopyCost(const std::string& ibn, bool comp if (use_sbp_collector_ && compute_cost && !SearchLbi(lbi)) { return; } oneflow::OpNode* producer = StartNode->op_node; - const oneflow::ParallelDesc& producer_parallel_desc = producer->parallel_desc(); - const oneflow::ParallelDesc& consumer_parallel_desc = consumer->parallel_desc(); + const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi)); + const oneflow::ParallelDesc& producer_parallel_desc = + *CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)); + const oneflow::ParallelDesc& consumer_parallel_desc = + *CHECK_JUST(consumer->op().GetParallelDesc4BnInOp(ibn)); // Need to be careful, the logical blob description should be independent to current // SbpParallel. Use producer or op_node? diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.cpp b/oneflow/core/framework/consistent_tensor_infer_cache.cpp index 6914f8c8795..908f10ea8cd 100644 --- a/oneflow/core/framework/consistent_tensor_infer_cache.cpp +++ b/oneflow/core/framework/consistent_tensor_infer_cache.cpp @@ -262,6 +262,7 @@ class UserOpExprOpDeviceInferContext final : public user_op::DeviceInferContext } const auto& op = JUST(MakeOp(user_op_expr, infer_args.attrs(), parallel_desc->device_tag())); JUST(op->FillOpParallelDesc(parallel_desc.shared_from_symbol())); + JUST(op->InferParallelSignatureIf()); { // Infer parallel distribution. cfg::NdSbpSignature nd_sbp_constraints; diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 155ccba49c2..3157f9a6742 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -200,7 +200,7 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_pa const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, - bool is_same_sbp) { + bool requires_same_sbp) { if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } @@ -227,13 +227,7 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_pa bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } - - // Will directly modify output blob of source op. Requiring data having same sbp_parallel. - if (is_same_sbp - && !(reduced_in_parallel_desc.EqualsIgnoringDeviceType(reduced_out_parallel_desc) - && same_nd_sbp)) { - return kUnsupportedBoxing; - } + if (requires_same_sbp) { return kUnsupportedBoxing; } int32_t in_dim = in_hierarchy->NumAxes(); int32_t out_dim = out_hierarchy->NumAxes(); @@ -285,7 +279,7 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_par const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, - bool is_same_sbp) { + bool requires_same_sbp) { if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } @@ -308,13 +302,9 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_par } bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; + // Same sbp is always supported. if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } - // Will directly modify output blob of source op. Requiring data having same sbp_parallel - if (is_same_sbp - && !(reduced_in_parallel_desc.EqualsIgnoringDeviceType(reduced_out_parallel_desc) - && same_nd_sbp)) { - return kUnsupportedBoxing; - } + if (requires_same_sbp) { return kUnsupportedBoxing; } // We support different hierarchy for 1D sbp if (in_dim == 1 && out_dim == 1) { @@ -447,10 +437,10 @@ Maybe ComputeCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_paralle const BlobDesc& logical_blob_desc, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, - bool is_same_sbp) { + bool requires_same_sbp) { return JUST(GetComputeCopyCostFunc())(producer_sbp_parallel, consumer_sbp_parallel, logical_blob_desc, producer_parallel_desc, - consumer_parallel_desc, is_same_sbp); + consumer_parallel_desc, requires_same_sbp); } } // namespace oneflow diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index 0e9b42e3589..97898c50b7e 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -306,7 +306,9 @@ void OpGraph::InferOpNodeNdSbpSignature(OpNode* op_node, for (const std::string& ibn : op_node->op().input_bns()) { const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn); OpNode* producer = op_node->MutSrcNode4Ibn(ibn); - const ParallelDesc* parallel_desc = &producer->parallel_desc(); + const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi)); + const ParallelDesc* parallel_desc = + CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)).get(); const BlobDesc* logical_blob_desc = &producer->LogicalBlobDesc4Lbi(lbi); const cfg::NdSbp* nd_sbp = &producer->NdSbp4Lbi(lbi); ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(parallel_desc, logical_blob_desc, nd_sbp)); diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 82529adfe0f..3626c645bd8 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -589,6 +589,7 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con return nullptr; }; JUST(op->FillLogicalInBlobDesc(GetBlobDesc4BnInOp)); + JUST(op->InferParallelSignatureIf()); // infer mirrored signature JUST(InferMirroredSignature(op, is_mirrored_parallel_view, parallel_desc)); @@ -618,7 +619,6 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con } return &iter->second; }; - JUST(op->InferParallelSignatureIf()); for (const auto& bn : op->output_bns()) { lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn), *JUST(op->GetParallelDesc4BnInOp(bn))); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 2aebb0e2665..838c30869ea 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -705,7 +705,8 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( total_copy_cost += JUST(ComputeCopyCostBetweenNdSbp( JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(), nd_sbp_sig_list.at(i).bn_in_op2nd_sbp()[ibn], JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc(), - JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetOpParallelDesc()), is_same_sbp)); + JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), + is_same_sbp)); } if (total_copy_cost <= min_copy_cost) { select_sbp_idx = i;