Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat general basic communication #8437

Merged
merged 76 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
9b8664c
Add a slight cost for B->S and B->P in 2d sbp
Yipeng1994 Jun 5, 2022
c32adbb
Add penalty for P in consumer
Yipeng1994 Jun 6, 2022
362a426
Merge branch 'master' into patch-sbp_cost
Yipeng1994 Jun 7, 2022
403d429
Fix a slight bug
Yipeng1994 Jun 7, 2022
3ff0d59
Add at most 1 middle node for general basic communication
Yipeng1994 Jun 7, 2022
998f883
Add the cost for general basic communication
Yipeng1994 Jun 8, 2022
e53ffbc
Add the slight penalty for eager
Yipeng1994 Jun 8, 2022
9145dc6
Merge branch 'patch-sbp_cost' of github.com:Oneflow-Inc/oneflow into …
Yipeng1994 Jun 8, 2022
0373803
Skip initialization of boxing collector if not needed
Yipeng1994 Jun 9, 2022
e7164a7
Fix a bug
Yipeng1994 Jun 9, 2022
2b16f1b
Dev nd nccl send recv boxing (#8467)
Yipeng1994 Jun 23, 2022
2a1810c
Support different hierarchy
Yipeng1994 Jun 23, 2022
f46efa1
Merge branch 'master' into feat-general_basic_communication (#8477)
Yipeng1994 Jun 23, 2022
2c30a03
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jun 23, 2022
900dcc8
Ask general basic communication before middle nodes
Yipeng1994 Jun 23, 2022
5eb7510
Add a task type for general basic communication
Yipeng1994 Jun 23, 2022
8f39d6d
Fix a bug
Yipeng1994 Jun 23, 2022
0c9fd15
Fix a bug
Yipeng1994 Jun 24, 2022
0c95d76
Fix the bug of transfer from 1d sbp to 2d sbp
Yipeng1994 Jun 24, 2022
d4cf04c
Use the intersection to approximate the ratio
Yipeng1994 Jun 27, 2022
e843e7e
Use a suitable virtual blob description
Yipeng1994 Jun 27, 2022
7e7724a
Remove the checking for balanced splitting
Yipeng1994 Jun 28, 2022
e7d7fe3
Fix the previous bug, still have another one
Yipeng1994 Jun 29, 2022
afee00a
Fix another bug
Yipeng1994 Jun 29, 2022
3b6baad
Update oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_…
guo-ran Jun 30, 2022
90b0a5d
Use machine 4-7 for hierarchy [2, 2] in the consumer
Yipeng1994 Jun 30, 2022
bcede72
Add a switch for general basic communication
Yipeng1994 Jun 30, 2022
79c905f
Add test script and of format
Yipeng1994 Jun 30, 2022
600c384
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jun 30, 2022
fe8fd38
Fix conflit of master and remove print-out information
Yipeng1994 Jun 30, 2022
3605117
Skip middle nodes if not enough gains
Yipeng1994 Jul 1, 2022
e9e2d42
Fix a typo
Yipeng1994 Jul 1, 2022
3cf45b2
fix nccl send recv bug for different stream
guo-ran Jul 2, 2022
1b9705d
Merge branch 'feat-general_basic_communication' of https://github.com…
guo-ran Jul 2, 2022
27aef14
hot fix for ncclComm init
guo-ran Jul 4, 2022
d837d73
Reuse streams for different jobs
Yipeng1994 Jul 8, 2022
ba065fd
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 8, 2022
c3c0074
Rename and of format
Yipeng1994 Jul 8, 2022
ea2ce43
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 8, 2022
0076c1d
Skip general basic communication for transfer
Yipeng1994 Jul 11, 2022
bd8444d
Merge branch 'master' into feat-general_basic_communication
strint Jul 13, 2022
f09dc85
Address suggestion
Yipeng1994 Jul 14, 2022
f4ea3c2
Use the more powerful GetRankSendRecvIntersection
Yipeng1994 Jul 18, 2022
e426c52
Register nccl send recv op for comm init before
Yipeng1994 Jul 18, 2022
41e2e57
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 20, 2022
9d2808e
Remove irrelevant scripts
Yipeng1994 Jul 20, 2022
b32c133
Address suggestion and of format
Yipeng1994 Jul 21, 2022
d65ae0d
Address suggestion
Yipeng1994 Jul 21, 2022
96487cf
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 21, 2022
921d0d4
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 21, 2022
5869aa5
Static analysis
Yipeng1994 Jul 21, 2022
e21597e
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 21, 2022
983ae75
Static analysis. Still have another one
Yipeng1994 Jul 21, 2022
e87770c
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 21, 2022
5b01f68
Static analysis
Yipeng1994 Jul 22, 2022
7cbbeac
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
b23f6b3
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
ae3598e
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
60e421f
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
cb44617
Alleviate on test time
Yipeng1994 Jul 22, 2022
185feb0
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
3e91733
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
f5b267a
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 22, 2022
f40df8d
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
d4fc161
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
830d889
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
8d11afb
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
b9b7870
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 23, 2022
1deb764
nccl logical send recv do not support different hierarchy
Yipeng1994 Jul 24, 2022
2fb1122
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 24, 2022
9b54b56
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 24, 2022
bb1b2d1
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 24, 2022
088e754
Merge branch 'master' into feat-general_basic_communication
mergify[bot] Jul 24, 2022
2abf194
Init boxing collector when asked
Yipeng1994 Jul 25, 2022
6c4a674
Merge branch 'feat-general_basic_communication' of github.com:Oneflow…
Yipeng1994 Jul 25, 2022
6407bb5
Merge branch 'master' into feat-general_basic_communication
Yipeng1994 Jul 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 216 additions & 54 deletions oneflow/core/auto_parallel/boxing_collector.cpp

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions oneflow/core/auto_parallel/boxing_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ class BoxingCollector final {
BoxingCollector* boxing_collector_producer,
BoxingCollector* boxing_collector_consumer,
const std::vector<std::vector<int32_t>>& diag_nodes);
// Ask for sbp combination for general basic communication
Maybe<void> AskSbpCombination4GeneralBasicCommunication(
const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc,
std::vector<NdSbp>& middle_sbps, int32_t* diag_node_pos);
// Ask for a all-split sbp which is closed to the original one
Maybe<void> AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,
const BlobDesc& logical_blob_desc,
std::vector<NdSbp>& middle_sbps);
// Stores all the possible SbpParallel.
HashMap<::oneflow::SbpParallel, int32_t> sbp_parallel_universe_;
// Relationship between id and Sbp Parallel
Expand All @@ -154,6 +163,11 @@ class BoxingCollector final {
std::vector<int32_t> id_1d_2_nd_;
// The sbp size in the combination table
int32_t hierarchy_num_;
// How the boxing collector is initialized
int32_t init_type_ = -1;
// Enable general basic communication or not
const bool enable_general_basic_communication =
ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false);
}; // class BoxingCollector

} // namespace oneflow
Expand Down
177 changes: 173 additions & 4 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ limitations under the License.
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/auto_parallel/boxing_collector.h"
#include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/common/device_type.pb.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/sbp_parallel.pb.h"

