diff --git a/oneflow/core/auto_parallel/boxing_collector.cpp b/oneflow/core/auto_parallel/boxing_collector.cpp index c8210c2e744..a0c2f44b21e 100644 --- a/oneflow/core/auto_parallel/boxing_collector.cpp +++ b/oneflow/core/auto_parallel/boxing_collector.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/global_for.h" @@ -49,7 +50,7 @@ void DfsSetNdSbp(const std::vector<::oneflow::SbpParallel>& id2sbp_parallel, int } // Let a nd sbp be consistent with the given hierarchy number -Maybe SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) { +Maybe SetNdSbpDim(const NdSbp& nd_sbp, int32_t hierarchy_num) { // Do not need to change if (nd_sbp.sbp_parallel_size() == hierarchy_num) { return nd_sbp; } // (S0, S0) -> S0 @@ -71,6 +72,60 @@ Maybe SetNdSbpDim(NdSbp nd_sbp, int32_t hierarchy_num) { return new_sbp; } +int32_t TotalNumSplit(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc) { + int32_t total_num_split = 1; + for (int32_t i = 0; i < nd_sbp.sbp_parallel_size(); i++) { + if (nd_sbp.sbp_parallel(i).has_split_parallel()) { + total_num_split *= parallel_desc.hierarchy()->At(i); + } + } + return total_num_split; +} + +// Dealing with 1D sbp to 1D sbp +// Specifically, S -> P. +Maybe AskSbpCombinationFor1DSbp(const NdSbp& sbp_producer, const NdSbp& sbp_consumer, + const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc, + std::vector& middle_sbps, int32_t* diag_node_pos) { + if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) { + // Support [4]: P <--> [2, 2]: (P, P) + // Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P) + if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num() + && sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) { + return Maybe::Ok(); + } + + if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) { + // S -> B -> P (Large cost!) + // TODO: Please implement S -> P directly. + // We do not support [3]: P <--> [2, 2]: (P, P) as well. + + int32_t hierarchy_size = 0; + if (producer_parallel_desc.hierarchy()->elem_cnt() + < consumer_parallel_desc.hierarchy()->elem_cnt()) { + // The diagonal node uses the parallel description from producer + // (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P) + *diag_node_pos = 1; + hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes(); + } else { + // The diagonal node uses the parallel description from consumer + // S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P) + *diag_node_pos = 0; + hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes(); + } + + NdSbp broadcast_nd; + for (int32_t i = 0; i < hierarchy_size; i++) { + broadcast_nd.add_sbp_parallel(); + broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); + } + middle_sbps.emplace_back(broadcast_nd); + } + } + return Maybe::Ok(); +} + } // namespace // A constructor with init, designed for uncustomized boxing collector @@ -92,6 +147,8 @@ Maybe BoxingCollector::Init(int32_t max_axis) { JUST(GenerateCombination4SamePlacement(3)); JUST(GenerateCombination4DiffHierarchy(this, this)); JUST(GenerateCombination4DiffPlacement(this, this)); + init_type_ = int32_t(enable_general_basic_communication + || Singleton::Get()->nccl_use_compute_stream()); return Maybe::Ok(); } @@ -106,6 +163,8 @@ Maybe BoxingCollector::Init(const BlobDesc& logical_blob_desc, // Get copy cost in lazy mode LazyMode::Guard enable_lazy_mode(true); JUST(GenerateCombination4SamePlacement(5, logical_blob_desc, parallel_desc)); + init_type_ = int32_t(enable_general_basic_communication + || Singleton::Get()->nccl_use_compute_stream()); return Maybe::Ok(); } @@ -173,6 +232,7 @@ void BoxingCollector::GenerateMap1d2nd() { // Generate the id Map from 1d sbp to nd sbp NdSbp nd_sbp; for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { nd_sbp.add_sbp_parallel(); } + id_1d_2_nd_.clear(); id_1d_2_nd_.resize(m, -1); for (int32_t id_1d = 0; id_1d < m; id_1d++) { for (int32_t dim_sbp = 0; dim_sbp < hierarchy_num_; dim_sbp++) { @@ -190,10 +250,13 @@ Maybe BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl // NOTE: The performance of this function are all the same with different hierarchy int32_t world_size = GlobalProcessCtx::WorldSize(); Shape hierarchy44({4 * world_size, 4 * world_size}); + int32_t virtual_range_size = hierarchy44.elem_cnt(); std::shared_ptr virtual_hierarchy = std::make_shared(hierarchy44); auto parallel_desc = JUST(ParallelDesc::New( "cpu", {"0:0-" + std::to_string(hierarchy44.elem_cnt() - 1)}, virtual_hierarchy)); - BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false); + BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size, + virtual_range_size, virtual_range_size, virtual_range_size}, + DataType::kInt8, /*is_dynamic=*/false); JUST(GenerateCombination4SamePlacement(max_middle_node_num, blob_desc, *parallel_desc)); return Maybe::Ok(); } @@ -204,7 +267,9 @@ Maybe BoxingCollector::GenerateCombination4SamePlacement(int32_t max_middl const ParallelDesc& parallel_desc) { // Store the origin transfer cost information int32_t n = nd_sbp_lists_.size(); + minimum_copy_cost_.clear(); minimum_copy_cost_.resize(n); + middle_nodes_.clear(); middle_nodes_.resize(n); for (int32_t i = 0; i < n; i++) { minimum_copy_cost_[i].resize(n); @@ -291,6 +356,7 @@ Maybe BoxingCollector::GenerateCombination4DiffHierarchy( // Search the path that contains one of the diagonal sbp int32_t n = nd_sbp_lists_.size(); + diag_node_diff_hierarchy_.clear(); diag_node_diff_hierarchy_.resize(n); for (int32_t i = 0; i < n; i++) { diag_node_diff_hierarchy_[i].resize(n); @@ -309,7 +375,10 @@ Maybe BoxingCollector::GenerateCombination4DiffPlacement( BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer) { // Virtual parallel and blob description int32_t world_size = GlobalProcessCtx::WorldSize(); - BlobDesc blob_desc({16, 16, 16, 16}, DataType::kInt8, /*is_dynamic=*/false); + int32_t virtual_range_size = 4 * world_size * (4 * world_size + 1); + BlobDesc blob_desc({virtual_range_size, virtual_range_size, virtual_range_size, + virtual_range_size, virtual_range_size, virtual_range_size}, + DataType::kInt8, /*is_dynamic=*/false); // Virtual placements before transfer Shape in_hierarchy44({4 * world_size + 1, 4 * world_size}); std::shared_ptr in_hierarchy = std::make_shared(in_hierarchy44); @@ -334,6 +403,7 @@ Maybe BoxingCollector::ComputeCostFor1DSbpDiffPlacement( // Number of 1d sbp int32_t m = id2sbp_parallel_.size(); // Compute the cost while transferring a 1D sbp between different placements + cost_4_diff_placement.clear(); cost_4_diff_placement.resize(m); for (int32_t id_1d_producer = 0; id_1d_producer < m; id_1d_producer++) { cost_4_diff_placement[id_1d_producer].resize(m, GetMaxVal()); @@ -364,6 +434,7 @@ Maybe BoxingCollector::GenerateCombination4DiffPlacement( // Search the path that contains two of the diagonal sbp int32_t n = nd_sbp_lists_.size(); + diag_node_diff_placement_.clear(); diag_node_diff_placement_.resize(n); for (int32_t i = 0; i < n; i++) { diag_node_diff_placement_[i].resize(n); @@ -496,64 +567,53 @@ Maybe BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const if (ParseBooleanFromEnv("ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK", false)) { return Maybe::Ok(); } - // If compute_cost==false + 2D sbp + same placment + nccl logical + not (p->b), - // Use nccl logical send recv instead of middle node. - // Note that in op sbp inference, cost of middle nodes is still used for the moment. -#ifdef WITH_CUDA - if (compute_cost == false && producer_parallel_desc.hierarchy()->NumAxes() == 2 - && producer_parallel_desc == consumer_parallel_desc - && !(NdSbpHasPartialParallel(sbp_consumer)) && - // TODO(): When same dim 0 finished dealing with (*, P) -> (*, S) in nccl logical pass, open - // this condition. When dealing with (P, P) -> (B, S0), middle node will change it to (P, P) - // -> (P, S0) -> (B, S0), neither same dim 0 or send recv in nccl logical pass can deal with - // (P, P) -> (P, S0) at the moment. - // !(NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) && - Singleton::Get()->nccl_use_compute_stream()) { - VLOG(3) << "Middle node insertion is skipped when src sbp is " << NdSbpToString(sbp_producer) - << " dst sbp is " << NdSbpToString(sbp_consumer) - << ", because nccl logical send/recv can handle this."; + if (producer_parallel_desc == consumer_parallel_desc && sbp_producer == sbp_consumer) { return Maybe::Ok(); } -#endif // WITH_CUDA // Dealing with 1D sbp to 1D sbp - // Specifically, S -> P. if (Is1dSbp(sbp_producer) && Is1dSbp(sbp_consumer)) { - if (sbp_consumer.sbp_parallel(0).has_partial_sum_parallel()) { - // Support [4]: P <--> [2, 2]: (P, P) - // Support {0, 1, 2, 3}: P <--> {2, 0, 6, 7}: (P, P) - if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num() - && sbp_producer.sbp_parallel(0).has_partial_sum_parallel()) { - return Maybe::Ok(); - } + JUST(AskSbpCombinationFor1DSbp(sbp_producer, sbp_consumer, producer_parallel_desc, + consumer_parallel_desc, middle_sbps, diag_node_pos)); + // No middle nodes for the other 1d-sbp combinations + return Maybe::Ok(); + } - if (!sbp_producer.sbp_parallel(0).has_broadcast_parallel()) { - // S -> B -> P (Large cost!) - // TODO: Please implement S -> P directly. - // We do not support [3]: P <--> [2, 2]: (P, P) as well. - - int32_t hierarchy_size = 0; - if (producer_parallel_desc.hierarchy()->elem_cnt() - < consumer_parallel_desc.hierarchy()->elem_cnt()) { - // The diagonal node uses the parallel description from producer - // (S, S) -> (B, B) -> P/(P, P) or S -> B -> P/(P, P) - *diag_node_pos = 1; - hierarchy_size = producer_parallel_desc.hierarchy()->NumAxes(); - } else { - // The diagonal node uses the parallel description from consumer - // S/(S, S) -> B -> P or S/(S, S) -> (B, B) -> (P, P) - *diag_node_pos = 0; - hierarchy_size = consumer_parallel_desc.hierarchy()->NumAxes(); - } +#ifdef WITH_CUDA + // Use a general basic communication if no P in the consumer + if (((Singleton::Get()->nccl_use_compute_stream() + && producer_parallel_desc == consumer_parallel_desc) + || enable_general_basic_communication) + && (!NdSbpHasPartialParallel(sbp_consumer)) + && producer_parallel_desc.device_type() == DeviceType::kCUDA + && consumer_parallel_desc.device_type() == DeviceType::kCUDA) { + if (NdSbpHasPartialParallel(sbp_producer) && NdSbpHasBroadcastParallel(sbp_consumer)) { + // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer + // Directly applying general basic communication would have O(n^2) time complexity for P->B + // Using two-step transfer would reduce it to a linear cost + JUST(AskSbpCombination4GeneralBasicCommunication( + sbp_producer, sbp_consumer, logical_blob_desc, producer_parallel_desc, + consumer_parallel_desc, middle_sbps, diag_node_pos)); + } + // Otherwise, one-step transfer + return Maybe::Ok(); + } +#endif // WITH_CUDA - NdSbp broadcast_nd; - for (int32_t i = 0; i < hierarchy_size; i++) { - broadcast_nd.add_sbp_parallel(); - broadcast_nd.mutable_sbp_parallel(i)->mutable_broadcast_parallel(); - } - middle_sbps.emplace_back(broadcast_nd); - } - return Maybe::Ok(); + if (JUST(ComputeLazyCopyCostBetweenNdSbp(sbp_producer, sbp_consumer, logical_blob_desc, + producer_parallel_desc, consumer_parallel_desc, + /*requires_same_sbp=*/false)) + < GetValidMaxCopyCost()) { + return Maybe::Ok(); + } else { + int32_t require_init_type = + int32_t(enable_general_basic_communication + || Singleton::Get()->nccl_use_compute_stream()); + if (init_type_ != require_init_type) { + // We assemble the boxing table from S(0) to S(5). + // Those splitting in higher axes are considered in the customized boxing. + constexpr int32_t kRegularMaxSplitAxes = 6; + JUST(Init(kRegularMaxSplitAxes)); } } @@ -568,6 +628,7 @@ Maybe BoxingCollector::AskSbpCombination(const NdSbp& sbp_producer, const // Transfer for the same machines, devices and hierarchy. if (sbp_producer == sbp_consumer) { return Maybe::Ok(); } const auto& parallel_hierarchy = producer_parallel_desc.hierarchy(); + *diag_node_pos = 0; // Dealing with nD sbp, n>2 if (parallel_hierarchy->NumAxes() > 2) { @@ -1007,4 +1068,105 @@ Maybe BoxingCollector::FilterNdSbpList4LogicalShape(const BlobDesc& logica return Maybe::Ok(); } +// Ask for sbp combination for general basic communication +Maybe BoxingCollector::AskSbpCombination4GeneralBasicCommunication( + const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, + std::vector& middle_sbps, int32_t* diag_node_pos) { + // (P, X) -> (B, X) || (X , P) -> (X, B), X is any SBP + // One step transfer, at most 50% reduction in the transfer cost, do not use middle nodes + if (producer_parallel_desc == consumer_parallel_desc + && producer_parallel_desc.hierarchy()->NumAxes() == 2 + && (sbp_producer.sbp_parallel(0) == sbp_consumer.sbp_parallel(0) + || sbp_producer.sbp_parallel(1) == sbp_consumer.sbp_parallel(1))) { + return Maybe::Ok(); + } + + // Not enough gain in transfer cost, do not use middle nodes + int32_t partial_ratio4producer = PartialRatio4Producer(sbp_producer, producer_parallel_desc); + int32_t broadcast_ratio4consumer = BroadcastRatio4Consumer(sbp_consumer, consumer_parallel_desc); + if (2 * (partial_ratio4producer + broadcast_ratio4consumer) + >= partial_ratio4producer * broadcast_ratio4consumer) { + return Maybe::Ok(); + } + + bool close2producer = true; + if (producer_parallel_desc.parallel_num() == consumer_parallel_desc.parallel_num()) { + // Get close to the one with more splits + close2producer = TotalNumSplit(sbp_producer, producer_parallel_desc) + > TotalNumSplit(sbp_consumer, consumer_parallel_desc); + } else { + // Get close to the one with more machines + close2producer = producer_parallel_desc.parallel_num() > consumer_parallel_desc.parallel_num(); + } + // Get the contiguous sbp + if (close2producer) { + JUST(AskCloseAllSplitSbp(sbp_producer, producer_parallel_desc, logical_blob_desc, middle_sbps)); + *diag_node_pos = 1; + } else { + JUST(AskCloseAllSplitSbp(sbp_consumer, consumer_parallel_desc, logical_blob_desc, middle_sbps)); + *diag_node_pos = 0; + } + return Maybe::Ok(); +} + +// Ask for a all-split sbp which is close to the original one +Maybe BoxingCollector::AskCloseAllSplitSbp(const NdSbp& nd_sbp, + const ParallelDesc& parallel_desc, + const BlobDesc& logical_blob_desc, + std::vector& middle_sbps) { + Shape remain_shape = logical_blob_desc.shape(); + Shape rest_split_shape = logical_blob_desc.shape(); + int32_t dim_shape = remain_shape.NumAxes(); + // Initialize the remains and splitting + // logical_blob_desc.shape() == remain_shape .* rest_split_shape; + for (int32_t i = 0; i < dim_shape; i++) { rest_split_shape.Set(i, 1); } + for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { + const auto& sbp = nd_sbp.sbp_parallel(sbp_id); + if (sbp.has_split_parallel()) { + int32_t axis = sbp.split_parallel().axis(); + int32_t split_num = parallel_desc.hierarchy()->At(sbp_id); + remain_shape.Set(axis, remain_shape.At(axis) / split_num); + rest_split_shape.Set(axis, rest_split_shape.At(axis) * split_num); + } + } + // Get the contiguous sbp + NdSbp new_sbp = nd_sbp; + for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { + const auto& sbp = nd_sbp.sbp_parallel(sbp_id); + int32_t split_num = parallel_desc.hierarchy()->At(sbp_id); + if (sbp.has_split_parallel()) { + int32_t axis = sbp.split_parallel().axis(); + // split shape is the total splitting number starting from sbp_id to the end + rest_split_shape.Set(axis, rest_split_shape.At(axis) / split_num); + } else { + // change P or B to S(axis) + int32_t axis = -1; + // 4096 is large enough, we might not have that much devices + int32_t min_split_num = 4096; + // We need to pick a suitable axis + for (int32_t i = 0; i < remain_shape.NumAxes(); i++) { + if (remain_shape.At(i) % split_num == 0) { + if (rest_split_shape.At(i) < min_split_num) { + // Pick the axis with smallest splitting number among the rest of the sbp + min_split_num = rest_split_shape.At(i); + axis = i; + } + } + } + // P, B -> S(axis) + if (axis >= 0) { + new_sbp.mutable_sbp_parallel(sbp_id)->mutable_split_parallel()->set_axis(axis); + remain_shape.Set(axis, remain_shape.At(axis) / split_num); + } else { + // Can not find a suitable contiguous sbp + return Maybe::Ok(); + } + } + } + // Add the new sbp into the middle node lists + middle_sbps.emplace_back(new_sbp); + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/oneflow/core/auto_parallel/boxing_collector.h b/oneflow/core/auto_parallel/boxing_collector.h index 09ddfd48f13..4661d6feb32 100644 --- a/oneflow/core/auto_parallel/boxing_collector.h +++ b/oneflow/core/auto_parallel/boxing_collector.h @@ -129,6 +129,15 @@ class BoxingCollector final { BoxingCollector* boxing_collector_producer, BoxingCollector* boxing_collector_consumer, const std::vector>& diag_nodes); + // Ask for sbp combination for general basic communication + Maybe AskSbpCombination4GeneralBasicCommunication( + const NdSbp& sbp_producer, const NdSbp& sbp_consumer, const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, + std::vector& middle_sbps, int32_t* diag_node_pos); + // Ask for a all-split sbp which is closed to the original one + Maybe AskCloseAllSplitSbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, + const BlobDesc& logical_blob_desc, + std::vector& middle_sbps); // Stores all the possible SbpParallel. HashMap<::oneflow::SbpParallel, int32_t> sbp_parallel_universe_; // Relationship between id and Sbp Parallel @@ -154,6 +163,11 @@ class BoxingCollector final { std::vector id_1d_2_nd_; // The sbp size in the combination table int32_t hierarchy_num_; + // How the boxing collector is initialized + int32_t init_type_ = -1; + // Enable general basic communication or not + const bool enable_general_basic_communication = + ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); }; // class BoxingCollector } // namespace oneflow diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 0f1d0b22f21..2687433c9ef 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -17,9 +17,15 @@ limitations under the License. #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/auto_parallel/boxing_collector.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" +#include "oneflow/core/common/device_type.pb.h" +#include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/job/global_for.h" #include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/resource_desc.h" +#include "oneflow/core/job/sbp_parallel.pb.h" namespace oneflow { @@ -55,6 +61,15 @@ double Penalty4PartialInConsumer(double logical_blob_size, int32_t producer_para } } +int32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, + const std::function& classifier) { + int32_t ratio = 1; + for (int32_t sbp_id = 0; sbp_id < nd_sbp.sbp_parallel_size(); sbp_id++) { + if (classifier(nd_sbp.sbp_parallel(sbp_id))) { ratio *= parallel_desc.hierarchy()->At(sbp_id); } + } + return ratio; +} + Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel, const SbpParallel& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, @@ -409,6 +424,16 @@ void CollaborativeParallelDimReduce(const ParallelDesc& in_parallel_desc, } // namespace +int32_t PartialRatio4Producer(const NdSbp& sbp_producer, + const ParallelDesc& producer_parallel_desc) { + return Ratio4Sbp(sbp_producer, producer_parallel_desc, &SbpParallel::has_partial_sum_parallel); +} + +int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, + const ParallelDesc& consumer_parallel_desc) { + return Ratio4Sbp(sbp_consumer, consumer_parallel_desc, &SbpParallel::has_broadcast_parallel); +} + void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp) { const auto& hierarchy = parallel_desc.hierarchy(); @@ -496,14 +521,31 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc)); } - // Not supporting different hierarchy - // TODO: Support it in the future + +#ifdef WITH_CUDA + static const bool enable_general_basic_communication = + ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); + // Use a general basic communication if no P in the consumer + if ((((Singleton::Get()->nccl_use_compute_stream() + && producer_parallel_desc == consumer_parallel_desc) + || enable_general_basic_communication) + && !NdSbpHasPartialParallel(consumer_sbp_parallel)) + && producer_parallel_desc.device_type() == DeviceType::kCUDA + && consumer_parallel_desc.device_type() == DeviceType::kCUDA) { + return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel, + logical_blob_desc, producer_parallel_desc, + consumer_parallel_desc) + + GetTransferCost(); + } +#endif // WITH_CUDA + + // Not supporting different hierarchy without general basic communication if (in_hierarchy->elem_cnt() != out_hierarchy->elem_cnt()) { return kUnsupportedBoxing; } - double logical_blob_size = - logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); bool on_same_devices = reduced_in_parallel_desc.EqualsIgnoringHierarchy(reduced_out_parallel_desc); + double logical_blob_size = + logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); if (in_dim == 2 && out_dim == 2) { // Not supporting different hierarchy @@ -629,6 +671,39 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { + // Reduce before cost computation + ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; + NdSbp reduced_in_nd_sbp; + NdSbpDimReduce(producer_parallel_desc, producer_sbp_parallel, &reduced_in_parallel_desc, + &reduced_in_nd_sbp); + + ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + NdSbp reduced_out_nd_sbp; + NdSbpDimReduce(consumer_parallel_desc, consumer_sbp_parallel, &reduced_out_parallel_desc, + &reduced_out_nd_sbp); + // In 90% of the transfer, we would have the same parallel description for producer and consumer + // We need to speed it up and give an approximation of the cost + if (reduced_in_parallel_desc == reduced_out_parallel_desc + && reduced_in_nd_sbp == reduced_out_nd_sbp) { + return 0.0; + } +#ifdef WITH_CUDA + static const bool enable_general_basic_communication = + ParseBooleanFromEnv("ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION", false); + // Use a general basic communication if no P in the consumer + if ((((Singleton::Get()->nccl_use_compute_stream() + && producer_parallel_desc == consumer_parallel_desc) + || enable_general_basic_communication) + && !NdSbpHasPartialParallel(consumer_sbp_parallel)) + && producer_parallel_desc.device_type() == DeviceType::kCUDA + && consumer_parallel_desc.device_type() == DeviceType::kCUDA) { + return Cost4GeneralBasicCommunication(producer_sbp_parallel, consumer_sbp_parallel, + logical_blob_desc, producer_parallel_desc, + consumer_parallel_desc) + + GetTransferCost(); + } +#endif // WITH_CUDA + // Initialize boxing collector constexpr int32_t kRegularMaxSplitAxes = 6; static thread_local BoxingCollector boxing_collector(kRegularMaxSplitAxes); @@ -727,4 +802,98 @@ double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consum } } +// The transfer ratio for general basic communication +// Cost = ratio * data amount +// When we get the this function, either producer_sbp_parallel != consumer_sbp_parallel +// or producer_parallel_desc != consumer_parallel_desc +double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel, + const NdSbp& consumer_sbp_parallel, + const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc) { + // The upper bound of the amount of the transferred data + int32_t producer_partial_ratio = + PartialRatio4Producer(producer_sbp_parallel, producer_parallel_desc); + int32_t consumer_broadcast_ratio = + BroadcastRatio4Consumer(consumer_sbp_parallel, consumer_parallel_desc); + // More intersection on the same devices + bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); + // approximate intersection ratio + double intersection_ratio = 1.0; + // (?, P, ?)->(Si, Sj)->(?, B, ?), two-step transfer + if (producer_partial_ratio > 1 && consumer_broadcast_ratio > 1) { + if (on_same_devices) { + // Pure P in the producer or B in the consumer + // (P, P, P) -> ? or ? -> (B, B) + if (producer_partial_ratio == producer_parallel_desc.parallel_num() + || consumer_broadcast_ratio == consumer_parallel_desc.parallel_num()) { + // There some cases which is not applicable to this ratio + // We just take the one with the largest possibility + // For example: (P, S0) -> (B, B) for 1-D blob with machine hierarchy [n, m] + // The path should be (P, S0) -> (S0, S0) -> (B, B) + // true intersection ratio = 1/m + 1 + intersection_ratio = 2.0; + } else { + // sbp_consumer = (B, Si) or (Si, B) + for (int32_t sbp_id = 0; sbp_id < std::min(producer_sbp_parallel.sbp_parallel_size(), + consumer_sbp_parallel.sbp_parallel_size()); + sbp_id++) { + if (consumer_sbp_parallel.sbp_parallel(sbp_id).has_split_parallel()) { + const auto& producer_sbp4sbp_id = producer_sbp_parallel.sbp_parallel(sbp_id); + // (B, P) or (Si, P) -> (Si, B) + // (P, B) or (P, Si) -> (B, Si) + if (producer_sbp4sbp_id.has_broadcast_parallel() + || producer_sbp4sbp_id == consumer_sbp_parallel.sbp_parallel(sbp_id)) { + intersection_ratio = 2.0; + break; + } + } + } + // Judge whether the intersection ratio is given a value (2.0) + if (intersection_ratio == 1.0) { + // The true intersection ratio range from 0 to 2, + // we just take a middle point of the range as the approximation + // For example: (P, S0) -> (S0, B), Path: (P, S0) -> (S1, S0) -> (S0, B) + // true intersection ratio = 1 + 1/m + // For example: (P, S0) -> (S1, B), Path: (P, S0) -> (S1, S0) -> (S1, B) + // true intersection ratio = 1 + 1 + // For example: (P, S0) -> (B, S0), with a 1D blob + // true intersection ratio = (n+p-1)/nm + (n+p-1)/nm + // For example: (S0, P) -> (B, S0), Path: (S0, P) -> (S0, S1) -> (B, S0) + // true intersection ratio = 1 + 1/n + + // We use the approximation 1 + (1/n + 1/m)/2 + intersection_ratio = 1.0 + 0.5 / producer_parallel_desc.hierarchy()->At(0) + + 0.5 / producer_parallel_desc.hierarchy()->At(1); + } + } + } + // Otherwise, on different devices + // intersection_ratio = 1.0; + } else { + // No P in the producer or no B in the consumer, one-step transfer + if (on_same_devices) { + // We use simulation for nD sbp with n=1,2,3,... + TensorSliceView in_second_slice = + GetTensorSliceView4ParallelId(*producer_parallel_desc.hierarchy(), producer_sbp_parallel, + logical_blob_desc.shape(), /*parallel_id=*/1); + TensorSliceView out_second_slice = + GetTensorSliceView4ParallelId(*consumer_parallel_desc.hierarchy(), consumer_sbp_parallel, + logical_blob_desc.shape(), /*parallel_id=*/1); + const TensorSliceView& intersection = in_second_slice.Intersect(out_second_slice); + // The intersection ratio is design for two steps. + // However, we only have one step here, we would increase the ratio by 1.0 + // to eliminate the unused step + intersection_ratio += std::min( + 1.0, (double)(intersection.shape().elem_cnt() * producer_parallel_desc.parallel_num()) + / logical_blob_desc.shape().elem_cnt()); + } + // Otherwise, on different devices + // intersection_ratio = 1.0; + } + // Subtract the intersection part + return (producer_partial_ratio + consumer_broadcast_ratio - intersection_ratio) + * logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); +} + } // namespace oneflow diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 6af5f84faab..21d7da6ae90 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -33,6 +33,16 @@ enum Penalty4PartialInConsumerTag : int { kStrict = 3 // Not allow a transfer to P }; +// [2, 3, 4, 5, 9, 100, 8]: (P, S0, P, P, B, S1, P) +// partial ratio = 2 * 4 * 5 * 8 +int32_t PartialRatio4Producer(const NdSbp& sbp_producer, + const ParallelDesc& producer_parallel_desc); + +// [2, 3, 4, 5, 9, 100, 8]: (P, S0, B, P, B, S1, P) +// broadcast ratio = 4 * 9 +int32_t BroadcastRatio4Consumer(const NdSbp& sbp_consumer, + const ParallelDesc& consumer_parallel_desc); + void NdSbpDimReduce(const ParallelDesc& parallel_desc, const NdSbp& nd_sbp, ParallelDesc* reduced_parallel_desc, NdSbp* reduced_nd_sbp); @@ -96,6 +106,14 @@ double ComputeSbpInferPriority(const NdSbp& producer_sbp_parallel, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp); +// The transfer ratio for general basic communication +// Cost = ratio * data amount +double Cost4GeneralBasicCommunication(const NdSbp& producer_sbp_parallel, + const NdSbp& consumer_sbp_parallel, + const BlobDesc& logical_blob_desc, + const ParallelDesc& producer_parallel_desc, + const ParallelDesc& consumer_parallel_desc); + } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_SBP_INFER_UTIL_H_ diff --git a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp index 618db1e23c4..7592e50c9f2 100644 --- a/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp +++ b/oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.cpp @@ -27,6 +27,10 @@ limitations under the License. #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/core/graph/task_stream_id.h" +#include "oneflow/core/job/job_desc.h" namespace oneflow { @@ -46,6 +50,37 @@ std::shared_ptr Make1DSubTskGphBuilder() { return std::make_shared(builders); } +void MergeParallelConf(const ParallelDesc& parallel_desc_0, const ParallelDesc& parallel_desc_1, + ParallelConf* parallel_conf) { + CHECK_EQ(parallel_desc_0.device_tag(), parallel_desc_1.device_tag()); + std::set> machine_device_ids; + for (int64_t machine_id : parallel_desc_0.sorted_machine_ids()) { + for (int64_t device_id : parallel_desc_0.sorted_dev_phy_ids(machine_id)) { + machine_device_ids.insert(std::make_pair(machine_id, device_id)); + } + } + for (int64_t machine_id : parallel_desc_1.sorted_machine_ids()) { + for (int64_t device_id : parallel_desc_1.sorted_dev_phy_ids(machine_id)) { + machine_device_ids.insert(std::make_pair(machine_id, device_id)); + } + } + parallel_conf->set_device_tag(parallel_desc_0.device_tag()); + for (const auto& pair : machine_device_ids) { + parallel_conf->add_device_name("@" + std::to_string(pair.first) + ":" + + std::to_string(pair.second)); + } +} + +inline std::string NewUniqueIdGbc() { + static std::atomic counter(0); + static std::atomic curr_job_id(0); + if (curr_job_id != GlobalJobDesc().job_id()) { + curr_job_id = GlobalJobDesc().job_id(); + counter = 0; + } + return std::to_string(counter.fetch_add(1, std::memory_order_relaxed)); +} + } // namespace class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { @@ -78,6 +113,68 @@ class FlatSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { std::shared_ptr sub_tsk_gph_builder_; }; +class NDNcclSendRecvBoxingSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { + public: + OF_DISALLOW_COPY_AND_MOVE(NDNcclSendRecvBoxingSubTskGphBuilder); + NDNcclSendRecvBoxingSubTskGphBuilder() {} + ~NDNcclSendRecvBoxingSubTskGphBuilder() override = default; + + Maybe Build(SubTskGphBuilderCtx* ctx, + const std::vector& sorted_in_tasks, + std::vector* sorted_out_tasks, + std::vector>* sorted_ctrl_tasks, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, + const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, + const NdSbp& in_nd_sbp, const NdSbp& out_nd_sbp, + const Shape& time_shape) const override { + if (in_parallel_desc.device_type() == DeviceType::kCUDA + && out_parallel_desc.device_type() == DeviceType::kCUDA + && !NdSbpHasPartialParallel(out_nd_sbp)) { +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + ParallelConf merged_parallel_conf; + MergeParallelConf(in_parallel_desc.parallel_conf(), out_parallel_desc.parallel_conf(), + &merged_parallel_conf); + ParallelDesc merged_parallel_desc(merged_parallel_conf); + TaskNode* first_in_node = sorted_in_tasks.front(); + sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); + std::string stream_name = "NCCL_SEND_RECV_BOXING" + NewUniqueIdGbc(); + FOR_RANGE(int64_t, id, 0, merged_parallel_desc.parallel_num()) { + NcclSendRecvBoxingTaskNode* node = ctx->task_graph()->NewNode(); + const int64_t machine_id = JUST(merged_parallel_desc.MachineId4ParallelId(id)); + int64_t device_index = JUST(merged_parallel_desc.DeviceId4ParallelId(id)); + int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId( + machine_id, merged_parallel_desc.device_type(), device_index, stream_name)); + bool has_input = in_parallel_desc.Containing(machine_id, device_index); + bool has_output = out_parallel_desc.Containing(machine_id, device_index); + node->Init(machine_id, thrd_id, lbi, logical_blob_desc.shape(), + logical_blob_desc.data_type(), in_nd_sbp, out_nd_sbp, in_parallel_desc, + out_parallel_desc, id, merged_parallel_desc, has_input, has_output, stream_name); + if (has_input) { + int64_t in_id = + JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + ctx->task_graph()->ConnectWithLbi(sorted_in_tasks.at(in_id), node, lbi); + } else { + // TODO: find nearest + std::string regst_desc_name; + first_in_node->BuildCtrlRegstDesc(node, ®st_desc_name); + TaskEdge* edge = ctx->task_graph()->NewEdge(); + Connect(first_in_node, edge, node); + first_in_node->BindEdgeWithProducedRegst(edge, regst_desc_name); + } + if (has_output) { sorted_out_tasks->push_back(node); } + } + return BuildSubTskGphBuilderStatus("NDNcclSendRecvBoxingSubTskGphBuilder", ""); +#else + return Error::BoxingNotSupportedError() << "No CUDA or low NCCL version"; +#endif + } else { + return Error::BoxingNotSupportedError() + << "Partial SBP in the consumer or not running on CUDA"; + } + } +}; + class IntraGroupSubTskGphBuilder final : public HierarchicalSubTskGphBuilder { public: OF_DISALLOW_COPY_AND_MOVE(IntraGroupSubTskGphBuilder); @@ -257,21 +354,22 @@ class Dim0NdSbpMismatchedSubTskGphBuilder final : public HierarchicalSubTskGphBu if (in_parallel_desc.hierarchy()->NumAxes() == 2 && (*in_parallel_desc.hierarchy() == *out_parallel_desc.hierarchy()) && in_nd_sbp.sbp_parallel(0) != out_nd_sbp.sbp_parallel(0) - && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) { - if (!(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) { - return inter_group_sub_tsk_gph_builder_->Build( - ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, - out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else { - return Error::BoxingNotSupportedError(); - } + && in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1) + && !(NdSbpAllSameSplitParallel(in_nd_sbp) || NdSbpAllSameSplitParallel(out_nd_sbp))) { + return inter_group_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, + out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } else { - return Error::BoxingNotSupportedError(); + return nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, + out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); } } private: std::unique_ptr inter_group_sub_tsk_gph_builder_; + std::unique_ptr + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilder { @@ -298,12 +396,10 @@ class Same2DHierarchySubTskGphBuilder final : public HierarchicalSubTskGphBuilde return intra_group_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else if (in_nd_sbp.sbp_parallel(1) == out_nd_sbp.sbp_parallel(1)) { + } else { return dim0_nd_sbp_mismatched_sub_tsk_gph_builder_->Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, in_nd_sbp, out_nd_sbp, time_shape); - } else { - return Error::BoxingNotSupportedError(); } } else { return Error::BoxingNotSupportedError(); @@ -371,6 +467,8 @@ struct DispatchHierarchicalSubTskGphBuilder::Impl { std::unique_ptr same_2d_hierarchy_sub_tsk_gph_builder_; std::unique_ptr expand_to_same_2d_hierarchy_sub_tsk_gph_builder_; + std::unique_ptr + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_; }; DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { @@ -378,6 +476,7 @@ DispatchHierarchicalSubTskGphBuilder::Impl::Impl() { same_2d_hierarchy_sub_tsk_gph_builder_.reset(new Same2DHierarchySubTskGphBuilder()); expand_to_same_2d_hierarchy_sub_tsk_gph_builder_.reset( new ExpandToSame2DHierarchySubTskGphBuilder()); + nd_nccl_send_recv_boxing_sub_tsk_gph_builder_.reset(new NDNcclSendRecvBoxingSubTskGphBuilder()); } DispatchHierarchicalSubTskGphBuilder::DispatchHierarchicalSubTskGphBuilder() { @@ -402,6 +501,14 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( &reduced_out_nd_sbp); const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); + if ((in_hierarchy->NumAxes() > 2 || out_hierarchy->NumAxes() > 2) + && reduced_in_parallel_desc.device_type() == DeviceType::kCUDA + && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { + return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, + reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, + time_shape); + } if (in_hierarchy->NumAxes() <= 2 && out_hierarchy->NumAxes() <= 2) { if (in_hierarchy->NumAxes() == 1 && out_hierarchy->NumAxes() == 1) { return impl_->flat_sub_tsk_gph_builder_->Build( @@ -420,6 +527,12 @@ Maybe DispatchHierarchicalSubTskGphBuilder::Build( ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, time_shape); + } else if (reduced_in_parallel_desc.device_type() == DeviceType::kCUDA + && reduced_out_parallel_desc.device_type() == DeviceType::kCUDA) { + return impl_->nd_nccl_send_recv_boxing_sub_tsk_gph_builder_->Build( + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, reduced_in_parallel_desc, + reduced_out_parallel_desc, lbi, logical_blob_desc, reduced_in_nd_sbp, reduced_out_nd_sbp, + time_shape); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp new file mode 100644 index 00000000000..e6ab2530c36 --- /dev/null +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.cpp @@ -0,0 +1,96 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/to_string.h" +#include "oneflow/core/graph/nccl_send_recv_boxing_task_node.h" + +namespace oneflow { + +void NcclSendRecvBoxingTaskNode::Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const DataType& data_type, + const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, + const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, + const int64_t parallel_id, const ParallelDesc& parallel_desc, + const bool has_input, const bool has_output, + const std::string& stream_name) { + set_machine_id(machine_id); + set_thrd_id(thrd_id); + set_lbi(lbi); + logical_shape_ = logical_shape; + src_nd_sbp_ = src_nd_sbp; + dst_nd_sbp_ = dst_nd_sbp; + src_parallel_conf_ = src_parallel_desc.parallel_conf(); + dst_parallel_conf_ = dst_parallel_desc.parallel_conf(); + parallel_conf_ = parallel_desc.parallel_conf(); + parallel_ctx_.set_parallel_id(parallel_id); + parallel_ctx_.set_parallel_num(parallel_desc.parallel_num()); + has_input_ = has_input; + has_output_ = has_output; + data_type_ = data_type; + stream_name_ = stream_name; +} + +void NcclSendRecvBoxingTaskNode::ProduceAllRegstsAndBindEdges() { + if (has_output_) { + std::shared_ptr out_regst = ProduceRegst("out", true, 1, 1); + this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); + } + ProduceRegst("tmp", true); +} + +void NcclSendRecvBoxingTaskNode::ConsumeAllRegsts() { + this->ForEachInDataEdge( + [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); +} + +void NcclSendRecvBoxingTaskNode::BuildExecGphAndRegst() { + ExecNode* node = mut_exec_gph().NewNode(); + OperatorConf op_conf; + op_conf.set_name("System-Nccl-Send-Recv-Boxing-" + NewUniqueId()); + op_conf.set_device_tag(*CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); + op_conf.set_stream_name_hint(stream_name_); + auto* nccl_send_recv_boxing_conf = op_conf.mutable_nccl_send_recv_boxing_conf(); + *nccl_send_recv_boxing_conf->mutable_lbi() = lbi(); + logical_shape_.ToProto(nccl_send_recv_boxing_conf->mutable_logical_shape()); + nccl_send_recv_boxing_conf->set_data_type(data_type_); + *nccl_send_recv_boxing_conf->mutable_src_nd_sbp() = src_nd_sbp_; + *nccl_send_recv_boxing_conf->mutable_dst_nd_sbp() = dst_nd_sbp_; + *nccl_send_recv_boxing_conf->mutable_parallel_conf() = parallel_conf_; + *nccl_send_recv_boxing_conf->mutable_src_parallel_conf() = src_parallel_conf_; + *nccl_send_recv_boxing_conf->mutable_dst_parallel_conf() = dst_parallel_conf_; + nccl_send_recv_boxing_conf->set_has_input(has_input_); + nccl_send_recv_boxing_conf->set_has_output(has_output_); + std::shared_ptr sole_op = CHECK_JUST(ConstructOp(op_conf)); + node->mut_op() = sole_op; + CHECK_JUST(sole_op->FillOpParallelDesc(parallel_conf_)); + if (has_input_) { node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); } + if (has_output_) { + std::shared_ptr out_regst = GetProducedRegst("out"); + out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); + node->BindBnWithRegst(sole_op->SoleObn(), out_regst); + } + node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp")); + node->InferBlobDescs(parallel_ctx()); +} + +void NcclSendRecvBoxingTaskNode::InferProducedDataRegstTimeShape() { + auto out_regst = GetProducedRegst("out"); + if (out_regst != nullptr) { out_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); } + auto tmp_regst = GetProducedRegst("tmp"); + tmp_regst->mut_data_regst_time_shape()->reset(new Shape({1, 1})); +} + +} // namespace oneflow diff --git a/oneflow/core/graph/nccl_send_recv_boxing_task_node.h b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h new file mode 100644 index 00000000000..1fcc4482f0e --- /dev/null +++ b/oneflow/core/graph/nccl_send_recv_boxing_task_node.h @@ -0,0 +1,59 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ + +#include "oneflow/core/graph/transport_task_node.h" + +namespace oneflow { + +class NcclSendRecvBoxingTaskNode : public TransportTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingTaskNode); + NcclSendRecvBoxingTaskNode() = default; + ~NcclSendRecvBoxingTaskNode() override = default; + + void Init(int64_t machine_id, int64_t thrd_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const DataType& data_type, const NdSbp& src_nd_sbp, + const NdSbp& dst_nd_sbp, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, const int64_t parallel_id, + const ParallelDesc& parallel_desc, const bool has_input, const bool has_output, + const std::string& stream_name); + TaskType GetTaskType() const override { return TaskType::kNcclSendRecvBoxing; } + const ParallelContext* parallel_ctx() const override { return ¶llel_ctx_; } + + private: + void BuildExecGphAndRegst() override; + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() final; + void InferProducedDataRegstTimeShape() final; + + Shape logical_shape_; + DataType data_type_; + NdSbp src_nd_sbp_; + NdSbp dst_nd_sbp_; + ParallelConf src_parallel_conf_; + ParallelConf dst_parallel_conf_; + ParallelConf parallel_conf_; + ParallelContext parallel_ctx_; + bool has_input_; + bool has_output_; + std::string stream_name_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_NCCL_SEND_RECV_BOXING_TASK_NODE_H_ diff --git a/oneflow/core/graph/straighten_nodes.cpp b/oneflow/core/graph/straighten_nodes.cpp index c6f27d73d15..88b9de6b9b5 100644 --- a/oneflow/core/graph/straighten_nodes.cpp +++ b/oneflow/core/graph/straighten_nodes.cpp @@ -104,6 +104,7 @@ bool IsTransferNode(TaskType task_type) { switch (task_type) { // We mark the number of occurrences in bert case TaskType::kCollectiveBoxingGeneric: // 76 + case TaskType::kNcclSendRecvBoxing: // ? case TaskType::kCopyHd: // 27 case TaskType::kSliceBoxing: // 16 case TaskType::kCopyCommNet: // 12 diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 70a7cd34343..8b97e158090 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -727,6 +727,12 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const ParallelDesc& src_parallel_desc = src_op_node->parallel_desc(); const ParallelDesc& dst_parallel_desc = dst_op_node->parallel_desc(); const BlobDesc& blob_desc = src_op_node->LogicalBlobDesc4Lbi(lbi); + VLOG(3) << "src op: " << src_op_node->op().op_name() + << " dst op: " << dst_op_node->op().op_name() + << " src_parallel_conf: " << src_parallel_desc.parallel_conf().DebugString() + << " dst parallel conf: " << dst_parallel_desc.parallel_conf().DebugString() + << " src_nd_sbp " << src_nd_sbp.DebugString() << " dst nd_sbp " + << dst_nd_sbp.DebugString(); auto status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index 2fa0ab540f3..00ffc0bbb74 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -14,12 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/vm/vm_util.h" #ifdef WITH_CUDA @@ -78,8 +80,15 @@ void CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key, << ", key = {" << key << "}\n"; } -bool NeedUnifiedNcclCommInit(const std::string& op_type_name) { - return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered(op_type_name); +bool NeedUnifiedNcclCommInit(const OperatorConf& op_conf) { + if (op_conf.has_user_conf()) { + return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered( + op_conf.user_conf().op_type_name()); + } else { + // Please check the .h file for hard-coding of the name + return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered( + kSystemOpPrefix + std::to_string(op_conf.op_type_case())); + } } } // namespace @@ -169,9 +178,7 @@ void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) { continue; } const auto& op_conf = op_attr->op_conf(); - if (!op_conf.has_user_conf()) { continue; } - if (!NeedUnifiedNcclCommInit(op_conf.user_conf().op_type_name())) { continue; } - + if (!NeedUnifiedNcclCommInit(op_conf)) { continue; } if (!op_attr->has_parallel_conf_signature()) { continue; } if (!op_attr->parallel_conf_signature().has_op_parallel_conf()) { continue; } diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index b57a2cd92fe..33b27e930a8 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -83,12 +83,19 @@ class UserKernelUnifiedNcclCommInitRegistry final { std::set reg_set_; }; +static const std::string kSystemOpPrefix = "sys_op_"; + } // namespace oneflow #define REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_name) \ static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(op_type_name) +#define REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_case) \ + static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \ + ::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(::oneflow::kSystemOpPrefix \ + + std::to_string(op_type_case)) + #endif // WITH_CUDA #endif // ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_ diff --git a/oneflow/core/job/nd_sbp_util.cpp b/oneflow/core/job/nd_sbp_util.cpp index 4bbab195e01..c93974acc18 100644 --- a/oneflow/core/job/nd_sbp_util.cpp +++ b/oneflow/core/job/nd_sbp_util.cpp @@ -19,48 +19,6 @@ limitations under the License. #include "oneflow/core/common/nd_index_offset_helper.h" namespace oneflow { -namespace { -// Go through all the ranks while transfer between two nd sbps with no PartialSum under the same -// placement. -// NOTE: We need to make sure no partial sums in the sbps of the producer and consumer. -void DfsTraverseRanks4NdSbp( - int32_t depth, std::vector& in_parallel_ids, - const std::vector& out_parallel_ids, const Shape& parallel_hierarchy, - const NdIndexOffsetHelper& hierarchy_index_helper, - const NdSbp& in_nd_sbp, const std::function& visit) { - if (depth >= parallel_hierarchy.NumAxes()) { - visit(hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), - parallel_hierarchy.NumAxes())); - return; - } - if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) { - // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the - // current rank along the depth-dimension. - in_parallel_ids[depth] = out_parallel_ids[depth]; - DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, parallel_hierarchy, - hierarchy_index_helper, in_nd_sbp, visit); - } else { - // If Split or PartialSum, go through all the ranks along the depth-dimension. - for (int64_t i = 0; i < parallel_hierarchy.dim_vec().at(depth); i++) { - in_parallel_ids[depth] = i; - DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, parallel_hierarchy, - hierarchy_index_helper, in_nd_sbp, visit); - } - } -} - -void DfsTraverse4NdSbp(int64_t recv_id, const std::shared_ptr& parallel_hierarchy, - const NdSbp& in_nd_sbp, const std::function& visit) { - int32_t hierarchy_dimension = parallel_hierarchy->NumAxes(); - const NdIndexOffsetHelper hierarchy_index_helper( - parallel_hierarchy->dim_vec().data(), hierarchy_dimension); - std::vector in_parallel_ids(hierarchy_dimension); - std::vector out_parallel_ids(hierarchy_dimension); - hierarchy_index_helper.OffsetToNdIndex(recv_id, out_parallel_ids.data(), hierarchy_dimension); - DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *parallel_hierarchy, - hierarchy_index_helper, in_nd_sbp, visit); -} -} // namespace std::vector GetTensorSliceView(const int64_t parallel_num, const SbpParallel& sbp_parallel, @@ -199,45 +157,4 @@ bool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis) { return true; } -void GetRankSendRecvIntersection(int64_t parallel_id, - const std::shared_ptr& parallel_hierarchy, - const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, - const Shape& logical_shape, - std::vector* send_intersections, - std::vector* recv_intersections) { - CHECK(parallel_hierarchy != nullptr); - const int64_t parallel_num = parallel_hierarchy->elem_cnt(); - CHECK_LT(parallel_id, parallel_num); - - const std::vector& in_slices = - GetTensorSliceView(*parallel_hierarchy, src_nd_sbp, logical_shape); - const std::vector& out_slices = - GetTensorSliceView(*parallel_hierarchy, dst_nd_sbp, logical_shape); - - // cur rank recv from - recv_intersections->resize(parallel_num); - const TensorSliceView& cur_rank_out_slice = out_slices.at(parallel_id); - const auto& add_to_recv_intersections = [&](int32_t send_id) { - const TensorSliceView& in_slice = in_slices.at(send_id); - const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); - if (intersection.IsEmpty()) { return; } - recv_intersections->at(send_id) = intersection; - }; - DfsTraverse4NdSbp(parallel_id, parallel_hierarchy, src_nd_sbp, add_to_recv_intersections); - - // cur rank send to - send_intersections->resize(parallel_num); - const TensorSliceView& cur_rank_in_slice = in_slices.at(parallel_id); - for (int64_t recv_i = 0; recv_i < parallel_num; ++recv_i) { - const auto& add_to_send_intersections = [&](int32_t send_id) { - if (send_id != parallel_id) { return; } - const TensorSliceView& out_slice = out_slices.at(recv_i); - const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); - if (intersection.IsEmpty()) { return; } - send_intersections->at(recv_i) = intersection; - }; - DfsTraverse4NdSbp(recv_i, parallel_hierarchy, src_nd_sbp, add_to_send_intersections); - } -} - } // namespace oneflow diff --git a/oneflow/core/job/nd_sbp_util.h b/oneflow/core/job/nd_sbp_util.h index 7eac44a52fc..be8b72c7746 100644 --- a/oneflow/core/job/nd_sbp_util.h +++ b/oneflow/core/job/nd_sbp_util.h @@ -39,12 +39,6 @@ bool NdSbpIsAllSplit(const NdSbp& nd_sbp, int64_t axis); bool NdSbpHasPartialParallel(const NdSbp& nd_sbp); bool NdSbpHasBroadcastParallel(const NdSbp& nd_sbp); -void GetRankSendRecvIntersection(int64_t parallel_id, - const std::shared_ptr& parallel_hierarchy, - const NdSbp& src_nd_sbp, const NdSbp& dst_nd_sbp, - const Shape& logical_shape, - std::vector* send_intersections, - std::vector* recv_intersections); } // namespace oneflow #endif // ONEFLOW_CORE_JOB_SBP_PARALLEL_H_ diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index e4df1c4a0db..2fb82cc1ab9 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -38,6 +38,7 @@ enum TaskType { kSspVariableProxy = 63; kBoxingZeros = 64; kCriticalSectionWaitTick = 65; + kNcclSendRecvBoxing = 66; }; message RegstDescIdSet { diff --git a/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp b/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp index 91ed0f77f87..79fb1fb429d 100644 --- a/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp +++ b/oneflow/core/job_rewriter/boxing_with_middle_nodes.cpp @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/job_rewriter/boxing_with_middle_nodes.h" +#include "oneflow/core/common/just.h" #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/sbp_infer_util.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/auto_parallel/boxing_collector.h" @@ -30,10 +32,6 @@ Maybe BoxingWithMiddleNodes(const OpGraph& op_graph, JobBuilder* job_build } // Initialize boxing collector BoxingCollector boxing_collector; - // We assemble the boxing table from S(0) to S(5). - // Those splitting in higher axes are considered in the customized boxing. - constexpr int32_t kRegularMaxSplitAxes = 6; - JUST(boxing_collector.Init(kRegularMaxSplitAxes)); std::vector middle_sbps; HashMap op_node2op_conf; // Fill other unsupported combinations diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp new file mode 100644 index 00000000000..6bb52bedbd6 --- /dev/null +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -0,0 +1,256 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" +#include "oneflow/core/register/tensor_slice_copier.h" +#include "oneflow/core/ep/include/primitive/memset.h" +#include "oneflow/core/ep/include/primitive/add.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + +namespace oneflow { + +class NcclSendRecvBoxingKernel final : public Kernel { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingKernel); + NcclSendRecvBoxingKernel() = default; + ~NcclSendRecvBoxingKernel() override = default; + + const std::vector>& in_tensor_slice_copier_vec() const { + return in_tensor_slice_copier_vec_; + } + const std::vector>& out_tensor_slice_copier_vec() const { + return out_tensor_slice_copier_vec_; + } + const std::vector& send_elem_cnts() const { return send_elem_cnts_; } + const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } + const bool has_input() const { return has_input_; } + const bool has_output() const { return has_output_; } + ncclComm_t comm() const { return GetOrCreate().comm; } + + private: + struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; + }; + + void Init() const { + ParallelDesc parallel_desc(parallel_conf_); + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); + ncclComm_t comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + comm_.reset(new Comm(comm)); + } + + const Comm& GetOrCreate() const { + if (!comm_) { Init(); } + return *comm_; + } + + void VirtualKernelInit(KernelContext* ctx) override; + void ForwardDataContent(KernelContext* ctx) const override; + + std::string stream_name_; + ParallelConf parallel_conf_; + mutable std::unique_ptr comm_; + bool src_nd_sbp_no_partial_parallel_; + std::vector> in_tensor_slice_copier_vec_; + std::vector> out_tensor_slice_copier_vec_; + std::vector send_elem_cnts_; + std::vector recv_elem_cnts_; + bool has_input_; + bool has_output_; +}; + +void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { + Blob* buf = ctx->BnInOp2Blob("buf"); + ncclComm_t comm = this->comm(); + cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + const std::vector& send_elem_cnts = this->send_elem_cnts(); + const std::vector& recv_elem_cnts = this->recv_elem_cnts(); + const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num(); + const DataType data_type = buf->data_type(); + std::vector send_in_ptr; + std::vector recv_out_ptr; + char* buf_ptr = buf->mut_dptr(); + int64_t offset = 0; + if (this->has_input()) { + for (int64_t i = 0; i < parallel_num; ++i) { + void* send_ptr = reinterpret_cast(buf_ptr + offset); + send_in_ptr.push_back(send_ptr); + offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + } + if (this->has_output()) { + for (int64_t i = 0; i < parallel_num; ++i) { + void* recv_ptr = reinterpret_cast(buf_ptr + offset); + recv_out_ptr.push_back(recv_ptr); + offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); + } + } + if (this->has_input()) { + const Blob* in = ctx->BnInOp2Blob("in"); + const std::vector>& in_tensor_slice_copier_vec = + this->in_tensor_slice_copier_vec(); + for (int64_t i = 0; i < parallel_num; ++i) { + if (in_tensor_slice_copier_vec.at(i)) { + in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); + } + } + } + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < parallel_num; ++i) { + if (this->has_input() && send_elem_cnts.at(i) != 0) { + OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, + comm, cuda_stream)); + } + if (this->has_output() && recv_elem_cnts.at(i) != 0) { + OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), + i, comm, cuda_stream)); + } + } + OF_NCCL_CHECK(ncclGroupEnd()); + if (!this->has_output()) { return; } + Blob* out = ctx->BnInOp2Blob("out"); + const std::vector>& out_tensor_slice_copier_vec = + this->out_tensor_slice_copier_vec(); + + if (src_nd_sbp_no_partial_parallel_) { + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), recv_out_ptr.at(i)); + } + } + } else { + std::unique_ptr primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type(), + out->data_type()); + CHECK(primitive); + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->stream()->device_type()); + CHECK(memset_primitive); + bool is_first_slice = true; + for (int64_t i = 0; i < parallel_num; ++i) { + if (out_tensor_slice_copier_vec.at(i)) { + if (is_first_slice) { + is_first_slice = false; + if (recv_elem_cnts.at(i) != out->shape().elem_cnt()) { + // if not same shape, memset out + memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + } + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out->mut_dptr(), + recv_out_ptr.at(i)); + } else { + if (recv_elem_cnts.at(i) == out->shape().elem_cnt()) { + primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(), + out->shape().elem_cnt()); + } else { + void* out_buf = reinterpret_cast(buf_ptr + offset); + memset_primitive->Launch(ctx->stream(), out_buf, 0, + out->shape().elem_cnt() * GetSizeOfDataType(data_type)); + out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); + primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), + out->shape().elem_cnt()); + } + } + } + } + } +} + +void NcclSendRecvBoxingKernel::VirtualKernelInit(KernelContext* ctx) { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + if (this->op_conf().has_stream_name_hint()) { + stream_name_ = this->op_conf().stream_name_hint(); + } else { + stream_name_ = EagerNcclCommMgr::kDefaultStreamName; + } + parallel_conf_ = conf.parallel_conf(); + const int64_t parallel_id = this->kernel_conf().parallel_ctx().parallel_id(); + ParallelDesc parallel_desc(parallel_conf_); + ParallelDesc src_parallel_desc(conf.src_parallel_conf()); + ParallelDesc dst_parallel_desc(conf.dst_parallel_conf()); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + has_input_ = conf.has_input(); + has_output_ = conf.has_output(); + src_nd_sbp_no_partial_parallel_ = !NdSbpHasPartialParallel(src_nd_sbp); + const DataType data_type = this->kernel_conf().data_type(); + const DeviceType device_type = parallel_desc.device_type(); + const Shape& logical_shape = Shape(conf.logical_shape()); + const int64_t parallel_num = parallel_desc.parallel_num(); + + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetRankSendRecvIntersection(parallel_id, parallel_desc, src_parallel_desc, dst_parallel_desc, + src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, + &dst_recv_intersections); + // if parallel_id exists in src parallel desc, has send + int64_t src_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, src_parallel_desc); + if (src_parallel_id != -1) { + CHECK_EQ(src_send_intersections.size(), parallel_num); + send_elem_cnts_.resize(parallel_num); + in_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_in_slice = GetTensorSliceView4ParallelId( + *src_parallel_desc.hierarchy(), src_nd_sbp, logical_shape, src_parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = src_send_intersections.at(i); + if (!intersection.IsEmpty()) { + send_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + in_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(intersection, cur_rank_in_slice, data_type, device_type)); + } + } + } else { + CHECK_EQ(src_send_intersections.size(), 0); + } + + // if parallel_id exists in src parallel desc, has send + int64_t dst_parallel_id = GetMappedParallelId(parallel_id, parallel_desc, dst_parallel_desc); + if (dst_parallel_id != -1) { + CHECK_EQ(dst_recv_intersections.size(), parallel_num); + recv_elem_cnts_.resize(parallel_num); + out_tensor_slice_copier_vec_.resize(parallel_num); + const TensorSliceView& cur_rank_out_slice = GetTensorSliceView4ParallelId( + *dst_parallel_desc.hierarchy(), dst_nd_sbp, logical_shape, dst_parallel_id); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = dst_recv_intersections.at(i); + if (!intersection.IsEmpty()) { + recv_elem_cnts_.at(i) = intersection.shape().elem_cnt(); + out_tensor_slice_copier_vec_.at(i).reset( + new TensorSliceCopier(cur_rank_out_slice, intersection, data_type, device_type)); + } + } + } else { + CHECK_EQ(dst_recv_intersections.size(), 0); + } +} + +REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingKernel); + +REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf); + +} // namespace oneflow + +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 diff --git a/oneflow/core/lazy/actor/naive_actor.cpp b/oneflow/core/lazy/actor/naive_actor.cpp index ed1e52166ad..e691e77a424 100644 --- a/oneflow/core/lazy/actor/naive_actor.cpp +++ b/oneflow/core/lazy/actor/naive_actor.cpp @@ -34,6 +34,7 @@ REGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor); +REGISTER_ACTOR(TaskType::kNcclSendRecvBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor); REGISTER_ACTOR(TaskType::kCriticalSectionWaitTick, NaiveActor); REGISTER_ACTOR(TaskType::kCopyHd, NaiveActor); diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op.cpp b/oneflow/core/operator/nccl_send_recv_boxing_op.cpp new file mode 100644 index 00000000000..d0d3417c413 --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op.cpp @@ -0,0 +1,142 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +namespace oneflow { + +class NcclSendRecvBoxingOp : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(NcclSendRecvBoxingOp); + NcclSendRecvBoxingOp() = default; + ~NcclSendRecvBoxingOp() override = default; + + Maybe InitFromOpConf() override; + Maybe InferInternalBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const JobDesc* job_desc) const override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + UNIMPLEMENTED_THEN_RETURN(); + } + Maybe InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + LogicalBlobId lbi4ibn(const std::string& input_bn) const override; + LogicalBlobId lbi4obn(const std::string& output_bn) const override; +}; + +Maybe NcclSendRecvBoxingOp::InitFromOpConf() { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + if (conf.has_input()) { EnrollInputBn("in", false); } + if (conf.has_output()) { EnrollOutputBn("out", false); } + EnrollTmpBn("buf"); + return Maybe::Ok(); +} + +Maybe NcclSendRecvBoxingOp::InferInternalBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const JobDesc* job_desc) const { + BlobDesc* buf = GetBlobDesc4BnInOp("buf"); + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + ParallelDesc parallel_desc(conf.parallel_conf()); + ParallelDesc in_parallel_desc(conf.src_parallel_conf()); + ParallelDesc out_parallel_desc(conf.dst_parallel_conf()); + const int64_t parallel_num = parallel_desc.parallel_num(); + const int64_t parallel_id = parallel_ctx->parallel_id(); + const Shape& logical_shape = Shape(conf.logical_shape()); + std::vector src_send_intersections; + std::vector dst_recv_intersections; + GetRankSendRecvIntersection(parallel_id, parallel_desc, in_parallel_desc, out_parallel_desc, + src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, + &dst_recv_intersections); + int64_t buf_count = 0; + if (conf.has_input()) { + const BlobDesc* in = GetBlobDesc4BnInOp("in"); + buf->set_data_type(in->data_type()); + CHECK_EQ(src_send_intersections.size(), parallel_num); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = JUST(VectorAt(src_send_intersections, i)); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + } + if (conf.has_output()) { + const BlobDesc* out = GetBlobDesc4BnInOp("out"); + buf->set_data_type(out->data_type()); + for (int64_t i = 0; i < parallel_num; ++i) { + const TensorSliceView& intersection = JUST(VectorAt(dst_recv_intersections, i)); + if (!intersection.IsEmpty()) { buf_count += intersection.shape().elem_cnt(); } + } + if (NdSbpHasPartialParallel(src_nd_sbp)) { + // Note: when src_nd_sbp has partial_sum, need a out_size buffer to copy and add to out. + buf_count += out->shape().elem_cnt(); + } + } + buf->mut_shape() = Shape({buf_count}); + return Maybe::Ok(); +} + +LogicalBlobId NcclSendRecvBoxingOp::lbi4ibn(const std::string& input_bn) const { + return this->op_conf().nccl_send_recv_boxing_conf().lbi(); +} + +LogicalBlobId NcclSendRecvBoxingOp::lbi4obn(const std::string& output_bn) const { + return this->op_conf().nccl_send_recv_boxing_conf().lbi(); +} + +Maybe NcclSendRecvBoxingOp::InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const NcclSendRecvBoxingOpConf& conf = this->op_conf().nccl_send_recv_boxing_conf(); + const Shape& logical_shape = Shape(conf.logical_shape()); + const ParallelDesc& parallel_desc = ParallelDesc(conf.parallel_conf()); + const int64_t machine_id = JUST(parallel_desc.MachineId4ParallelId(parallel_ctx->parallel_id())); + const int64_t device_index = JUST(parallel_desc.DeviceId4ParallelId(parallel_ctx->parallel_id())); + if (conf.has_input()) { + const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); + const NdSbp& src_nd_sbp = conf.src_nd_sbp(); + const ParallelDesc& src_parallel_desc = ParallelDesc(conf.src_parallel_conf()); + int64_t src_parallel_id = + JUST(src_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + std::shared_ptr in_shape = + JUST(GetPhysicalShape(logical_shape, src_nd_sbp, src_parallel_desc, src_parallel_id)); + CHECK_EQ_OR_RETURN(*in_shape, in_blob_desc->shape()) + << "Non-matching shape of blobs for pieces of nccl send recv"; + } + if (conf.has_output()) { + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + const NdSbp& dst_nd_sbp = conf.dst_nd_sbp(); + const ParallelDesc& dst_parallel_desc = ParallelDesc(conf.dst_parallel_conf()); + int64_t dst_parallel_id = + JUST(dst_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + std::shared_ptr out_shape = + JUST(GetPhysicalShape(logical_shape, dst_nd_sbp, dst_parallel_desc, dst_parallel_id)); + out_blob_desc->mut_shape() = *out_shape; + out_blob_desc->set_data_type(conf.data_type()); + } + return Maybe::Ok(); +} + +REGISTER_OP(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp b/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp new file mode 100644 index 00000000000..a0be3320256 --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op_util.cpp @@ -0,0 +1,170 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" + +namespace oneflow { + +namespace { +// Go through all the ranks while transfer between two nd sbps with no PartialSum under the same +// placement. +// NOTE: We need to make sure no partial sums in the sbps of the producer and consumer. +void DfsTraverseRanks4NdSbp( + int32_t depth, std::vector& in_parallel_ids, + const std::vector& out_parallel_ids, const Shape& in_parallel_hierarchy, + const NdIndexOffsetHelper& in_hierarchy_index_helper, + const NdSbp& in_nd_sbp, const std::function& visit) { + if (depth >= in_parallel_hierarchy.NumAxes()) { + visit(in_hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), + in_parallel_hierarchy.NumAxes())); + return; + } + if (in_nd_sbp.sbp_parallel(depth).has_broadcast_parallel()) { + // If Broadcast in the sbp of the producer, only visit those ranks with the same id as the + // current rank along the depth-dimension. + in_parallel_ids[depth] = out_parallel_ids[depth]; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, visit); + } else { + // If Split or PartialSum, go through all the ranks along the depth-dimension. + for (int64_t i = 0; i < in_parallel_hierarchy.dim_vec().at(depth); i++) { + in_parallel_ids[depth] = i; + DfsTraverseRanks4NdSbp(depth + 1, in_parallel_ids, out_parallel_ids, in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, visit); + } + } +} + +bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { + CHECK_GT(nd_sbp.sbp_parallel_size(), 0); + FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { + if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } + } + return true; +} + +} // namespace + +int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc) { + const int64_t machine_id = CHECK_JUST(from_parallel_desc.MachineId4ParallelId(from_parallel_id)); + const int64_t device_index = CHECK_JUST(from_parallel_desc.DeviceId4ParallelId(from_parallel_id)); + if (to_parallel_desc.Containing(machine_id, device_index)) { + return CHECK_JUST(to_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + } else { + return -1; + } +} + +void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, + const NdSbp& out_nd_sbp, const Shape& logical_shape, + std::vector* send_intersections, + std::vector* recv_intersections) { + const int64_t parallel_num = parallel_desc.parallel_num(); + CHECK_LT(parallel_id, parallel_num); + + const std::vector& in_slices = + GetTensorSliceView(*in_parallel_desc.hierarchy(), in_nd_sbp, logical_shape); + const std::vector& out_slices = + GetTensorSliceView(*out_parallel_desc.hierarchy(), out_nd_sbp, logical_shape); + + const auto& in_parallel_hierarchy = in_parallel_desc.hierarchy(); + int32_t in_hierarchy_dimension = in_parallel_hierarchy->NumAxes(); + const NdIndexOffsetHelper in_hierarchy_index_helper( + in_parallel_hierarchy->dim_vec().data(), in_hierarchy_dimension); + + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + const int64_t in_parallel_num = in_parallel_desc.parallel_num(); + const int64_t out_parallel_num = out_parallel_desc.parallel_num(); + // cur rank recv from + // cur rank has output + if (out_parallel_desc.Containing(machine_id, device_index)) { + recv_intersections->resize(parallel_num); + int64_t out_id = + CHECK_JUST(out_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + const TensorSliceView& cur_rank_out_slice = out_slices.at(out_id); + const auto& add_to_recv_intersections = [&](int32_t send_id) { + const TensorSliceView& in_slice = in_slices.at(send_id); + const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); + if (intersection.IsEmpty()) { return; } + const int64_t merged_id = GetMappedParallelId(send_id, in_parallel_desc, parallel_desc); + recv_intersections->at(merged_id) = intersection; + }; + int64_t corresponding_in_id = 0; + // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] + if (in_parallel_desc.Containing(machine_id, device_index)) { + // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description + // The id of 1 is (0, 1), the id of 3 is (1, 1) + corresponding_in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + } else { + // 5 and 7 are not in [[0, 1], [2, 3]] + // Then the id does not matter + corresponding_in_id = out_id % in_parallel_num; + } + std::vector in_parallel_ids(in_hierarchy_dimension); + // The corresponding parallel id of a consumer rank in the producer parallel description + std::vector out_parallel_ids(in_hierarchy_dimension); + in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), + in_hierarchy_dimension); + DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, add_to_recv_intersections); + } + + // cur rank send to + if (in_parallel_desc.Containing(machine_id, device_index)) { + send_intersections->resize(parallel_num); + int64_t in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(machine_id, device_index)); + const TensorSliceView& cur_rank_in_slice = in_slices.at(in_id); + for (int64_t recv_i = 0; recv_i < out_parallel_num; ++recv_i) { + const auto& add_to_send_intersections = [&](int32_t send_id) { + if (send_id != in_id) { return; } + const TensorSliceView& out_slice = out_slices.at(recv_i); + const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); + if (intersection.IsEmpty()) { return; } + const int64_t merged_id = GetMappedParallelId(recv_i, out_parallel_desc, parallel_desc); + send_intersections->at(merged_id) = intersection; + }; + int64_t out_device_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(recv_i)); + int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(recv_i)); + int64_t corresponding_in_id = 0; + // For example [[0, 1], [2, 3]] -> [[1, 3], [5, 6]] + if (in_parallel_desc.Containing(out_machine_id, out_device_id)) { + // 1 and 3 are in [[0, 1], [2, 3]], use the same id in the producer parallel description + // The id of 1 is (0, 1), the id of 3 is (1, 1) + corresponding_in_id = + CHECK_JUST(in_parallel_desc.ParallelId4MachineDeviceId(out_machine_id, out_device_id)); + } else { + // 5 and 7 are not in [[0, 1], [2, 3]] + // Then the id does not matter + corresponding_in_id = recv_i % in_parallel_num; + } + std::vector in_parallel_ids(in_hierarchy_dimension); + // The corresponding parallel id of a consumer rank in the producer parallel description + std::vector out_parallel_ids(in_hierarchy_dimension); + in_hierarchy_index_helper.OffsetToNdIndex(corresponding_in_id, out_parallel_ids.data(), + in_hierarchy_dimension); + DfsTraverseRanks4NdSbp(0, in_parallel_ids, out_parallel_ids, *in_parallel_hierarchy, + in_hierarchy_index_helper, in_nd_sbp, add_to_send_intersections); + } + } +} + +} // namespace oneflow diff --git a/oneflow/core/operator/nccl_send_recv_boxing_op_util.h b/oneflow/core/operator/nccl_send_recv_boxing_op_util.h new file mode 100644 index 00000000000..f491a50e91b --- /dev/null +++ b/oneflow/core/operator/nccl_send_recv_boxing_op_util.h @@ -0,0 +1,31 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/register/tensor_slice_view.h" +#include "oneflow/core/job/nd_sbp_util.h" + +namespace oneflow { + +int64_t GetMappedParallelId(const int64_t from_parallel_id, const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc); + +void GetRankSendRecvIntersection(int64_t parallel_id, const ParallelDesc& parallel_desc, + const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const NdSbp& in_nd_sbp, + const NdSbp& out_nd_sbp, const Shape& logical_shape, + std::vector* send_intersections, + std::vector* recv_intersections); + +} // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 94379291558..c94ad6d9fa1 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -13,6 +13,7 @@ import "oneflow/core/job/sbp_parallel.proto"; import "oneflow/core/graph/boxing/collective_boxing.proto"; import "oneflow/core/job/initializer_conf.proto"; import "oneflow/core/job/regularizer_conf.proto"; +import "oneflow/core/job/placement.proto"; import "oneflow/core/job/learning_rate_schedule_conf.proto"; import "oneflow/core/operator/interface_blob_conf.proto"; import "oneflow/core/register/blob_desc.proto"; @@ -401,6 +402,19 @@ message BoxingZerosOpConf { required DataType data_type = 3; } +message NcclSendRecvBoxingOpConf { + required LogicalBlobId lbi = 1; + required NdSbp src_nd_sbp = 2; + required NdSbp dst_nd_sbp = 3; + required ParallelConf parallel_conf = 4; + required ParallelConf src_parallel_conf = 5; + required ParallelConf dst_parallel_conf = 6; + required ShapeProto logical_shape = 7; + required DataType data_type = 8; + required bool has_input = 9; + required bool has_output = 10; +} + message OperatorConf { required string name = 1; optional string device_tag = 4 [default = "invalid_device"]; @@ -446,6 +460,7 @@ message OperatorConf { CollectiveBoxingPackOpConf collective_boxing_pack_conf = 174; CollectiveBoxingUnpackOpConf collective_boxing_unpack_conf = 175; BoxingZerosOpConf boxing_zeros_conf = 176; + NcclSendRecvBoxingOpConf nccl_send_recv_boxing_conf = 177; UserOpConf user_conf = 199; // domain op diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 714c9a5cbd3..6ef75d9a993 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/register/tensor_slice_copier.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" +#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -87,7 +88,9 @@ NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* c std::vector src_send_intersections; std::vector dst_recv_intersections; - GetRankSendRecvIntersection(parallel_id, parallel_desc_->hierarchy(), src_nd_sbp, dst_nd_sbp, + GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/*parallel_desc_, + /*in_parallel_desc=*/*parallel_desc_, + /*out_parallel_desc=*/*parallel_desc_, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); CHECK_EQ(src_send_intersections.size(), parallel_num); @@ -264,7 +267,10 @@ size_t InferTmpBufferSize(user_op::InferContext* ctx) { std::vector src_send_intersections; std::vector dst_recv_intersections; - GetRankSendRecvIntersection(parallel_id, ctx->parallel_desc().hierarchy(), src_nd_sbp, dst_nd_sbp, + const auto& parallel_desc = ctx->parallel_desc(); + GetRankSendRecvIntersection(parallel_id, /*merge_parallel_desc=*/parallel_desc, + /*in_parallel_desc=*/parallel_desc, + /*out_parallel_desc=*/parallel_desc, src_nd_sbp, dst_nd_sbp, logical_shape, &src_send_intersections, &dst_recv_intersections); int64_t buf_count = 0; CHECK_EQ(src_send_intersections.size(), parallel_num); diff --git a/python/oneflow/test/graph/test_comb1to2d.py b/python/oneflow/test/graph/test_comb1to2d.py index eae8c04ec1d..cce4d3292de 100644 --- a/python/oneflow/test/graph/test_comb1to2d.py +++ b/python/oneflow/test/graph/test_comb1to2d.py @@ -24,6 +24,10 @@ import oneflow.unittest +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" + + class _TestModuleDiffHierarchy(nn.Module): def forward(self, x): sbp_1ds = [ @@ -32,7 +36,6 @@ def forward(self, x): flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), - flow.sbp.split(3), ] for sbp1 in sbp_1ds: @@ -63,7 +66,6 @@ def forward(self, x): flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), - flow.sbp.split(3), ] for sbp1 in sbp_1ds: @@ -106,13 +108,14 @@ def test_lazy_boxing_2d_all_combination(test_case): 4, 12, 4, - 12, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) ), ) + flow.boxing.nccl.enable_use_compute_stream(False) + model_diff_hierarchy = _TestModuleDiffHierarchy() graph_diff_hierarchy = _TestGraph(model_diff_hierarchy) y = graph_diff_hierarchy(x) diff --git a/python/oneflow/test/graph/test_comb2d.py b/python/oneflow/test/graph/test_comb2d.py index 7b746017bdb..f4ea5fa2d37 100644 --- a/python/oneflow/test/graph/test_comb2d.py +++ b/python/oneflow/test/graph/test_comb2d.py @@ -24,6 +24,12 @@ import oneflow.unittest +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" + +flow.boxing.nccl.enable_use_compute_stream(False) + + class _TestModule(nn.Module): def forward(self, x): sbp_1ds = [ @@ -32,7 +38,6 @@ def forward(self, x): flow.sbp.split(0), flow.sbp.split(1), flow.sbp.split(2), - flow.sbp.split(3), ] y = x @@ -40,6 +45,9 @@ def forward(self, x): for sbp2 in sbp_1ds: for sbp3 in sbp_1ds: + # in this case, use intra group boxing + if sbp1 == sbp3: + continue for sbp4 in sbp_1ds: # (2, 2) -> (2, 2) x = x.to_global(sbp=[sbp1, sbp2]) @@ -69,7 +77,6 @@ def test_lazy_boxing_2d_all_combination(test_case): 4, 4, 4, - 4, sbp=[flow.sbp.broadcast, flow.sbp.broadcast], placement=flow.placement( type="cuda", ranks=np.array(range(4)).reshape(2, 2) diff --git a/python/oneflow/test/graph/test_gbc1to2d.py b/python/oneflow/test/graph/test_gbc1to2d.py new file mode 100644 index 00000000000..4025b81e69b --- /dev/null +++ b/python/oneflow/test/graph/test_gbc1to2d.py @@ -0,0 +1,96 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" + + +def _test_general_basic_communication_1d_to_2d(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + + # input + placement_x = flow.placement("cuda", ranks=[0, 1, 2]) + placement_y = flow.placement("cuda", ranks=[[3, 4], [1, 2]]) + local_np = np.arange(4 * 12).reshape(4, 12) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) + + # check eager boxing + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) + + # check graph boxing + flow.boxing.nccl.enable_use_compute_stream(False) + + class TestGeneralBasicCommunicationGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + return y + + graph = TestGeneralBasicCommunicationGraph() + y = graph(x) + out_np = y.numpy() + in_np = x.numpy() + test_case.assertTrue(np.array_equal(out_np, in_np)) + + +def gen_nd_sbp_1d(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + ] + return sbp_list + + +def gen_nd_sbp_2d(): + nd_sbp_list = [] + for sbp0 in gen_nd_sbp_1d(): + for sbp1 in gen_nd_sbp_1d(): + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_2n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestGeneralBasicCommunication(flow.unittest.TestCase): + def test_general_basic_communication(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp_1d() + arg_dict["dst_nd_sbp"] = gen_nd_sbp_2d() + for arg in GenArgList(arg_dict): + _test_general_basic_communication_1d_to_2d(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_gbc2d.py b/python/oneflow/test/graph/test_gbc2d.py new file mode 100644 index 00000000000..d08ce287d17 --- /dev/null +++ b/python/oneflow/test/graph/test_gbc2d.py @@ -0,0 +1,107 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" + + +def _test_general_basic_communication_same_placement(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + + # skip src == dst + if src_nd_sbp == dst_nd_sbp: + return + + # in this case, use intra group boxing + if src_nd_sbp[0] == dst_nd_sbp[0]: + return + + # in this case, use inter group boxing + if ( + src_nd_sbp[1] == dst_nd_sbp[1] + and src_nd_sbp[0] != src_nd_sbp[1] + and dst_nd_sbp[0] != dst_nd_sbp[1] + ): + return + + # input + placement = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + local_np = np.arange(4 * 4).reshape(4, 4) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement) + + # check eager boxing + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) + + # check graph boxing + flow.boxing.nccl.enable_use_compute_stream(False) + + class TestGeneralBasicCommunicationGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement) + return y + + graph = TestGeneralBasicCommunicationGraph() + y = graph(x) + out_np = y.numpy() + in_np = x.numpy() + test_case.assertTrue(np.array_equal(out_np, in_np)) + + +def gen_nd_sbp(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + ] + nd_sbp_list = [] + for sbp0 in sbp_list: + for sbp1 in sbp_list: + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_1n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestGeneralBasicCommunication(flow.unittest.TestCase): + def test_general_basic_communication(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp() + arg_dict["dst_nd_sbp"] = gen_nd_sbp() + for arg in GenArgList(arg_dict): + _test_general_basic_communication_same_placement(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_gbc2to1d.py b/python/oneflow/test/graph/test_gbc2to1d.py new file mode 100644 index 00000000000..95f74f97661 --- /dev/null +++ b/python/oneflow/test/graph/test_gbc2to1d.py @@ -0,0 +1,96 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" + + +def _test_general_basic_communication_2d_to_1d(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() == dst_nd_sbp: + return + + # input + placement_x = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + placement_y = flow.placement("cuda", ranks=[0, 3, 4]) + local_np = np.arange(12 * 4).reshape(12, 4) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) + + # check eager boxing + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) + + # check graph boxing + flow.boxing.nccl.enable_use_compute_stream(False) + + class TestGeneralBasicCommunicationGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + return y + + graph = TestGeneralBasicCommunicationGraph() + y = graph(x) + out_np = y.numpy() + in_np = x.numpy() + test_case.assertTrue(np.array_equal(out_np, in_np)) + + +def gen_nd_sbp_1d(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + ] + return sbp_list + + +def gen_nd_sbp_2d(): + nd_sbp_list = [] + for sbp0 in gen_nd_sbp_1d(): + for sbp1 in gen_nd_sbp_1d(): + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_2n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestGeneralBasicCommunication(flow.unittest.TestCase): + def test_general_basic_communication(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp_2d() + arg_dict["dst_nd_sbp"] = gen_nd_sbp_1d() + for arg in GenArgList(arg_dict): + _test_general_basic_communication_2d_to_1d(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_gbc2to2d.py b/python/oneflow/test/graph/test_gbc2to2d.py new file mode 100644 index 00000000000..5a2d00809e8 --- /dev/null +++ b/python/oneflow/test/graph/test_gbc2to2d.py @@ -0,0 +1,95 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import oneflow +import numpy as np +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.test_util import GenArgList + +from oneflow.test_utils.automated_test_util import * +import time +import os + +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "1" + + +def _test_general_basic_communication_2d_to_2d(test_case, src_nd_sbp, dst_nd_sbp): + # can not process p in dst + if flow.sbp.partial_sum() in dst_nd_sbp: + return + + if dst_nd_sbp[0] == dst_nd_sbp[1] and src_nd_sbp[0] == src_nd_sbp[1]: + return + + # input + placement_x = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + placement_y = flow.placement("cuda", ranks=[[0, 3, 4], [2, 5, 6]]) + local_np = np.arange(12 * 12).reshape(12, 12) + x = flow.tensor(local_np, sbp=src_nd_sbp, placement=placement_x) + + # check eager boxing + eager_out = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + test_case.assertTrue(np.array_equal(eager_out.numpy(), x.numpy())) + + # check graph boxing + flow.boxing.nccl.enable_use_compute_stream(False) + + class TestGeneralBasicCommunicationGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + y = x.to_global(sbp=dst_nd_sbp, placement=placement_y) + return y + + graph = TestGeneralBasicCommunicationGraph() + y = graph(x) + out_np = y.numpy() + in_np = x.numpy() + test_case.assertTrue(np.array_equal(out_np, in_np)) + + +def gen_nd_sbp(): + sbp_list = [ + flow.sbp.partial_sum(), + flow.sbp.broadcast(), + flow.sbp.split(0), + flow.sbp.split(1), + ] + nd_sbp_list = [] + for sbp0 in sbp_list: + for sbp1 in sbp_list: + nd_sbp_list.append([sbp0, sbp1]) + return nd_sbp_list + + +@flow.unittest.skip_unless_2n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestGeneralBasicCommunication(flow.unittest.TestCase): + def test_general_basic_communication(test_case): + arg_dict = OrderedDict() + arg_dict["src_nd_sbp"] = gen_nd_sbp() + arg_dict["dst_nd_sbp"] = gen_nd_sbp() + for arg in GenArgList(arg_dict): + _test_general_basic_communication_2d_to_2d(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_comb2to2d.py b/python/oneflow/test/modules/test_comb2to2d.py index dc05016242a..670f20885c4 100644 --- a/python/oneflow/test/modules/test_comb2to2d.py +++ b/python/oneflow/test/modules/test_comb2to2d.py @@ -24,6 +24,12 @@ import oneflow.unittest +os.environ["ONEFLOW_BOXING_DISABLE_MIDDLE_NODE_AND_CHECK"] = "0" +os.environ["ONEFLOW_BOXING_ENABLE_GENERAL_BASIC_COMMUNICATION"] = "0" + +flow.boxing.nccl.enable_use_compute_stream(False) + + class _TestModuleDiffHierarchy(nn.Module): def forward(self, x): sbp_1ds = [