diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index c056157d339..0614eb729db 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -94,7 +94,7 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, std::shared_ptr result; NonRecursiveMetaInfoConsistencyCheckScope scope; if (inputs.empty()) { - // check consistency placment and nd_sbp, do not check in non-src op because it is assumed that + // check consistency placement and nd_sbp, do not check in non-src op because it is assumed that // InferSbp in op is a deterministic algorithm JUST(MetaInfoConsistencyCheck(parallel_desc, ctx.nd_sbp)); const auto& infer_args = diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index 3d4daea72e8..4d83f1ea137 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -332,8 +332,11 @@ Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, if (transposed_input->is_consistent()) { const auto& placement = JUST(transposed_input->parallel_desc()); const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel()); + int n = JUST(input->nd_sbp())->sbp_parallel_size(); std::vector> grad_sbp_tuple; - packed_indices = JUST(ToConsistent(packed_indices, placement, {broadcast_sbp}, grad_sbp_tuple)); + packed_indices = + JUST(ToConsistent(packed_indices, placement, + std::vector>(n, broadcast_sbp), grad_sbp_tuple)); } else { Symbol device = JUST(transposed_input->device()); if (JUST(packed_indices->device()) != device) { diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index a5086360e74..0d254767589 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -2489,6 +2489,7 @@ def OneFlow_FloorOp : OneFlow_IdempotentBaseOp<"floor", [NoSideEffect, DeclareOp def OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<"ones_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { let same_output_regst_num = 1; + let has_nd_sbp_infer_fn = 1; } def OneFlow_ReluOp : OneFlow_IdempotentBaseOp<"relu", [NoSideEffect, DeclareOpInterfaceMethods]> {} @@ -8556,6 +8557,7 @@ def OneFlow_ZeroLikeOp : OneFlow_BaseOp<"zero_like", [NoSideEffect, NoGrad, Decl let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; } #endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index cf05b880f87..488c30a0593 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -43,5 +43,19 @@ namespace oneflow { *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } +/*static*/ Maybe OnesLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); + cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + *like_distribution = in_sbp; + *out_distribution = in_sbp; + for (auto& sbp : *out_distribution->mutable_sbp_parallel()) { + if (sbp.has_partial_sum_parallel()) { + sbp.Clear(); + *sbp.mutable_broadcast_parallel() = cfg::BroadcastParallel(); + } + } + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index 6e650556069..e67bac631db 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -43,5 +43,13 @@ namespace oneflow { *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); return Maybe::Ok(); } +/*static*/ Maybe ZeroLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_sbp = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); + cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + *like_distribution = in_sbp; + *out_distribution = in_sbp; + return Maybe::Ok(); +} } // namespace oneflow