namespace oneflow {

Expand Down Expand Up @@ -55,6 +61,15 @@ double Penalty4PartialInConsumer(double logical_blob_size, int32_t producer_para
}
}

int32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc,
const std::function<bool(const SbpParallel&)>& classifier) {
int32_t ratio = 1;
for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) {
if (classifier(nd_sbp.sbp_parallel(sbp_id))) { ratio *= parallel_desc.hierarchy()->At(sbp_id); }
}
return ratio;
}

Maybe<double> ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel,
const SbpParallel& consumer_sbp_parallel,
const BlobDesc& logical_blob_desc,
Expand Down Expand Up @@ -409,6 +424,16 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc,

} // namespace

int32_t PartialRatio4Producer(const NdSbp& sbp_producer,
const ParallelDesc& producer_parallel_desc) {
return Ratio4Sbp(sbp_producer, producer_parallel_desc, &SbpParallel::has_partial_sum_parallel);
}

int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer,
const ParallelDesc& consumer_parallel_desc) {
return Ratio4Sbp(sbp_consumer, consumer_parallel_desc, &SbpParallel::has_broadcast_parallel);
}

void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,
ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp) {
const auto& hierarchy = parallel_desc.hierarchy();
Expand Down Expand Up @@ -496,14 +521,31 @@ Maybe<double> ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel
reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0),
logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc));
}
// Not supporting different hierarchy
// TODO: Support it in the future

