Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat unbalanced split nd sbp #9310

Merged
merged 39 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d7c11bc
Add a GetSbpSignature with use parallel num
Yipeng1994 Oct 24, 2022
f30f29d
Get sbp_sig_list for each dimension of hierarchy
Yipeng1994 Oct 24, 2022
fdc7ee8
Add test script and print out information
Yipeng1994 Oct 24, 2022
e1b4a96
Remove parallel description in GetSbpSignature()
Yipeng1994 Oct 24, 2022
dc23ff7
Fix small bug
Yipeng1994 Oct 24, 2022
195b0ea
Disable InferNdSbp for reshape op
Yipeng1994 Oct 24, 2022
3f6d981
Merge branch 'master' into refactor-GetSbpSignature
Yipeng1994 Oct 25, 2022
f7d29d1
Revert "Add test script and print out information"
Yipeng1994 Oct 25, 2022
f20e222
Use the same physical shape as eager did
Yipeng1994 Oct 25, 2022
cf332ce
Remove the difference between eager and lazy for physical shape
Yipeng1994 Oct 25, 2022
aa42b6d
Update the filter
Yipeng1994 Oct 25, 2022
6e87762
Revert "Use the same physical shape as eager did"
Yipeng1994 Oct 25, 2022
0f1554d
Compute range for each rank
Yipeng1994 Oct 25, 2022
4d7edc7
Compute position for range
Yipeng1994 Oct 25, 2022
b2249c3
Remove the difference between eager and lazy
Yipeng1994 Oct 25, 2022
3c1362b
Allow unbalanced split for variables
Yipeng1994 Oct 25, 2022
58cdfb4
Add test script and print out information
Yipeng1994 Oct 26, 2022
9a974e1
Pass 2d test cases
Yipeng1994 Oct 26, 2022
85f28a2
Merge branch 'feat-unbalanced_split-nd_sbp' of github.com:Oneflow-Inc…
Yipeng1994 Oct 26, 2022
1c2653c
Merge branch 'master' into feat-unbalanced_split-nd_sbp
Yipeng1994 Oct 27, 2022
909bade
Merge branch 'feat-unbalanced_split-nd_sbp' of github.com:Oneflow-Inc…
Yipeng1994 Oct 27, 2022
9243427
Resolve conflict
Yipeng1994 Oct 27, 2022
cca3adb
Can not merge some split
Yipeng1994 Oct 28, 2022
3d06827
Reduce in and out sbp simultaneously
Yipeng1994 Oct 28, 2022
34c3395
Speed up for 1d sbp
Yipeng1994 Oct 31, 2022
fc6127f
Reduced simultaneously with the same hierarchy
Yipeng1994 Oct 31, 2022
38f405c
Deal with 1to2d and 2to1d in InOutParallelDimReduce()
Yipeng1994 Oct 31, 2022
c414ce7
Pass 1to2d and 2to1d test cases
Yipeng1994 Oct 31, 2022
75d5a82
Remove the old code
Yipeng1994 Oct 31, 2022
575a55f
Revert "Add test script and print out information"
Yipeng1994 Oct 31, 2022
c8216b9
Add the check for split questionary back
Yipeng1994 Oct 31, 2022
eaf6681
Merge branch 'master' into feat-unbalanced_split-nd_sbp
Yipeng1994 Oct 31, 2022
7deae83
Feat speed up cost computation (#9355)
Yipeng1994 Nov 2, 2022
5a2049b
Merge branch 'master' into feat-unbalanced_split-nd_sbp
Yipeng1994 Nov 2, 2022
efbb118
fix comment typeo
Yipeng1994 Nov 4, 2022
9002bb5
Address comment
Yipeng1994 Nov 4, 2022
4233d6d
Merge branch 'master' into feat-unbalanced_split-nd_sbp
Yipeng1994 Nov 4, 2022
3103ba8
Merge branch 'master' into feat-unbalanced_split-nd_sbp
mergify[bot] Nov 4, 2022
2b7c066
Merge branch 'master' into feat-unbalanced_split-nd_sbp
mergify[bot] Nov 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions oneflow/core/boxing/nd_sbp_dim_reduce_boxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ namespace oneflow {
namespace {

Maybe<std::tuple<Symbol<PlacedNdSbp>, Symbol<PlacedNdSbp>>> RawInOutPlacedNdSbpDimReduce(
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out) {
Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out, const Shape& logical_shape) {
// reduce hierarchy
ParallelDesc reduced_in_placement = *in->placement();
ParallelDesc reduced_out_placement = *out->placement();
NdSbp reduced_in_nd_sbp;
NdSbp reduced_out_nd_sbp;
InOutParallelDimReduce(*in->placement(), *out->placement(), *in->nd_sbp(), *out->nd_sbp(),
&reduced_in_placement, &reduced_out_placement, &reduced_in_nd_sbp,
&reduced_out_nd_sbp);
&reduced_out_nd_sbp, logical_shape);
return std::make_tuple(
JUST(PlacedNdSbp::New(SymbolOf(reduced_in_nd_sbp), SymbolOf(reduced_in_placement))),
JUST(PlacedNdSbp::New(SymbolOf(reduced_out_nd_sbp), SymbolOf(reduced_out_placement))));
}

constexpr auto* InOutPlacedNdSbpDimReduce =
DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCached);
DECORATE(&RawInOutPlacedNdSbpDimReduce, ThreadLocalCachedCopiable);

// NOLINTBEGIN(maybe-need-error-msg)
Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp> out,
Expand All @@ -51,7 +51,7 @@ Maybe<void> RawCheckParallelDimReduce(Symbol<PlacedNdSbp> in, Symbol<PlacedNdSbp
CHECK_EQ_OR_RETURN(in->placement()->device_tag(), out->placement()->device_tag());
Symbol<PlacedNdSbp> reduced_in;
Symbol<PlacedNdSbp> reduced_out;
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out));
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, logical_shape));

