Skip to content

Commit

Permalink
Fix boxing parallel desc (#7356)
Browse files Browse the repository at this point in the history
* fix(AutoParallel): fix lazy copy cost between op

* fix(*): fix nd_sbp_infer_hint bug

* fix(*): infer blob parallel desc before inferring sbp

* refine format

* refine param name
  • Loading branch information
wyg1997 authored Jan 25, 2022
1 parent 2824513 commit 0f57001
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 22 deletions.
7 changes: 5 additions & 2 deletions oneflow/core/auto_parallel/sbp_edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ void SbpEdge<SbpSignature>::InitializeCopyCost(const std::string& ibn, bool comp
if (use_sbp_collector_ && compute_cost && !SearchLbi(lbi)) { return; }

oneflow::OpNode* producer = StartNode->op_node;
const oneflow::ParallelDesc& producer_parallel_desc = producer->parallel_desc();
const oneflow::ParallelDesc& consumer_parallel_desc = consumer->parallel_desc();
const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi));
const oneflow::ParallelDesc& producer_parallel_desc =
*CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn));
const oneflow::ParallelDesc& consumer_parallel_desc =
*CHECK_JUST(consumer->op().GetParallelDesc4BnInOp(ibn));

// Need to be careful, the logical blob description should be independent to current
// SbpParallel. Use producer or op_node?
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/framework/consistent_tensor_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class UserOpExprOpDeviceInferContext final : public user_op::DeviceInferContext
}
const auto& op = JUST(MakeOp(user_op_expr, infer_args.attrs(), parallel_desc->device_tag()));
JUST(op->FillOpParallelDesc(parallel_desc.shared_from_symbol()));
JUST(op->InferParallelSignatureIf());
{
// Infer parallel distribution.
cfg::NdSbpSignature nd_sbp_constraints;
Expand Down
24 changes: 7 additions & 17 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Maybe<double> ComputeEagerCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_pa
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_same_sbp) {
bool requires_same_sbp) {
if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) {
return Error::RuntimeError() << "Illegal sbp parallel has been found.";
}
Expand All @@ -227,13 +227,7 @@ Maybe<double> ComputeEagerCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_pa
bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp;
// Same sbp is always supported.
if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; }

// Will directly modify output blob of source op. Requiring data having same sbp_parallel.
if (is_same_sbp
&& !(reduced_in_parallel_desc.EqualsIgnoringDeviceType(reduced_out_parallel_desc)
&& same_nd_sbp)) {
return kUnsupportedBoxing;
}
if (requires_same_sbp) { return kUnsupportedBoxing; }

int32_t in_dim = in_hierarchy->NumAxes();
int32_t out_dim = out_hierarchy->NumAxes();
Expand Down Expand Up @@ -285,7 +279,7 @@ Maybe<double> ComputeLazyCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_par
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_same_sbp) {
bool requires_same_sbp) {
if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) {
return Error::RuntimeError() << "Illegal sbp parallel has been found.";
}
Expand All @@ -308,13 +302,9 @@ Maybe<double> ComputeLazyCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_par
}

bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp;
// Same sbp is always supported.
if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; }
// Will directly modify output blob of source op. Requiring data having same sbp_parallel
if (is_same_sbp
&& !(reduced_in_parallel_desc.EqualsIgnoringDeviceType(reduced_out_parallel_desc)
&& same_nd_sbp)) {
return kUnsupportedBoxing;
}
if (requires_same_sbp) { return kUnsupportedBoxing; }

// We support different hierarchy for 1D sbp
if (in_dim == 1 && out_dim == 1) {
Expand Down Expand Up @@ -447,10 +437,10 @@ Maybe<double> ComputeCopyCostBetweenNdSbp(const cfg::NdSbp& producer_sbp_paralle
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool is_same_sbp) {
bool requires_same_sbp) {
return JUST(GetComputeCopyCostFunc())(producer_sbp_parallel, consumer_sbp_parallel,
logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, is_same_sbp);
consumer_parallel_desc, requires_same_sbp);
}

} // namespace oneflow
4 changes: 3 additions & 1 deletion oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ void OpGraph::InferOpNodeNdSbpSignature(OpNode* op_node,
for (const std::string& ibn : op_node->op().input_bns()) {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(ibn);
OpNode* producer = op_node->MutSrcNode4Ibn(ibn);
const ParallelDesc* parallel_desc = &producer->parallel_desc();
const std::string& producer_lbn = *CHECK_JUST(producer->op().obn4lbi(lbi));
const ParallelDesc* parallel_desc =
CHECK_JUST(producer->op().GetParallelDesc4BnInOp(producer_lbn)).get();
const BlobDesc* logical_blob_desc = &producer->LogicalBlobDesc4Lbi(lbi);
const cfg::NdSbp* nd_sbp = &producer->NdSbp4Lbi(lbi);
ibn2nd_sbp_infer_hint.emplace(ibn, NdSbpInferHint(parallel_desc, logical_blob_desc, nd_sbp));
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
return nullptr;
};
JUST(op->FillLogicalInBlobDesc(GetBlobDesc4BnInOp));
JUST(op->InferParallelSignatureIf());

// infer mirrored signature
JUST(InferMirroredSignature(op, is_mirrored_parallel_view, parallel_desc));
Expand Down Expand Up @@ -618,7 +619,6 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
}
return &iter->second;
};
JUST(op->InferParallelSignatureIf());
for (const auto& bn : op->output_bns()) {
lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn),
*JUST(op->GetParallelDesc4BnInOp(bn)));
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ Maybe<void> Operator::GreedilyFindMinCopyCostNdSbp(
total_copy_cost += JUST(ComputeCopyCostBetweenNdSbp(
JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(), nd_sbp_sig_list.at(i).bn_in_op2nd_sbp()[ibn],
JUST(NdSbpInferHint4Ibn(ibn))->logical_blob_desc(),
JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetOpParallelDesc()), is_same_sbp));
JUST(NdSbpInferHint4Ibn(ibn))->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)),
is_same_sbp));
}
if (total_copy_cost <= min_copy_cost) {
select_sbp_idx = i;
Expand Down

0 comments on commit 0f57001

Please sign in to comment.