diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 125e1bbf73071..0c7df3354da43 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -185,7 +185,7 @@ void ComputeAt::runAt( InlinePropagatorSelector selector(selected); InlinePropagator inline_propagator( - selector.selected(), consumer, consumer_position, mode); + consumer, consumer_position, mode, selector.selected()); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); @@ -227,7 +227,7 @@ void ComputeAt::runWith( InlinePropagatorSelector selector(selected); InlinePropagator inline_propagator( - selector.selected(), producer, producer_position, mode); + producer, producer_position, mode, selector.selected()); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 00269c68cdc90..d35c72e3a61d3 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -148,7 +148,7 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { void InlinePropagator::setCAPos(TensorView* tv) { size_t pos = mapped_reference_pos_.at(tv); - if (selected_.count(tv) && !tv->isFusionInput()) { + if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) { auto max_pos = getMaxPosAll(tv); if (mode_ == ComputeAtMode::Standard) { TORCH_INTERNAL_ASSERT( @@ -171,10 +171,10 @@ void InlinePropagator::setCAPos(TensorView* tv) { } InlinePropagator::InlinePropagator( - std::unordered_set selected, TensorView* reference, int64_t reference_pos, - ComputeAtMode mode) + ComputeAtMode mode, + std::unordered_set selected) : max_pos_calc(mode), selected_(std::move(selected)), reference_(reference), @@ -213,6 +213,8 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { to_pos >= 0, "Unable to propagate CA position from consumer ", from, + " at ", + from_pos, " to producer ", to, " because this would require replay."); @@ -240,6 +242,8 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { to_pos >= 0, "Unable to propagate CA position from producer ", from, + " at ", + from_pos, " to consumer ", to, " because this would require replay."); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 2ed137ac5955e..d7cd1f82a8d90 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -14,7 +14,8 @@ namespace cuda { // Simple selector that only propagates across tensor views in the provided // unordered_set. Will also propagate to all consumers of those tensors, and the // siblings of those tensors. -class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { +class TORCH_CUDA_CU_API InlinePropagatorSelector + : public MaxInfoSpanningTree::Selector { std::unordered_set selected_; public: @@ -29,7 +30,7 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { } }; -class MaxPosCalculator { +class TORCH_CUDA_CU_API MaxPosCalculator { ComputeAtMode mode_ = ComputeAtMode::Standard; // Root domains in producer that's unmappable to any of its consumers @@ -67,7 +68,10 @@ class MaxPosCalculator { MaxPosCalculator(ComputeAtMode mode); }; -class InlinePropagator : public MaxInfoSpanningTree::Propagator { +// Propagate inline position to the `selected` tensors in the DAG. If `selected` +// is not specified or empty, then propagate to the entire DAG. +class TORCH_CUDA_CU_API InlinePropagator + : public MaxInfoSpanningTree::Propagator { // Checks producers and consumers to see what the maximum position in tv is // that can be shared across both directions. size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); @@ -94,10 +98,20 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { public: InlinePropagator( - std::unordered_set selected, TensorView* reference, int64_t reference_pos, - ComputeAtMode mode); + ComputeAtMode mode = ComputeAtMode::Standard, + std::unordered_set selected = {}); + + InlinePropagator( + TensorView* reference, + int64_t reference_pos, + std::unordered_set selected) + : InlinePropagator( + reference, + reference_pos, + ComputeAtMode::Standard, + selected) {} ~InlinePropagator() = default; @@ -112,7 +126,8 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // the tensors, and it is not needed to compute the max producer position in a // specific order. But MaxInfoSpanningTree provides a very convenient API to // visit the tensors, so I just use it for cleaner code. -class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator { +class TORCH_CUDA_CU_API MaxProducerPosUpdater + : public MaxInfoSpanningTree::Propagator { std::unordered_set updated_; void handle(TensorView* tv); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 2f00863776d7f..454f9433d6dfa 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -13,6 +14,9 @@ #include +#include +#include + namespace torch { namespace jit { namespace fuser { @@ -671,6 +675,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } + int64_t unswitch_pos; if (params.break_point) { // 2D parallelization scheme TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0); @@ -723,8 +728,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } else { // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx] @@ -734,8 +741,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } } else { @@ -747,8 +756,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } else { // [BIDx | BIDy | Unswitch, Unroll, TIDx] @@ -757,8 +768,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } } @@ -803,10 +816,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv->axis(1)->parallelize(ParallelType::Unswitch); reference_tv->axis(3)->parallelize(ParallelType::TIDx); } + unswitch_pos = 2; } TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxRootDomainInfoSpanningTree spanning_tree(reference_tv); + spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference_tv, all_tvs); if (params.vectorize) { @@ -841,84 +856,31 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // Compute at into cached inputs - std::vector consumers_of_cached_inputs; - // Cache of input, and one of its consumers - std::vector> input_cache_and_consumer; - { - // Avoid duplicate additions, so track what we add - std::unordered_set added; - for (auto cached_input : cached_inputs) { - auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - TORCH_INTERNAL_ASSERT( - consumer_tvs.size(), - "Input was not succesfully filtered out for scheduling but wasn't used."); - - // Grab a consumer which will be used for computeAt structure of cached - // input into a consumer - input_cache_and_consumer.emplace_back( - std::make_pair(cached_input, consumer_tvs[0])); - - // Grab all consumers which will be used for inlining computeAt for the - // body of the computation (excluding caching inputs/outputs) - for (auto consumer_tv : consumer_tvs) { - // Don't duplicate - if (added.insert(consumer_tv).second) { - consumers_of_cached_inputs.emplace_back(consumer_tv); - } - } - } - } - - for (auto entry : input_cache_and_consumer) { - // Compute at inside unswitch position: - auto input_cache = entry.first; - auto input_cache_consumer = entry.second; - - auto unswitch_it = std::find_if( - input_cache_consumer->domain()->domain().begin(), - input_cache_consumer->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = - unswitch_it == input_cache_consumer->domain()->domain().end() - ? -1 - : std::distance( - input_cache_consumer->domain()->domain().begin(), unswitch_it) + - 1; + // Begin by inlining at the unswitch position for the entire DAG. The cached + // inputs, and outputs will keep this inline position, but other tensors will + // get a higher position in later inline propagation. + InlinePropagator inline_unswitch( + reference_tv, unswitch_pos, ComputeAtMode::BestEffort); + spanning_tree.traverse(&inline_unswitch); - input_cache->computeAt( - input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort); + // Inline at the inner most position. The CA position of all tensors except + // inputs, cached inputs and outputs will be updated. + std::unordered_set inner_most_tensors( + all_tvs.begin(), all_tvs.end()); + for (auto cached_input : cached_inputs) { + inner_most_tensors.erase(cached_input); } - - // Producers for inlined computeAt - std::vector compute_from = consumers_of_cached_inputs; - - // Consumers for inlined computeAt - std::vector compute_to; - // Compute at cached outputs - //[BIDx, Unswitch, Vectorization, TIDx] for (auto entry : cached_outputs) { - auto cached_output = entry.first; auto output = entry.second; - - auto unswitch_it = std::find_if( - output->domain()->domain().begin(), - output->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = unswitch_it == output->domain()->domain().end() - ? -1 - : std::distance(output->domain()->domain().begin(), unswitch_it) + 1; - - cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort); - compute_to.push_back(cached_output); + inner_most_tensors.erase(output); } + InlinePropagator inline_inner_most( + reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); + spanning_tree.traverse(&inline_inner_most); - scheduler_utils::computeAtBetween( - compute_from, compute_to, -1, ComputeAtMode::BestEffort); + // Fix max producer position + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index e16beb70302d2..234d1c35c649f 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -1362,26 +1362,26 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i51; - i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i51 < T0.size[0])) { + int64_t i50; + i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i50 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i51]; + = T1[i50]; float T4[1]; T4[0] = 0; T4[0] - = T0[i51]; - float T6[1]; + = T0[i50]; float T2[1]; T2[0] = T4[0] * T5[0]; + float T6[1]; T6[0] = T2[0] * T4[0]; - T3[i51] + T3[i50] = T6[0]; } } @@ -19086,9 +19086,9 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i172; - i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + int64_t i171; + i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { __half T9[1]; T9[0] = 0; T9[0] @@ -19096,8 +19096,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T8[1]; T8[0] = 0; T8[0] - = T0[i172]; - __half T10[1]; + = T0[i171]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -19114,9 +19113,10 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, float T6[1]; T6[0] = relu(T5[0]); + __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i172] + T7[i171] = T10[0]; } }