Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat general basic communication #8437

Merged
merged 76 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
9b8664c
Add a slight cost for B->S and B->P in 2d sbp
Yipeng1994 Jun 5, 2022
c32adbb
Add penalty for P in consumer
Yipeng1994 Jun 6, 2022
362a426
Merge branch 'master' into patch-sbp_cost
Yipeng1994 Jun 7, 2022
403d429
Fix a slight bug
Yipeng1994 Jun 7, 2022
3ff0d59
Add at most 1 middle node for general basic communication
Yipeng1994 Jun 7, 2022
998f883
Add the cost for general basic communication
Yipeng1994 Jun 8, 2022
e53ffbc
Add the slight penalty for eager
Yipeng1994 Jun 8, 2022
9145dc6
Merge branch 'patch-sbp_cost' of github.com:Oneflow-Inc/oneflow into …
Yipeng1994 Jun 8, 2022
0373803
Skip initialization of boxing collector if not needed
Yipeng1994 Jun 9, 2022
e7164a7
Fix a bug
Yipeng1994 Jun 9, 2022
2b16f1b
Dev nd nccl send recv boxing (#8467)
Yipeng1994 Jun 23, 2022
2a1810c
Support different hierarchy
Yipeng1994 Jun 23, 2022
f46efa1
Merge branch 'master' into feat-general_basic_communication (#8477)
Yipeng1994 Jun 23, 2022
2c30a03
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jun 23, 2022
900dcc8
Ask general basic communication before middle nodes
Yipeng1994 Jun 23, 2022
5eb7510
Add a task type for general basic communication
Yipeng1994 Jun 23, 2022
8f39d6d
Fix a bug
Yipeng1994 Jun 23, 2022
0c9fd15
Fix a bug
Yipeng1994 Jun 24, 2022
0c95d76
Fix the bug of transfer from 1d sbp to 2d sbp
Yipeng1994 Jun 24, 2022
d4cf04c
Use the intersection to approximate the ratio
Yipeng1994 Jun 27, 2022
e843e7e
Use a suitable virtual blob description
Yipeng1994 Jun 27, 2022
7e7724a
Remove the checking for balanced splitting
Yipeng1994 Jun 28, 2022
e7d7fe3
Fix the previous bug, still have another one
Yipeng1994 Jun 29, 2022
afee00a
Fix another bug
Yipeng1994 Jun 29, 2022
3b6baad
Update oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_…
guo-ran Jun 30, 2022
90b0a5d
Use machine 4-7 for hierarchy [2, 2] in the consumer
Yipeng1994 Jun 30, 2022
bcede72
Add a switch for general basic communication
Yipeng1994 Jun 30, 2022
79c905f
Add test script and of format
Yipeng1994 Jun 30, 2022
600c384
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jun 30, 2022
fe8fd38
Fix conflit of master and remove print-out information
Yipeng1994 Jun 30, 2022
3605117
Skip middle nodes if not enough gains
Yipeng1994 Jul 1, 2022
e9e2d42
Fix a typo
Yipeng1994 Jul 1, 2022
3cf45b2
fix nccl send recv bug for different stream
guo-ran Jul 2, 2022
1b9705d
Merge branch 'feat-general_basic_communication' of https://github.com…
guo-ran Jul 2, 2022
27aef14
hot fix for ncclComm init
guo-ran Jul 4, 2022
d837d73
Reuse streams for different jobs
Yipeng1994 Jul 8, 2022
ba065fd
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 8, 2022
c3c0074
Rename and of format
Yipeng1994 Jul 8, 2022
ea2ce43
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 8, 2022
0076c1d
Skip general basic communication for transfer
Yipeng1994 Jul 11, 2022
bd8444d
Merge branch 'master' into feat-general_basic_communication
strint Jul 13, 2022
f09dc85
Address suggestion
Yipeng1994 Jul 14, 2022
f4ea3c2
Use the more powerful GetRankSendRecvIntersection
Yipeng1994 Jul 18, 2022
e426c52
Register nccl send recv op for comm init before
Yipeng1994 Jul 18, 2022
41e2e57
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 20, 2022
9d2808e
Remove irrelevant scripts
Yipeng1994 Jul 20, 2022
b32c133
Address suggestion and of format
Yipeng1994 Jul 21, 2022
d65ae0d
Address suggestion
Yipeng1994 Jul 21, 2022
96487cf
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 21, 2022
921d0d4
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 21, 2022
5869aa5
Static analysis
Yipeng1994 Jul 21, 2022
e21597e
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 21, 2022
983ae75
Static analysis. Still have another one
Yipeng1994 Jul 21, 2022
e87770c
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 21, 2022
5b01f68
Static analysis
Yipeng1994 Jul 22, 2022
7cbbeac
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
b23f6b3
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
ae3598e
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
60e421f
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
cb44617
Alleviate on test time
Yipeng1994 Jul 22, 2022
185feb0
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
3e91733
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
f5b267a
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
f40df8d
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
d4fc161
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
830d889
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
8d11afb
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
b9b7870
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
1deb764
nccl logical send recv do not support different hierarchy
Yipeng1994 Jul 24, 2022
2fb1122
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 24, 2022
9b54b56
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 24, 2022
bb1b2d1
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 24, 2022
088e754
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 24, 2022
2abf194
Init boxing collector when asked
Yipeng1994 Jul 25, 2022
6c4a674
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 25, 2022
6407bb5
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 191 additions & 55 deletions oneflow/core/auto_parallel/boxing_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/global_for.h"
Expand Down Expand Up @@ -49,7 +50,7 @@ void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int
}

// Let a nd sbp be consistent with the given hierarchy number
Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) {
Maybe<NdSbp> SetNdSbpDim(const NdSbp& nd_sbp, int32_t hierarchy_num) {
// Do not need to change
if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; }
// (S0, S0) -> S0
Expand All @@ -71,6 +72,60 @@ Maybe<NdSbp> SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) {
return new_sbp;
}

