Skip to content

Commit

Permalink
Fix advance indexing bug and zeros_like sbp bug (#7238)
Browse files Browse the repository at this point in the history
* fix(*): fix advance indexing bug and zeros_like sbp bug

* fix(*): fix index sbp

* fix(*): fix nd sbp infer for ones_like_op

* fix ones like nd sbp infer

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Houjiang Chen <chenhoujiangcug@gmail.com>
  • Loading branch information
3 people authored Jan 26, 2022
1 parent 0aa355a commit 115186a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
std::shared_ptr<const ConsistentTensorInferResult> result;
NonRecursiveMetaInfoConsistencyCheckScope scope;
if (inputs.empty()) {
// check consistency placment and nd_sbp, do not check in non-src op because it is assumed that
// check consistency placement and nd_sbp, do not check in non-src op because it is assumed that
// InferSbp in op is a deterministic algorithm
JUST(MetaInfoConsistencyCheck(parallel_desc, ctx.nd_sbp));
const auto& infer_args =
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,11 @@ Maybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,
if (transposed_input->is_consistent()) {
const auto& placement = JUST(transposed_input->parallel_desc());
const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());
int n = JUST(input->nd_sbp())->sbp_parallel_size();
std::vector<Symbol<cfg::SbpParallel>> grad_sbp_tuple;
packed_indices = JUST(ToConsistent(packed_indices, placement, {broadcast_sbp}, grad_sbp_tuple));
packed_indices =
JUST(ToConsistent(packed_indices, placement,
std::vector<Symbol<cfg::SbpParallel>>(n, broadcast_sbp), grad_sbp_tuple));
} else {
Symbol<Device> device = JUST(transposed_input->device());
if (JUST(packed_indices->device()) != device) {
Expand Down
2 changes: 2 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,7 @@ def OneFlow_FloorOp : OneFlow_IdempotentBaseOp<"floor", [NoSideEffect, DeclareOp

def OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<"ones_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let same_output_regst_num = 1;
let has_nd_sbp_infer_fn = 1;
}

def OneFlow_ReluOp : OneFlow_IdempotentBaseOp<"relu", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {}
Expand Down Expand Up @@ -8556,6 +8557,7 @@ def OneFlow_ZeroLikeOp : OneFlow_BaseOp<"zero_like", [NoSideEffect, NoGrad, Decl
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
let has_nd_sbp_infer_fn = 1;
}

#endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS
Expand Down
14 changes: 14 additions & 0 deletions oneflow/user/ops/ones_like_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,19 @@ namespace oneflow {
*ctx->OutputDType("out", 0) = ctx->InputDType("like", 0);
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> OnesLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {
const cfg::NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0);
cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0);
cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0);
*like_distribution = in_sbp;
*out_distribution = in_sbp;
for (auto& sbp : *out_distribution->mutable_sbp_parallel()) {
if (sbp.has_partial_sum_parallel()) {
sbp.Clear();
*sbp.mutable_broadcast_parallel() = cfg::BroadcastParallel();
}
}
return Maybe<void>::Ok();
}

} // namespace oneflow
8 changes: 8 additions & 0 deletions oneflow/user/ops/zero_like_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,13 @@ namespace oneflow {
*ctx->OutputDType("out", 0) = ctx->InputDType("like", 0);
return Maybe<void>::Ok();
}
/*static*/ Maybe<void> ZeroLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) {
const cfg::NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0);
cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0);
cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0);
*like_distribution = in_sbp;
*out_distribution = in_sbp;
return Maybe<void>::Ok();
}

} // namespace oneflow

0 comments on commit 115186a

Please sign in to comment.