#ifdef WITH_CUDA
static const bool enable_general_basic_communication =
ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false);
// Use a general basic communication if no P in the consumer
if ((((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
&& producer_parallel_desc == consumer_parallel_desc)
|| enable_general_basic_communication)
&& !NdSbpHasPartialParallel(consumer_sbp_parallel))
&& producer_parallel_desc.device_type() == DeviceType::kCUDA
&& consumer_parallel_desc.device_type() == DeviceType::kCUDA) {
return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel,
logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc)
+ GetTransferCost();
}
#endif // WITH_CUDA

// Not supporting different hierarchy without general basic communication
if (in_hierarchy->elem_cnt() != out_hierarchy->elem_cnt()) { return kUnsupportedBoxing; }

double logical_blob_size =
logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());
bool on_same_devices =
reduced_in_parallel_desc.EqualsIgnoringHierarchy(reduced_out_parallel_desc);
double logical_blob_size =
logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());

if (in_dim == 2 && out_dim == 2) {
// Not supporting different hierarchy
Expand Down Expand Up @@ -629,6 +671,39 @@ Maybe<double> ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc,
bool requires_same_sbp) {
// Reduce before cost computation
ParallelDesc reduced_in_parallel_desc = producer_parallel_desc;
NdSbp reduced_in_nd_sbp;
NdSbpDimReduce(producer_parallel_desc, producer_sbp_parallel, &reduced_in_parallel_desc,
&reduced_in_nd_sbp);

ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc;
NdSbp reduced_out_nd_sbp;
NdSbpDimReduce(consumer_parallel_desc, consumer_sbp_parallel, &reduced_out_parallel_desc,
&reduced_out_nd_sbp);
// In 90% of the transfer, we would have the same parallel description for producer and consumer
// We need to speed it up and give an approximation of the cost
if (reduced_in_parallel_desc == reduced_out_parallel_desc
&& reduced_in_nd_sbp == reduced_out_nd_sbp) {
return 0.0;
}
#ifdef WITH_CUDA
static const bool enable_general_basic_communication =
ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false);
// Use a general basic communication if no P in the consumer
if ((((Singleton<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()
&& producer_parallel_desc == consumer_parallel_desc)
|| enable_general_basic_communication)
&& !NdSbpHasPartialParallel(consumer_sbp_parallel))
&& producer_parallel_desc.device_type() == DeviceType::kCUDA
&& consumer_parallel_desc.device_type() == DeviceType::kCUDA) {
return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel,
logical_blob_desc, producer_parallel_desc,
consumer_parallel_desc)
+ GetTransferCost();
}
#endif // WITH_CUDA

// Initialize boxing collector
constexpr int32_t kRegularMaxSplitAxes = 6;
static thread_local BoxingCollector boxing_collector(kRegularMaxSplitAxes);
Expand Down Expand Up @@ -727,4 +802,98 @@ double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consum
}
}