int32_t TotalNumSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) {
int32_t total_num_split = 1;
for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); i++) {
if (nd_sbp.sbp_parallel(i).has_split_parallel()) {
total_num_split *= parallel_desc.hierarchy()->At(i);
}
}
return total_num_split;
}

// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
Maybe<void> AskSbpCombinationFor1DSbp(const NdSbp& sbp_producer, const NdSbp& sbp_consumer,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {
if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) {
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()
&& sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) {
return Maybe<void>::Ok();
}

if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) {
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.

int32_t hierarchy_size = 0;
if (producer_parallel_desc.hierarchy()->elem_cnt()
< consumer_parallel_desc.hierarchy()->elem_cnt()) {
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*diag_node_pos = 1;
hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes();
} else {
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*diag_node_pos = 0;
hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes();
}

NdSbp broadcast_nd;
for (int32_t i = 0; i < hierarchy_size; i++) {
broadcast_nd.add_sbp_parallel();
broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
middle_sbps.emplace_back(broadcast_nd);
}
}
return Maybe<void>::Ok();
}

} // namespace

// A constructor with init, designed for uncustomized boxing collector
Expand Down Expand Up @@ -190,10 +245,13 @@ Maybe<void> BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl
// NOTE: The performance of this function are all the same with different hierarchy
int32_t world_size = GlobalProcessCtx::WorldSize();
Shape hierarchy44({4 * world_size, 4 * world_size});
int32_t virtual_range_size = hierarchy44.elem_cnt();
std::shared_ptr<Shape> virtual_hierarchy = std::make_shared<Shape>(hierarchy44);
auto parallel_desc = JUST(ParallelDesc::New(
"cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy));
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false);
BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,
virtual_range_size, virtual_range_size, virtual_range_size},
DataType::kInt8, /*is_dynamic=*/false);
JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc));
return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -309,7 +367,10 @@ Maybe<void> BoxingCollector::GenerateCombination4DiffPlacement(
BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) {
// Virtual parallel and blob description
int32_t world_size = GlobalProcessCtx::WorldSize();
BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false);
int32_t virtual_range_size = 4 * world_size * (4 * world_size + 1);
BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size,
virtual_range_size, virtual_range_size, virtual_range_size},
DataType::kInt8, /*is_dynamic=*/false);
// Virtual placements before transfer
Shape in_hierarchy44({4 * world_size + 1, 4 * world_size});
std::shared_ptr<Shape> in_hierarchy = std::make_shared<Shape>(in_hierarchy44);
Expand Down Expand Up @@ -496,66 +557,39 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) {
return Maybe<void>::Ok();
}
// If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b),
// Use nccl logical send recv instead of middle node.
// Note that in op sbp inference, cost of middle nodes is still used for the moment.
#ifdef WITH_CUDA
if (compute_cost == false && producer_parallel_desc.hierarchy()->NumAxes() == 2
&& producer_parallel_desc == consumer_parallel_desc
&& !(NdSbpHasPartialParallel(sbp_consumer)) &&
// TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open
// this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P)
// -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with
// (P, P) -> (P, S0) at the moment.
// !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) &&
Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
VLOG(3) << "Middle node insertion is skipped when src sbp is " << NdSbpToString(sbp_producer)
<< " dst sbp is " << NdSbpToString(sbp_consumer)
<< ", because nccl logical send/recv can handle this.";
if (producer_parallel_desc == consumer_parallel_desc && sbp_producer == sbp_consumer) {
return Maybe<void>::Ok();
}
#endif // WITH_CUDA

// Dealing with 1D sbp to 1D sbp
// Specifically, S -> P.
if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) {
if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) {
// Support [4]: P <--> [2, 2]: (P, P)
// Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P)
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()
&& sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) {
return Maybe<void>::Ok();
}