for (int64_t in_parallel_id = 0; in_parallel_id < in->placement()->parallel_num();
++in_parallel_id) {
Expand Down Expand Up @@ -102,7 +102,7 @@ Maybe<one::Tensor> ParallelDimReduce(const std::shared_ptr<one::Tensor>& tensor,

Symbol<PlacedNdSbp> reduced_in;
Symbol<PlacedNdSbp> reduced_out;
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out));
std::tie(reduced_in, reduced_out) = *JUST(InOutPlacedNdSbpDimReduce(in, out, *tensor->shape()));

const std::shared_ptr<one::Tensor>& local_tensor = JUST(tensor->cur_rank_phy_tensor());

Expand Down
470 changes: 313 additions & 157 deletions oneflow/core/framework/sbp_infer_util.cpp

Large diffs are not rendered by default.

20 changes: 16 additions & 4 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,25 @@ int32_t PartialRatio4Producer(const NdSbp& sbp_producer,
int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer,
const ParallelDesc& consumer_parallel_desc);

void NdSbpDimReduce(const Shape& hierarchy, const NdSbp& nd_sbp, Shape* reduced_hierarchy,
NdSbp* reduced_nd_sbp, const Shape& logical_shape);
void NdSbpsDimReduce(const Shape& hierarchy, const std::vector<const NdSbp*>& nd_sbps,
Shape* reduced_hierarchy, const std::vector<NdSbp*>& reduced_nd_sbps,
const Shape& logical_shape);
void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,
ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp);

ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp,
const Shape& logical_shape);

void InOutParallelDimReduce(const Shape& in_hierarchy, const Shape& out_hierarchy,
const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,
Shape* reduced_in_hierarchy, Shape* reduced_out_hierarchy,
NdSbp* reduced_in_nd_sbp, NdSbp* reduced_out_nd_sbp,
const Shape& logical_shape);
void InOutParallelDimReduce(const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp,
const NdSbp& out_nd_sbp, ParallelDesc* reduced_in_parallel_desc,
ParallelDesc* reduced_out_parallel_desc, NdSbp* reduced_in_nd_sbp,
NdSbp* reduced_out_nd_sbp);
NdSbp* reduced_out_nd_sbp, const Shape& logical_shape);

double GetValidMaxCopyCost();

Expand Down Expand Up @@ -105,7 +116,8 @@ Maybe<double> ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel,
double ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel,
const NdSbp& consumer_sbp_parallel,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp);
const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp,
const Shape& logical_shape);

// The transfer ratio for general basic communication
// Cost = ratio * data amount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,71 +411,17 @@ class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilde
std::unique_ptr<Dim0NdSbpMismatchedSubTskGphBuilder> dim0_nd_sbp_mismatched_sub_tsk_gph_builder_;
};

