Skip to content

Commit

Permalink
Prune parallel dim with val eq one in parallel dim reduce (#7257)
Browse files Browse the repository at this point in the history
* fix Resource::DumpCudnnConf

* prune_parallel_dim_with_val_eq_one_in_parallel_dim_reduce

* minor fix

* refine Prune

* refine

* refine

* minor fix

* fix bug

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
clackhan and oneflow-ci-bot authored Jan 19, 2022
1 parent 6137422 commit 0dbb976
Showing 1 changed file with 38 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,22 @@ void ParallelDimReduce(const ParallelDesc& parallel_desc, const cfg::NdSbp& nd_s
ParallelDesc* reduced_parallel_desc, cfg::NdSbp* reduced_nd_sbp) {
const auto& hierarchy = parallel_desc.hierarchy();
DimVector reduced_hierarchy;
reduced_hierarchy.emplace_back(hierarchy->At(0));
*reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(0);
FOR_RANGE(int64_t, i, 1, hierarchy->NumAxes()) {
if (nd_sbp.sbp_parallel(i) == nd_sbp.sbp_parallel(i - 1)) {
reduced_hierarchy.back() *= hierarchy->At(i);
} else {
reduced_hierarchy.emplace_back(hierarchy->At(i));
*reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(i);
FOR_RANGE(int64_t, i, 0, hierarchy->NumAxes()) {
if (hierarchy->At(i) != 1) {
if (reduced_nd_sbp->sbp_parallel().empty()
|| (nd_sbp.sbp_parallel(i)
!= reduced_nd_sbp->sbp_parallel(reduced_nd_sbp->sbp_parallel_size() - 1))) {
reduced_hierarchy.emplace_back(hierarchy->At(i));
*reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(i);
} else {
reduced_hierarchy.back() *= hierarchy->At(i);
}
}
}
if (reduced_hierarchy.empty()) {
reduced_hierarchy.emplace_back(hierarchy->At(0));
*reduced_nd_sbp->add_sbp_parallel() = nd_sbp.sbp_parallel(0);
}
ParallelConf reduced_parallel_conf = parallel_desc.parallel_conf();
Shape(reduced_hierarchy).ToProto(reduced_parallel_conf.mutable_hierarchy());
*reduced_parallel_desc = ParallelDesc(reduced_parallel_conf);
Expand All @@ -60,26 +66,33 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc,
CHECK_EQ(in_hierarchy->NumAxes(), out_hierarchy->NumAxes());

DimVector reduced_in_hierarchy;
reduced_in_hierarchy.emplace_back(in_hierarchy->At(0));
*reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(0);

DimVector reduced_out_hierarchy;
reduced_out_hierarchy.emplace_back(out_hierarchy->At(0));
*reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(0);

FOR_RANGE(int64_t, i, 1, in_hierarchy->NumAxes()) {
if ((in_nd_sbp.sbp_parallel(i) == in_nd_sbp.sbp_parallel(i - 1))
&& (out_nd_sbp.sbp_parallel(i) == out_nd_sbp.sbp_parallel(i - 1))) {
reduced_in_hierarchy.back() *= in_hierarchy->At(i);
reduced_out_hierarchy.back() *= out_hierarchy->At(i);
} else {
reduced_in_hierarchy.emplace_back(in_hierarchy->At(i));
*reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(i);

reduced_out_hierarchy.emplace_back(out_hierarchy->At(i));
*reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(i);
FOR_RANGE(int64_t, i, 0, in_hierarchy->NumAxes()) {
if (in_hierarchy->At(i) != 1 || out_hierarchy->At(i) != 1) {
if (reduced_in_nd_sbp->sbp_parallel().empty()
|| (in_nd_sbp.sbp_parallel(i)
!= reduced_in_nd_sbp->sbp_parallel(reduced_in_nd_sbp->sbp_parallel_size() - 1)
|| out_nd_sbp.sbp_parallel(i)
!= reduced_out_nd_sbp->sbp_parallel(reduced_out_nd_sbp->sbp_parallel_size()
- 1))) {
reduced_in_hierarchy.emplace_back(in_hierarchy->At(i));
*reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(i);

reduced_out_hierarchy.emplace_back(out_hierarchy->At(i));
*reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(i);
} else {
reduced_in_hierarchy.back() *= in_hierarchy->At(i);
reduced_out_hierarchy.back() *= out_hierarchy->At(i);
}
}
}
if (reduced_in_hierarchy.empty()) {
reduced_in_hierarchy.emplace_back(in_hierarchy->At(0));
*reduced_in_nd_sbp->add_sbp_parallel() = in_nd_sbp.sbp_parallel(0);

reduced_out_hierarchy.emplace_back(out_hierarchy->At(0));
*reduced_out_nd_sbp->add_sbp_parallel() = out_nd_sbp.sbp_parallel(0);
}

ParallelConf reduced_in_parallel_conf = in_parallel_desc.parallel_conf();
Shape(reduced_in_hierarchy).ToProto(reduced_in_parallel_conf.mutable_hierarchy());
Expand Down

0 comments on commit 0dbb976

Please sign in to comment.