if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) {
// S -> B -> P (Large cost!)
// TODO: Please implement S -> P directly.
// We do not support [3]: P <--> [2, 2]: (P, P) as well.

int32_t hierarchy_size = 0;
if (producer_parallel_desc.hierarchy()->elem_cnt()
< consumer_parallel_desc.hierarchy()->elem_cnt()) {
// The diagonal node uses the parallel description from producer
// (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P)
*diag_node_pos = 1;
hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes();
} else {
// The diagonal node uses the parallel description from consumer
// S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P)
*diag_node_pos = 0;
hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes();
}
JUST(AskSbpCombinationFor1DSbp(sbp_producer, sbp_consumer, producer_parallel_desc,
consumer_parallel_desc, middle_sbps, diag_node_pos));
// No middle nodes for the other 1d-sbp combinations
return Maybe<void>::Ok();
}

NdSbp broadcast_nd;
for (int32_t i = 0; i < hierarchy_size; i++) {
broadcast_nd.add_sbp_parallel();
broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
middle_sbps.emplace_back(broadcast_nd);
}
return Maybe<void>::Ok();
#ifdef WITH_CUDA
static const bool enable_general_basic_communication =
strint marked this conversation as resolved.
Show resolved Hide resolved
ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false);
// Use a general basic communication if no P in the consumer
if ((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
|| enable_general_basic_communication)
&& (!NdSbpHasPartialParallel(sbp_consumer))
&& producer_parallel_desc.device_type() == DeviceType::kCUDA
&& consumer_parallel_desc.device_type() == DeviceType::kCUDA) {
strint marked this conversation as resolved.
Show resolved Hide resolved
if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) {
// (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
// Directly applying general basic communication would have O(n^2) time complexity for P->B
// Using two-step transfer would reduce it to a linear cost
JUST(AskSbpCombination4GeneralBasicCommunication(
strint marked this conversation as resolved.
Show resolved Hide resolved
sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc, middle_sbps, diag_node_pos));
}
// Otherwise, one-step transfer
return Maybe<void>::Ok();
}
#endif // WITH_CUDA

// Middle nodes algorithm supports transfer for different machines or devices or hierarchies
if (producer_parallel_desc != consumer_parallel_desc) {
Expand All @@ -568,6 +602,7 @@ Maybe<void> BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const
// Transfer for the same machines, devices and hierarchy.
if (sbp_producer == sbp_consumer) { return Maybe<void>::Ok(); }
const auto& parallel_hierarchy = producer_parallel_desc.hierarchy();

*diag_node_pos = 0;
// Dealing with nD sbp, n>2
if (parallel_hierarchy->NumAxes() > 2) {
Expand Down Expand Up @@ -1007,4 +1042,105 @@ Maybe<void> BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logica
return Maybe<void>::Ok();
}

// Ask for sbp combination for general basic communication
Maybe<void> BoxingCollector::AskSbpCombination4GeneralBasicCommunication(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos) {
// (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP
// One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes
if (producer_parallel_desc == consumer_parallel_desc
&& producer_parallel_desc.hierarchy()->NumAxes() == 2
&& (sbp_producer.sbp_parallel(0) == sbp_consumer.sbp_parallel(0)
|| sbp_producer.sbp_parallel(1) == sbp_consumer.sbp_parallel(1))) {
return Maybe<void>::Ok();
}

// Not enough gain in transfer cost, do not use middle nodes
int32_t partial_ratio4producer = PartialRatio4Producer(sbp_producer, producer_parallel_desc);
int32_t broadcast_ratio4consumer = BroadcastRatio4Consumer(sbp_consumer, consumer_parallel_desc);
if (2 * (partial_ratio4producer + broadcast_ratio4consumer)
>= partial_ratio4producer * broadcast_ratio4consumer) {
return Maybe<void>::Ok();
}

bool close2producer = true;
if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()) {
// Get close to the one with more splits
close2producer = TotalNumSplit(sbp_producer, producer_parallel_desc)
> TotalNumSplit(sbp_consumer, consumer_parallel_desc);
} else {
// Get close to the one with more machines
close2producer = producer_parallel_desc.parallel_num() > consumer_parallel_desc.parallel_num();
}
// Get the contiguous sbp
if (close2producer) {
JUST(AskCloseAllSplitSbp(sbp_producer, producer_parallel_desc, logical_blob_desc, middle_sbps));
*diag_node_pos = 1;
} else {
JUST(AskCloseAllSplitSbp(sbp_consumer, consumer_parallel_desc, logical_blob_desc, middle_sbps));
*diag_node_pos = 0;
}
return Maybe<void>::Ok();
}