class ExpandToSame2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder {
public:
OF_DISALLOW_COPY_AND_MOVE(ExpandToSame2DHierarchySubTskGphBuilder);
ExpandToSame2DHierarchySubTskGphBuilder() {
same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder());
}
~ExpandToSame2DHierarchySubTskGphBuilder() override = default;

Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx,
const std::vector<TaskNode*>& sorted_in_tasks,
std::vector<TaskNode*>* sorted_out_tasks,
std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks,
const ParallelDesc& in_parallel_desc,
const ParallelDesc& out_parallel_desc,
const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc,
const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp,
const Shape& time_shape) const override {
if (in_parallel_desc.hierarchy()->elem_cnt() == out_parallel_desc.hierarchy()->elem_cnt()
&& in_parallel_desc.hierarchy()->NumAxes() == 1
&& out_parallel_desc.hierarchy()->NumAxes() == 2) {
ParallelConf intermediate_parallel_conf = in_parallel_desc.parallel_conf();
out_parallel_desc.hierarchy()->ToProto(intermediate_parallel_conf.mutable_hierarchy());
NdSbp intermediate_nd_sbp;
*intermediate_nd_sbp.add_sbp_parallel() = in_nd_sbp.sbp_parallel(0);
*intermediate_nd_sbp.add_sbp_parallel() = in_nd_sbp.sbp_parallel(0);
return same_2d_hierarchy_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks,
ParallelDesc(intermediate_parallel_conf), out_parallel_desc, lbi, logical_blob_desc,
intermediate_nd_sbp, out_nd_sbp, time_shape);
} else if (in_parallel_desc.hierarchy()->elem_cnt() == out_parallel_desc.hierarchy()->elem_cnt()
&& in_parallel_desc.hierarchy()->NumAxes() == 2
&& out_parallel_desc.hierarchy()->NumAxes() == 1) {
ParallelConf intermediate_parallel_conf = out_parallel_desc.parallel_conf();
in_parallel_desc.hierarchy()->ToProto(intermediate_parallel_conf.mutable_hierarchy());
NdSbp intermediate_nd_sbp;
*intermediate_nd_sbp.add_sbp_parallel() = out_nd_sbp.sbp_parallel(0);
*intermediate_nd_sbp.add_sbp_parallel() = out_nd_sbp.sbp_parallel(0);
return same_2d_hierarchy_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc,
ParallelDesc(intermediate_parallel_conf), lbi, logical_blob_desc, in_nd_sbp,
intermediate_nd_sbp, time_shape);
} else {
return Error::BoxingNotSupportedError();
}
}

private:
std::unique_ptr<Same2DHierarchySubTskGphBuilder> same_2d_hierarchy_sub_tsk_gph_builder_;
};

struct DispatchHierarchicalSubTskGphBuilder::Impl {
Impl();
std::unique_ptr<FlatSubTskGphBuilder> flat_sub_tsk_gph_builder_;
std::unique_ptr<Same2DHierarchySubTskGphBuilder> same_2d_hierarchy_sub_tsk_gph_builder_;
std::unique_ptr<ExpandToSame2DHierarchySubTskGphBuilder>
expand_to_same_2d_hierarchy_sub_tsk_gph_builder_;
std::unique_ptr<NDNcclSendRecvBoxingSubTskGphBuilder>
nd_nccl_send_recv_boxing_sub_tsk_gph_builder_;
};

DispatchHierarchicalSubTskGphBuilder::Impl::Impl() {
flat_sub_tsk_gph_builder_.reset(new FlatSubTskGphBuilder());
same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder());
expand_to_same_2d_hierarchy_sub_tsk_gph_builder_.reset(
new ExpandToSame2DHierarchySubTskGphBuilder());
nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder());
}