// The transfer ratio for general basic communication
// Cost = ratio * data amount
// When we get the this function, either producer_sbp_parallel != consumer_sbp_parallel
// or producer_parallel_desc != consumer_parallel_desc
double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel,
const NdSbp& consumer_sbp_parallel,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc) {
// The upper bound of the amount of the transferred data
int32_t producer_partial_ratio =
PartialRatio4Producer(producer_sbp_parallel, producer_parallel_desc);
int32_t consumer_broadcast_ratio =
BroadcastRatio4Consumer(consumer_sbp_parallel, consumer_parallel_desc);
// More intersection on the same devices
bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc);
// approximate intersection ratio
double intersection_ratio = 1.0;
// (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer
if (producer_partial_ratio > 1 && consumer_broadcast_ratio > 1) {
if (on_same_devices) {
// Pure P in the producer or B in the consumer
// (P, P, P) -> ? or ? -> (B, B)
if (producer_partial_ratio == producer_parallel_desc.parallel_num()
|| consumer_broadcast_ratio == consumer_parallel_desc.parallel_num()) {
// There some cases which is not applicable to this ratio
// We just take the one with the largest possibility
// For example: (P, S0) -> (B, B) for 1-D blob with machine hierarchy [n, m]
// The path should be (P, S0) -> (S0, S0) -> (B, B)
// true intersection ratio = 1/m + 1
intersection_ratio = 2.0;
} else {
// sbp_consumer = (B, Si) or (Si, B)
for (int32_t sbp_id = 0; sbp_id < std::min(producer_sbp_parallel.sbp_parallel_size(),
consumer_sbp_parallel.sbp_parallel_size());
sbp_id++) {
if (consumer_sbp_parallel.sbp_parallel(sbp_id).has_split_parallel()) {
const auto& producer_sbp4sbp_id = producer_sbp_parallel.sbp_parallel(sbp_id);
// (B, P) or (Si, P) -> (Si, B)
// (P, B) or (P, Si) -> (B, Si)
if (producer_sbp4sbp_id.has_broadcast_parallel()
|| producer_sbp4sbp_id == consumer_sbp_parallel.sbp_parallel(sbp_id)) {
intersection_ratio = 2.0;
break;
}
}
}
// Judge whether the intersection ratio is given a value (2.0)
if (intersection_ratio == 1.0) {
// The true intersection ratio range from 0 to 2,
// we just take a middle point of the range as the approximation
// For example: (P, S0) -> (S0, B), Path: (P, S0) -> (S1, S0) -> (S0, B)
// true intersection ratio = 1 + 1/m
// For example: (P, S0) -> (S1, B), Path: (P, S0) -> (S1, S0) -> (S1, B)
// true intersection ratio = 1 + 1
// For example: (P, S0) -> (B, S0), with a 1D blob
// true intersection ratio = (n+p-1)/nm + (n+p-1)/nm
// For example: (S0, P) -> (B, S0), Path: (S0, P) -> (S0, S1) -> (B, S0)
// true intersection ratio = 1 + 1/n

// We use the approximation 1 + (1/n + 1/m)/2
intersection_ratio = 1.0 + 0.5 / producer_parallel_desc.hierarchy()->At(0)
+ 0.5 / producer_parallel_desc.hierarchy()->At(1);
}
}
}
// Otherwise, on different devices
// intersection_ratio = 1.0;
} else {
// No P in the producer or no B in the consumer, one-step transfer
if (on_same_devices) {
// We use simulation for nD sbp with n=1,2,3,...
TensorSliceView in_second_slice =
GetTensorSliceView4ParallelId(*producer_parallel_desc.hierarchy(), producer_sbp_parallel,
logical_blob_desc.shape(), /*parallel_id=*/1);
TensorSliceView out_second_slice =
GetTensorSliceView4ParallelId(*consumer_parallel_desc.hierarchy(), consumer_sbp_parallel,
logical_blob_desc.shape(), /*parallel_id=*/1);
const TensorSliceView& intersection = in_second_slice.Intersect(out_second_slice);
// The intersection ratio is design for two steps.
// However, we only have one step here, we would increase the ratio by 1.0
// to eliminate the unused step
intersection_ratio += std::min(
1.0, (double)(intersection.shape().elem_cnt() * producer_parallel_desc.parallel_num())
/ logical_blob_desc.shape().elem_cnt());
}
// Otherwise, on different devices
// intersection_ratio = 1.0;
}
// Subtract the intersection part
return (producer_partial_ratio + consumer_broadcast_ratio - intersection_ratio)
* logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type());
}

} // namespace oneflow
18 changes: 18 additions & 0 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ enum Penalty4PartialInConsumerTag : int {
kStrict = 3 // Not allow a transfer to P
};

// [2, 3, 4, 5, 9, 100, 8]: (P, S0, P, P, B, S1, P)
// partial ratio = 2 * 4 * 5 * 8
int32_t PartialRatio4Producer(const NdSbp& sbp_producer,
const ParallelDesc& producer_parallel_desc);

// [2, 3, 4, 5, 9, 100, 8]: (P, S0, B, P, B, S1, P)
// broadcast ratio = 4 * 9
int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer,
const ParallelDesc& consumer_parallel_desc);

void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp,
ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp);

Expand Down Expand Up @@ -96,6 +106,14 @@ double ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp);

// The transfer ratio for general basic communication
// Cost = ratio * data amount
double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel,
const NdSbp& consumer_sbp_parallel,
const BlobDesc& logical_blob_desc,
const ParallelDesc& producer_parallel_desc,
const ParallelDesc& consumer_parallel_desc);

} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_
Loading