// Ask for a all-split sbp which is close to the original one
Maybe<void> BoxingCollector::AskCloseAllSplitSbp(const NdSbp& nd_sbp,
const ParallelDesc& parallel_desc,
const BlobDesc& logical_blob_desc,
std::vector<NdSbp>& middle_sbps) {
Shape remain_shape = logical_blob_desc.shape();
Shape rest_split_shape = logical_blob_desc.shape();
int32_t dim_shape = remain_shape.NumAxes();
// Initialize the remains and splitting
// logical_blob_desc.shape() == remain_shape .* rest_split_shape;
for (int32_t i = 0; i < dim_shape; i++) { rest_split_shape.Set(i, 1); }
for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {
const auto& sbp = nd_sbp.sbp_parallel(sbp_id);
if (sbp.has_split_parallel()) {
int32_t axis = sbp.split_parallel().axis();
int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);
remain_shape.Set(axis, remain_shape.At(axis) / split_num);
rest_split_shape.Set(axis, rest_split_shape.At(axis) * split_num);
}
}
// Get the contiguous sbp
NdSbp new_sbp = nd_sbp;
for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {
const auto& sbp = nd_sbp.sbp_parallel(sbp_id);
int32_t split_num = parallel_desc.hierarchy()->At(sbp_id);
if (sbp.has_split_parallel()) {
int32_t axis = sbp.split_parallel().axis();
// split shape is the total splitting number starting from sbp_id to the end
rest_split_shape.Set(axis, rest_split_shape.At(axis) / split_num);
} else {
// change P or B to S(axis)
int32_t axis = -1;
// 4096 is large enough, we might not have that much devices
int32_t min_split_num = 4096;
// We need to pick a suitable axis
for (int32_t i = 0; i < remain_shape.NumAxes(); i++) {
if (remain_shape.At(i) % split_num == 0) {
if (rest_split_shape.At(i) < min_split_num) {
// Pick the axis with smallest splitting number among the rest of the sbp
min_split_num = rest_split_shape.At(i);
axis = i;
}
}
}
// P, B -> S(axis)
if (axis >= 0) {
new_sbp.mutable_sbp_parallel(sbp_id)->mutable_split_parallel()->set_axis(axis);
remain_shape.Set(axis, remain_shape.At(axis) / split_num);
} else {
// Can not find a suitable contiguous sbp
return Maybe<void>::Ok();
}
}
}
// Add the new sbp into the middle node lists
middle_sbps.emplace_back(new_sbp);
return Maybe<void>::Ok();
}

} // namespace oneflow
9 changes: 9 additions & 0 deletions oneflow/core/auto_parallel/boxing_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ class BoxingCollector final {
BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<int32_t>>& diag_nodes);
// Ask for sbp combination for general basic communication
Maybe<void> AskSbpCombination4GeneralBasicCommunication(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos);
// Ask for a all-split sbp which is closed to the original one
Maybe<void> AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,
const BlobDesc& logical_blob_desc,
std::vector<NdSbp>& middle_sbps);
// Stores all the possible SbpParallel.
HashMap<::oneflow::SbpParallel, int32_t> sbp_parallel_universe_;
// Relationship between id and Sbp Parallel
Expand Down
Loading