Expand All @@ -496,9 +442,12 @@ Maybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(
ParallelDesc reduced_out_parallel_desc = out_parallel_desc;
NdSbp reduced_in_nd_sbp;
NdSbp reduced_out_nd_sbp;
// The 1d to 2d and 2d to 1d cases are consider in this function
// If it gives out 1d sbp and 2d sbp simultaneously, then that the 2d sbp can not be converted
// to 1d sbp and 1d sbp can not be expanded to 2d sbp.
InOutParallelDimReduce(in_parallel_desc, out_parallel_desc, in_nd_sbp, out_nd_sbp,
&reduced_in_parallel_desc, &reduced_out_parallel_desc, &reduced_in_nd_sbp,
&reduced_out_nd_sbp);
&reduced_out_nd_sbp, logical_blob_desc.shape());
const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy();
const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy();
if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2)
Expand All @@ -520,13 +469,6 @@ Maybe<SubTskGphBuilderStatus> DispatchHierarchicalSubTskGphBuilder::Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,
time_shape);
} else if (in_hierarchy->elem_cnt() == out_hierarchy->elem_cnt()
&& ((in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 2)
|| (in_hierarchy->NumAxes() == 2 && out_hierarchy->NumAxes() == 1))) {
return impl_->expand_to_same_2d_hierarchy_sub_tsk_gph_builder_->Build(
ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc,
reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp,
time_shape);
} else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA
&& reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) {
return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build(
Expand Down
13 changes: 9 additions & 4 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/time_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
Expand Down Expand Up @@ -304,7 +305,8 @@ Maybe<void> JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t p
if (logical_blob_desc.shape().NumAxes() > 0) {
CHECK_GE_OR_RETURN(logical_blob_desc.shape().At(axis), blob_parallel_num)
<< "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name()
<< " cannot split blob by parallel_num: " << std::to_string(blob_parallel_num);
<< " shape: " << logical_blob_desc.shape()
<< " cannot be splitted by parallel_num: " << blob_parallel_num << " at axis " << axis;
}
}
} else {
Expand All @@ -318,10 +320,13 @@ Maybe<void> JobBuildAndInferCtx::CheckOpBlobSplitability(Operator* op, int64_t p
if (sbp_parallel.has_split_parallel()) {
const int64_t axis = sbp_parallel.split_parallel().axis();
CHECK_GT_OR_RETURN(current_shape.At(axis), 0);
CHECK_EQ_OR_RETURN(current_shape.At(axis) % parallel_hierarchy->At(i), 0)
// Support unbalanced splitting
CHECK_GE_OR_RETURN(current_shape.At(axis), parallel_hierarchy->At(i))
<< "op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name()
<< " cannot split blob by parallel_hierarchy: "
<< std::to_string(parallel_hierarchy->At(i));
<< " shape: " << logical_blob_desc.shape()
<< " cannot be splitted by nd sbp: " << NdSbpToString(pair.second) << " at axis "
<< axis << " with parallel_hierarchy: " << *parallel_hierarchy;
// Split and take the minimum one
current_shape.Set(axis, current_shape.At(axis) / parallel_hierarchy->At(i));
}
}
Expand Down
10 changes: 7 additions & 3 deletions oneflow/core/job/nd_sbp_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,19 @@ TensorSliceView GetTensorSliceView4ParallelRank(const Shape& parallel_hierarchy,
ranges[split_axis] = bs.At(id);
}
} else {
Shape physical_shape(logical_shape);
FOR_RANGE(int64_t, i, 0, parallel_hierarchy.NumAxes()) {
const SbpParallel& sbp_parallel = nd_sbp.sbp_parallel(i);
if (sbp_parallel.has_split_parallel()) {
const int64_t split_axis = sbp_parallel.split_parallel().axis();
CHECK_GE(split_axis, 0);
CHECK_LT(split_axis, ranges.size());
CHECK_EQ(ranges[split_axis].size() % parallel_hierarchy.At(i), 0);
const int64_t range_size = ranges[split_axis].size() / parallel_hierarchy.At(i);
const int64_t dim_start = ranges[split_axis].begin() + parallel_rank.at(i) * range_size;
CHECK_GE(ranges[split_axis].size(), parallel_hierarchy.At(i));
const BalancedSplitter bs(physical_shape.At(split_axis), parallel_hierarchy.At(i));
const auto& range = bs.At(parallel_rank.at(i));
const int64_t range_size = range.size();
const int64_t dim_start = ranges[split_axis].begin() + range.begin();
physical_shape.Set(split_axis, range_size);
ranges[split_axis].mut_begin() = dim_start;
ranges[split_axis].mut_end() = dim_start + range_size;
}
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,23 @@ Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_bu
if (blob_modifier_.has_is_mutable() && blob_modifier_.is_mutable()) { continue; }
const LogicalBlobId& lbi = node->op().BnInOp2Lbi(ibn);
const OpNode& producer = node->ProducerOpNode4Lbi(lbi);
const auto& logical_shape = node->LogicalBlobDesc4Lbi(lbi).shape();
const NdSbp& producer_nd_sbp = producer.NdSbp4Lbi(lbi);
const std::string& producer_lbn = *CHECK_JUST(producer.op().obn4lbi(lbi));
const ParallelDesc& producer_parallel_desc =
*CHECK_JUST(producer.op().GetParallelDesc4BnInOp(producer_lbn)).get();
ParallelDesc reduced_in_parallel_desc = producer_parallel_desc;
NdSbp reduced_in_nd_sbp;
NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc,
&reduced_in_nd_sbp);
&reduced_in_nd_sbp, logical_shape);

