diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index f6554fc95c1..276657b37ac 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -34,16 +34,22 @@ void ParallelDimReduce(const ParallelDesc& parallel_desc, const cfg::NdSbp& nd_s ParallelDesc* reduced_parallel_desc, cfg::NdSbp* reduced_nd_sbp) { const auto& hierarchy = parallel_desc.hierarchy(); DimVector reduced_hierarchy; - reduced_hierarchy.emplace_back(hierarchy->At(0)); - *reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(0); - FOR_RANGE(int64_t, i, 1, hierarchy->NumAxes()) { - if (nd_sbp.sbp_parallel(i) == nd_sbp.sbp_parallel(i - 1)) { - reduced_hierarchy.back() *= hierarchy->At(i); - } else { - reduced_hierarchy.emplace_back(hierarchy->At(i)); - *reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(i); + FOR_RANGE(int64_t, i, 0, hierarchy->NumAxes()) { + if (hierarchy->At(i) != 1) { + if (reduced_nd_sbp->sbp_parallel().empty() + || (nd_sbp.sbp_parallel(i) + != reduced_nd_sbp->sbp_parallel(reduced_nd_sbp->sbp_parallel_size() - 1))) { + reduced_hierarchy.emplace_back(hierarchy->At(i)); + *reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(i); + } else { + reduced_hierarchy.back() *= hierarchy->At(i); + } } } + if (reduced_hierarchy.empty()) { + reduced_hierarchy.emplace_back(hierarchy->At(0)); + *reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(0); + } ParallelConf reduced_parallel_conf = parallel_desc.parallel_conf(); Shape(reduced_hierarchy).ToProto(reduced_parallel_conf.mutable_hierarchy()); *reduced_parallel_desc = ParallelDesc(reduced_parallel_conf); @@ -60,26 +66,33 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, CHECK_EQ(in_hierarchy->NumAxes(), out_hierarchy->NumAxes()); DimVector reduced_in_hierarchy; - reduced_in_hierarchy.emplace_back(in_hierarchy->At(0)); - *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(0); - DimVector reduced_out_hierarchy; - reduced_out_hierarchy.emplace_back(out_hierarchy->At(0)); - *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(0); - - FOR_RANGE(int64_t, i, 1, in_hierarchy->NumAxes()) { - if ((in_nd_sbp.sbp_parallel(i) == in_nd_sbp.sbp_parallel(i - 1)) - && (out_nd_sbp.sbp_parallel(i) == out_nd_sbp.sbp_parallel(i - 1))) { - reduced_in_hierarchy.back() *= in_hierarchy->At(i); - reduced_out_hierarchy.back() *= out_hierarchy->At(i); - } else { - reduced_in_hierarchy.emplace_back(in_hierarchy->At(i)); - *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(i); - - reduced_out_hierarchy.emplace_back(out_hierarchy->At(i)); - *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(i); + FOR_RANGE(int64_t, i, 0, in_hierarchy->NumAxes()) { + if (in_hierarchy->At(i) != 1 || out_hierarchy->At(i) != 1) { + if (reduced_in_nd_sbp->sbp_parallel().empty() + || (in_nd_sbp.sbp_parallel(i) + != reduced_in_nd_sbp->sbp_parallel(reduced_in_nd_sbp->sbp_parallel_size() - 1) + || out_nd_sbp.sbp_parallel(i) + != reduced_out_nd_sbp->sbp_parallel(reduced_out_nd_sbp->sbp_parallel_size() + - 1))) { + reduced_in_hierarchy.emplace_back(in_hierarchy->At(i)); + *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(i); + + reduced_out_hierarchy.emplace_back(out_hierarchy->At(i)); + *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(i); + } else { + reduced_in_hierarchy.back() *= in_hierarchy->At(i); + reduced_out_hierarchy.back() *= out_hierarchy->At(i); + } } } + if (reduced_in_hierarchy.empty()) { + reduced_in_hierarchy.emplace_back(in_hierarchy->At(0)); + *reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(0); + + reduced_out_hierarchy.emplace_back(out_hierarchy->At(0)); + *reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(0); + } ParallelConf reduced_in_parallel_conf = in_parallel_desc.parallel_conf(); Shape(reduced_in_hierarchy).ToProto(reduced_in_parallel_conf.mutable_hierarchy());