diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 8de13f9afa1..4a553db5642 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -73,8 +73,9 @@ int32_t Ratio4Sbp(const NdSbp& nd_sbp, const ParallelDesc& parallel_desc, Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sbp_parallel, const SbpParallel& consumer_sbp_parallel, const BlobDesc& logical_blob_desc, - const ParallelDesc& producer_parallel_desc, - const ParallelDesc& consumer_parallel_desc) { + bool on_same_devices, + int32_t producer_parallel_num, + int32_t consumer_parallel_num) { if (!(CheckSbpParallel(producer_sbp_parallel) && CheckSbpParallel(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } @@ -89,7 +90,7 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb // NOTE: A tensor placed on cpu with a consumer operator that accepts cuda inputs would be // transfered to cuda later. We might not have correct parallel description at this moment. - if (producer_parallel_desc == consumer_parallel_desc) { + if (on_same_devices && producer_parallel_num == consumer_parallel_num) { // Same sbp, no cost: S->S, B->B, P->P if (producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; } double logical_blob_size = @@ -99,8 +100,8 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb // arbitrary i. // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { - return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(), - consumer_parallel_desc.parallel_num()); + return Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, + consumer_parallel_num); } // B->S if (producer_sbp_parallel.has_broadcast_parallel()) { return 1.0; } @@ -110,15 +111,14 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb if (consumer_sbp_parallel.has_split_parallel() && producer_sbp_parallel.has_split_parallel()) { // S(0)->S(1), S(1)->S(0), etc. - return logical_blob_size * (producer_parallel_desc.parallel_num() - 1) - / producer_parallel_desc.parallel_num(); + return logical_blob_size * (producer_parallel_num - 1) / producer_parallel_num; } else { // P->S, S->B/P - return logical_blob_size * (producer_parallel_desc.parallel_num() - 1); + return logical_blob_size * (producer_parallel_num - 1); } } // P->B - return 2 * logical_blob_size * (producer_parallel_desc.parallel_num() - 1); + return 2 * logical_blob_size * (producer_parallel_num - 1); } else { // Not supporting P->P for different placement if (LazyMode::is_enabled()) { @@ -133,17 +133,16 @@ Maybe ComputCopyCostBetweenTwoSbpParallel(const SbpParallel& producer_sb double overall_cost = logical_blob_size; // ? -> B if (consumer_sbp_parallel.has_broadcast_parallel()) { - overall_cost += (consumer_parallel_desc.parallel_num() - 1) * logical_blob_size; + overall_cost += (consumer_parallel_num - 1) * logical_blob_size; } // P -> ? if (producer_sbp_parallel.has_partial_sum_parallel()) { - overall_cost += (producer_parallel_desc.parallel_num() - 1) * logical_blob_size; + overall_cost += (producer_parallel_num - 1) * logical_blob_size; } // ? -> P if (consumer_sbp_parallel.has_partial_sum_parallel()) { - overall_cost += - Penalty4PartialInConsumer(logical_blob_size, producer_parallel_desc.parallel_num(), - consumer_parallel_desc.parallel_num()); + overall_cost += Penalty4PartialInConsumer(logical_blob_size, producer_parallel_num, + consumer_parallel_num); } // For B->S, S->S, overall_cost == logical_blob_size; return overall_cost; @@ -202,9 +201,8 @@ double ComputCopyCostBetweenTwoDiffSbpParallel(const SbpParallel& producer_sbp_p Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, const NdSbp& consumer_nd_sbp, double logical_blob_size, - const std::shared_ptr& hierarchy, - bool on_same_devices) { - if (hierarchy->NumAxes() != 2) { return kUnsupportedBoxing; } + const Shape& hierarchy, bool on_same_devices) { + if (hierarchy.NumAxes() != 2) { return kUnsupportedBoxing; } const auto& producer_sbp_size = producer_nd_sbp.sbp_parallel_size(); const auto& consumer_sbp_size = consumer_nd_sbp.sbp_parallel_size(); // One of the SBP should have size 2 @@ -221,7 +219,7 @@ Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, // The SBP parallel are the same at dimension (dim_same_sbp) if (producer_nd_sbp.sbp_parallel(dim_producer) == consumer_nd_sbp.sbp_parallel(dim_consumer)) { if (!producer_nd_sbp.sbp_parallel(dim_producer).has_split_parallel()) { - logical_blob_size *= hierarchy->At(dim_same_sbp); + logical_blob_size *= hierarchy.At(dim_same_sbp); } // The SBP parallel are different at dimension (dim_diff_sbp) int32_t dim_diff_sbp = 1 - dim_same_sbp; @@ -241,7 +239,7 @@ Maybe ComputCopyCostBetweenTwoNdSbp(const NdSbp& producer_nd_sbp, } return ComputCopyCostBetweenTwoDiffSbpParallel( producer_nd_sbp.sbp_parallel(dim_producer), consumer_nd_sbp.sbp_parallel(dim_consumer), - logical_blob_size, hierarchy->At(dim_diff_sbp), on_same_devices); + logical_blob_size, hierarchy.At(dim_diff_sbp), on_same_devices); } } return kUnsupportedBoxing; @@ -265,34 +263,36 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle return kUnsupportedBoxing; } - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); + + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, - consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); - const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); - const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); - bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. - if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } + if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { + return 0.0; + } if (requires_same_sbp) { return kUnsupportedBoxing; } - int32_t in_dim = in_hierarchy->NumAxes(); - int32_t out_dim = out_hierarchy->NumAxes(); + int32_t in_dim = reduced_in_hierarchy.NumAxes(); + int32_t out_dim = reduced_out_hierarchy.NumAxes(); // We support different hierarchy for 1D sbp if (in_dim == 1 && out_dim == 1) { return ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), logical_blob_desc, - reduced_in_parallel_desc, reduced_out_parallel_desc); + on_same_devices, reduced_in_hierarchy.elem_cnt(), reduced_out_hierarchy.elem_cnt()); } double total_cost = 1.0; - if (reduced_in_parallel_desc == reduced_out_parallel_desc) { + if (on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { // NOTE: After analysis, transfer cost increase if spliting the same dimension. // Example 1: (S(1), S(0), S(1), S(0)) -> (S(0), S(0), S(0), S(0)) // Example 2: (B, S(0)) -> (S(0), S(0)) @@ -301,14 +301,15 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle // simplification. bool normal_case = true; // nd to nd - for (int32_t i = 0; i < reduced_in_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < in_dim; ++i) { const auto& in_sbp = reduced_in_nd_sbp.sbp_parallel(i); const auto& out_sbp = reduced_out_nd_sbp.sbp_parallel(i); // Have bugs here. (B, S0) -> (S0, S0) will give a cost 0. // Actually it is (1-1/m)T for hierarchy (n, m) // TODO: Fix that after support all sbp combination for eager. total_cost += JUST(ComputCopyCostBetweenTwoSbpParallel( - in_sbp, out_sbp, logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc)); + in_sbp, out_sbp, logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), + reduced_out_hierarchy.elem_cnt())); // Add the penalty for P in the consumer if (out_sbp.has_partial_sum_parallel() && (in_sbp != out_sbp)) { total_cost += Penalty4PartialInConsumer( @@ -338,20 +339,20 @@ Maybe ComputeEagerCopyCostBetweenNdSbp(const NdSbp& producer_sbp_paralle logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); { double in_cost = 1.0; - for (int32_t i = 0; i < reduced_in_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < in_dim; ++i) { // P -> ? if (reduced_in_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { - in_cost *= reduced_in_parallel_desc.hierarchy()->At(i); + in_cost *= reduced_in_hierarchy.At(i); } } total_cost += logical_blob_size * in_cost; } { double out_cost = 1.0; - for (int32_t i = 0; i < reduced_out_parallel_desc.hierarchy()->NumAxes(); ++i) { + for (int32_t i = 0; i < out_dim; ++i) { // ? -> B if (reduced_out_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { - out_cost *= reduced_out_parallel_desc.hierarchy()->At(i); + out_cost *= reduced_out_hierarchy.At(i); } // Add the penalty for P in the consumer if (reduced_out_nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { @@ -430,7 +431,13 @@ void NdSbpsDimReduce(const Shape& hierarchy, const std::vector& nd // Speed up for 1d sbp if (hierarchy.NumAxes() == 1) { *reduced_hierarchy = hierarchy; - for (int32_t index = 0; index < sbp_num; index++) { *reduced_nd_sbps[index] = *nd_sbps[index]; } + for (int32_t index = 0; index < sbp_num; index++) { + if (hierarchy.elem_cnt() == 1) { + reduced_nd_sbps[index]->add_sbp_parallel()->mutable_broadcast_parallel(); + } else { + *reduced_nd_sbps[index] = *nd_sbps[index]; + } + } return; } reduced_hierarchy->clear(); @@ -632,19 +639,19 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel if (!(CheckNdSbp(producer_sbp_parallel) && CheckNdSbp(consumer_sbp_parallel))) { return Error::RuntimeError() << "Illegal sbp parallel has been found."; } - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + bool on_same_devices = producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc); + + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - InOutParallelDimReduce(producer_parallel_desc, consumer_parallel_desc, producer_sbp_parallel, - consumer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_out_parallel_desc, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, logical_blob_desc.shape()); - - const auto& in_hierarchy = reduced_in_parallel_desc.hierarchy(); - const auto& out_hierarchy = reduced_out_parallel_desc.hierarchy(); - int32_t in_dim = in_hierarchy->NumAxes(); - int32_t out_dim = out_hierarchy->NumAxes(); + int32_t in_dim = reduced_in_hierarchy.NumAxes(); + int32_t out_dim = reduced_out_hierarchy.NumAxes(); // Not supporting n-D sbp with n >= 3 // TODO: Support it in the future if (std::min(in_dim, out_dim) <= 0 || std::max(in_dim, out_dim) >= 3) { @@ -653,7 +660,9 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel bool same_nd_sbp = reduced_in_nd_sbp == reduced_out_nd_sbp; // Same sbp is always supported. - if (same_nd_sbp && reduced_in_parallel_desc == reduced_out_parallel_desc) { return 0.0; } + if (same_nd_sbp && on_same_devices && reduced_in_hierarchy == reduced_out_hierarchy) { + return 0.0; + } if (requires_same_sbp) { return kUnsupportedBoxing; } // We support different hierarchy for 1D sbp @@ -661,7 +670,8 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel return GetTransferCost() + JUST(ComputCopyCostBetweenTwoSbpParallel( reduced_in_nd_sbp.sbp_parallel(0), reduced_out_nd_sbp.sbp_parallel(0), - logical_blob_desc, reduced_in_parallel_desc, reduced_out_parallel_desc)); + logical_blob_desc, on_same_devices, reduced_in_hierarchy.elem_cnt(), + reduced_out_hierarchy.elem_cnt())); } #ifdef WITH_CUDA @@ -682,33 +692,36 @@ Maybe ComputeLazyCopyCostBetweenNdSbp(const NdSbp& producer_sbp_parallel #endif // WITH_CUDA // Not supporting different hierarchy without general basic communication - if (in_hierarchy->elem_cnt() != out_hierarchy->elem_cnt()) { return kUnsupportedBoxing; } + if (reduced_in_hierarchy.elem_cnt() != reduced_out_hierarchy.elem_cnt()) { + return kUnsupportedBoxing; + } - bool on_same_devices = - reduced_in_parallel_desc.EqualsIgnoringHierarchy(reduced_out_parallel_desc); double logical_blob_size = logical_blob_desc.shape().elem_cnt() * GetSizeOfDataType(logical_blob_desc.data_type()); if (in_dim == 2 && out_dim == 2) { // Not supporting different hierarchy // TODO: Support it in the future - if (*in_hierarchy != *out_hierarchy) { return kUnsupportedBoxing; } + if (reduced_in_hierarchy != reduced_out_hierarchy) { return kUnsupportedBoxing; } return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, in_hierarchy, on_same_devices)); + logical_blob_size, reduced_in_hierarchy, + on_same_devices)); } // (in_dim == 2 && out_dim == 1) || (in_dim == 1 && out_dim == 2) if (in_dim == 2 && out_dim == 1) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, in_hierarchy, on_same_devices)); + logical_blob_size, reduced_in_hierarchy, + on_same_devices)); } if (in_dim == 1 && out_dim == 2) { return GetTransferCost() + JUST(ComputCopyCostBetweenTwoNdSbp(reduced_in_nd_sbp, reduced_out_nd_sbp, - logical_blob_size, out_hierarchy, on_same_devices)); + logical_blob_size, reduced_out_hierarchy, + on_same_devices)); } return Error::RuntimeError() @@ -813,23 +826,26 @@ Maybe ComputeCopyCostWithMiddleNodes(const NdSbp& producer_sbp_parallel, const ParallelDesc& producer_parallel_desc, const ParallelDesc& consumer_parallel_desc, bool requires_same_sbp) { - // Reduce before cost computation - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; - NdSbp reduced_in_nd_sbp; - NdSbpDimReduce(producer_parallel_desc, producer_sbp_parallel, &reduced_in_parallel_desc, - &reduced_in_nd_sbp, logical_blob_desc.shape()); - - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; - NdSbp reduced_out_nd_sbp; - NdSbpDimReduce(consumer_parallel_desc, consumer_sbp_parallel, &reduced_out_parallel_desc, - &reduced_out_nd_sbp, logical_blob_desc.shape()); // In 90% of the transfer, we would have the same parallel description for producer and consumer // We need to speed it up and give an approximation of the cost - if (reduced_in_parallel_desc == reduced_out_parallel_desc - && reduced_in_nd_sbp == reduced_out_nd_sbp) { - if (producer_sbp_parallel == consumer_sbp_parallel) { + if (producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { + // [2, 2]: (S0, S1) -> [2, 2]: (S0, S1) + if (*producer_parallel_desc.hierarchy() == *consumer_parallel_desc.hierarchy() + && producer_sbp_parallel == consumer_sbp_parallel) { return 0.0; - } else { + } + // Reduce before cost computation + Shape reduced_in_hierarchy; + NdSbp reduced_in_nd_sbp; + Shape reduced_out_hierarchy; + NdSbp reduced_out_nd_sbp; + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_sbp_parallel, consumer_sbp_parallel, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_blob_desc.shape()); + + // [2, 2]: (B, B) -> [4]: B + if (reduced_in_hierarchy == reduced_out_hierarchy && reduced_in_nd_sbp == reduced_out_nd_sbp) { return 1.0; } } @@ -903,22 +919,20 @@ double ComputeSbpInferPriority(const NdSbp& producer_nd_sbp, const NdSbp& consum // consumer return 0.0; } - // Dim reduction for producer - ParallelDesc reduced_in_parallel_desc = producer_parallel_desc; + // Reduce before cost computation + Shape reduced_in_hierarchy; NdSbp reduced_in_nd_sbp; - NdSbpDimReduce(producer_parallel_desc, producer_nd_sbp, &reduced_in_parallel_desc, - &reduced_in_nd_sbp, logical_shape); - - // Dim reduction for consumer - ParallelDesc reduced_out_parallel_desc = consumer_parallel_desc; + Shape reduced_out_hierarchy; NdSbp reduced_out_nd_sbp; - NdSbpDimReduce(consumer_parallel_desc, consumer_nd_sbp, &reduced_out_parallel_desc, - &reduced_out_nd_sbp, logical_shape); + InOutParallelDimReduce(*producer_parallel_desc.hierarchy(), *consumer_parallel_desc.hierarchy(), + producer_nd_sbp, consumer_nd_sbp, &reduced_in_hierarchy, + &reduced_out_hierarchy, &reduced_in_nd_sbp, &reduced_out_nd_sbp, + logical_shape); if (requires_same_sbp) { // This blob does not support boxing - if (reduced_in_nd_sbp == reduced_out_nd_sbp - && reduced_in_parallel_desc == reduced_out_parallel_desc) { + if (reduced_in_nd_sbp == reduced_out_nd_sbp && reduced_in_hierarchy == reduced_out_hierarchy + && producer_parallel_desc.EqualsIgnoringHierarchy(consumer_parallel_desc)) { // Normal priority: No transfer occurs but we have different sbp // For example: [1]:S0 -> [1]:B // [1, 2]:(P, S0) -> [1, 2]:(S0, S0)