const NdSbp& consumer_nd_sbp = node->NdSbp4BnInOp(ibn);
const ParallelDesc& consumer_parallel_desc =
*CHECK_JUST(node->op().GetParallelDesc4BnInOp(ibn));
ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc;
NdSbp reduced_out_nd_sbp;
NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc,
&reduced_out_nd_sbp);
&reduced_out_nd_sbp, logical_shape);

if (reduced_in_parallel_desc == reduced_out_parallel_desc
&& reduced_in_nd_sbp == reduced_out_nd_sbp) {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const
InOutParallelDimReduce(src_node->parallel_desc(), dst_node->parallel_desc(),
src_node->NdSbp4Lbi(lbi), dst_node->NdSbp4Lbi(lbi),
src_reduced_parallel_desc, dst_reduced_parallel_desc, src_reduced_nd_sbp,
dst_reduced_nd_sbp);
dst_reduced_nd_sbp, logical_blob_desc.shape());

CHECK_EQ(src_reduced_parallel_desc->parallel_num(), dst_reduced_parallel_desc->parallel_num());
std::shared_ptr<Shape> src_reduced_hierarchy = src_reduced_parallel_desc->hierarchy();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy
if (sbp.has_split_parallel()) {
const int64_t dim = sbp.split_parallel().axis();
if (dim >= cur_shape.NumAxes()) { return false; }
// Evenly split.
if (cur_shape.At(dim) % hierachy.At(i) != 0) { return false; }
// Unbalanced split and take the minimum
cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i));
// Larger then min size.
if (cur_shape.elem_cnt() < min_size) { return false; }
Expand Down
27 changes: 9 additions & 18 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ Maybe<void> Operator::GreedilyFindMinCopyCostNdSbp(
producer_infer_hint4ibn->nd_sbp(),
JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn),
producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)),
requires_same_sbp[ibn_id]);
requires_same_sbp[ibn_id], producer_infer_hint4ibn->logical_blob_desc().shape());
sum_priority_ratio += priority_ratio;
// We do not accept any blob which has a priority ratio greater than 1
if (priority_ratio > 1.5) {
Expand Down Expand Up @@ -1647,26 +1647,17 @@ Maybe<Shape> GetNdHierarchyPhysicalShape(const Shape& logical_shape, const NdSbp
const auto& sbp_parallel = nd_sbp.sbp_parallel(i);
if (sbp_parallel.has_split_parallel()) {
const int64_t split_axis = sbp_parallel.split_parallel().axis();
if (LazyMode::is_enabled()) {
CHECK_EQ_OR_RETURN(physical->At(split_axis) % parallel_hierarchy.At(i), 0)
<< Error::RuntimeError() << "In nn.Graph, expected size at split axis (" << split_axis
<< ") of logical shape must be divisible by parallel num, but got logical_shape: "
// Both the lazy and eager mode support unbalanced splitting now
if (physical->At(split_axis) > 0) {
CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i))
<< Error::RuntimeError() << "Expected size at split axis (" << split_axis
<< ") of logical shape must be be greater than or equal to parallel num, but got "
"logical_shape: "
<< logical_shape.ToString()
<< ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc)))
<< ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp));
physical->Set(split_axis, physical->At(split_axis) / parallel_hierarchy.At(i));
} else {
if (physical->At(split_axis) > 0) {
CHECK_GE_OR_RETURN(physical->At(split_axis), parallel_hierarchy.At(i))
<< Error::RuntimeError() << "Expected size at split axis (" << split_axis
<< ") of logical shape must be be greater than or equal to parallel num, but got "
"logical_shape: "
<< logical_shape.ToString()
<< ", placement: " << *JUST(PlacementToString(SymbolOf(parallel_desc)))
<< ", nd_sbp: " << NdSbpToString(SymbolOf(nd_sbp));
const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i));
physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size());
}
const BalancedSplitter bs(physical->At(split_axis), parallel_hierarchy.At(i));
physical->Set(split_axis, bs.At(CalcIndex4Axis(parallel_id, hierarch_stride, i)).size());
}
}
}
Expand Down
Loading