From 3de1f1ce1ae835e61f6483c7b79eb54bb482091b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 8 Jul 2022 15:33:37 -0700 Subject: [PATCH 01/16] Use scheduler_utils to cache inputs and outputs in schedulePointwise --- .../codegen/cuda/scheduler/normalization.cpp | 5 +-- .../jit/codegen/cuda/scheduler/pointwise.cpp | 43 +++---------------- .../jit/codegen/cuda/scheduler/reduction.cpp | 5 +-- 3 files changed, 11 insertions(+), 42 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 6f3b736579d93..656db1c0ed805 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -983,9 +983,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel( // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 75172a4357331..8186986394d36 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -592,15 +592,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(), "This scheduler only handles pointwise ops."); - // For intermediate outputs, apply cacheFork - auto outs = fusion->outputs(); - for (const auto output : outs) { - if (!output->uses().empty() && output->definition() != nullptr) { - if (output->getValType().value() == ValType::TensorView) { - output->as()->cacheFork(); - } - } - } + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); std::vector input_tvs; { @@ -653,31 +649,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { TORCH_INTERNAL_ASSERT(inner_most_id != nullptr); - // Caches of inputs - std::vector cached_inputs; - - // Output, cacheBefore of output - std::vector> cached_outputs; - - // Track what should be vectorized versus unrolled - std::unordered_set vectorized_tensor; - - // Figure out which inputs to cache for unrolling or vectorization - for (auto inp : input_tvs) { - if (inp->uses().empty() || inp->isFusionOutput()) { - continue; - } - cached_inputs.emplace_back(inp->cacheAfter()); - } - - // Figure out which outputs to cache for unrolling or vectorization - for (auto out : output_tvs) { - if (out->definition() == nullptr) { - continue; - } - cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore())); - } - auto all_tvs = ir_utils::allTvs(fusion); // Merge right side of break point @@ -945,8 +916,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // Compute at cached outputs //[BIDx, Unswitch, Vectorization, TIDx] for (auto entry : cached_outputs) { - auto cached_output = entry.second; - auto output = entry.first; + auto cached_output = entry.first; + auto output = entry.second; auto unswitch_it = std::find_if( output->domain()->domain().begin(), diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index 1696242f4ff28..84a78bcf927d1 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -1002,9 +1002,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation From 586c0cdfbe4fd0d200cd36cf8a65dd7091741aa9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 00:02:37 -0700 Subject: [PATCH 02/16] save --- .../jit/codegen/cuda/scheduler/pointwise.cpp | 193 ++++++++++++------ 1 file changed, 126 insertions(+), 67 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 75172a4357331..e1608cd4ebc28 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -13,6 +14,9 @@ #include +#include +#include + namespace torch { namespace jit { namespace fuser { @@ -851,7 +855,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxRootDomainInfoSpanningTree spanning_tree(reference_tv); + spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference_tv, all_tvs); if (params.vectorize) { @@ -886,84 +891,138 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // Compute at into cached inputs - std::vector consumers_of_cached_inputs; - // Cache of input, and one of its consumers - std::vector> input_cache_and_consumer; - { - // Avoid duplicate additions, so track what we add - std::unordered_set added; - for (auto cached_input : cached_inputs) { - auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - TORCH_INTERNAL_ASSERT( - consumer_tvs.size(), - "Input was not succesfully filtered out for scheduling but wasn't used."); - - // Grab a consumer which will be used for computeAt structure of cached - // input into a consumer - input_cache_and_consumer.emplace_back( - std::make_pair(cached_input, consumer_tvs[0])); - - // Grab all consumers which will be used for inlining computeAt for the - // body of the computation (excluding caching inputs/outputs) - for (auto consumer_tv : consumer_tvs) { - // Don't duplicate - if (added.insert(consumer_tv).second) { - consumers_of_cached_inputs.emplace_back(consumer_tv); - } - } - } - } - - for (auto entry : input_cache_and_consumer) { - // Compute at inside unswitch position: - auto input_cache = entry.first; - auto input_cache_consumer = entry.second; - + if (true) { auto unswitch_it = std::find_if( - input_cache_consumer->domain()->domain().begin(), - input_cache_consumer->domain()->domain().end(), + reference_tv->domain()->domain().begin(), + reference_tv->domain()->domain().end(), [](IterDomain* id) { return id->getParallelType() == ParallelType::Unswitch; }); - auto unswitch_pos = - unswitch_it == input_cache_consumer->domain()->domain().end() + auto unswitch_pos = unswitch_it == reference_tv->domain()->domain().end() ? -1 - : std::distance( - input_cache_consumer->domain()->domain().begin(), unswitch_it) + + : std::distance(reference_tv->domain()->domain().begin(), unswitch_it) + 1; - input_cache->computeAt( - input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort); - } + std::unordered_set cached_outputs_set; + std::transform( + cached_outputs.begin(), + cached_outputs.end(), + std::inserter(cached_outputs_set, cached_outputs_set.begin()), + [](auto pair) { return pair->first; }); + + // inline cached inputs and cached outputs at unswitch position + std::unordered_set cached_inputs_and_outputs; + cached_inputs_and_outputs.insert( + cached_inputs.begin(), cached_inputs.end()); + cached_inputs_and_outputs.insert( + cached_outputs_set.begin(), cached_outputs_set.end()); + InlinePropagator inline_unswitch( + cached_inputs_and_outputs, + reference_tv, + unswitch_pos, + ComputeAtMode::BestEffort); + spanning_tree.traverse(&inline_unswitch); + + // inline all tensors between cached inputs and cached outputs at inner most + std::unordered_set consumers_of_cached_inputs; + for (auto cached_input : cached_inputs) { + auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); + consumers_of_cached_inputs.insert( + consumer_tvs.begin(), consumer_tvs.end()); + } + auto all_vals_between = DependencyCheck::getAllValsBetween( + consumers_of_cached_inputs, + {cached_outputs_set.begin(), cached_outputs_set.end()}); + auto all_tvs_between = ir_utils::filterByType(all_vals_between); + std::unordered_set all_tvs_between_set( + all_tvs_between.begin(), all_tvs_between.end()); + InlinePropagator inline_inner( + all_tvs_between_set, reference_tv, -1, ComputeAtMode::BestEffort); + spanning_tree.traverse(&inline_inner); + + // update max producer positions to ensure they are consistent + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); + } else { + // Compute at into cached inputs + std::vector consumers_of_cached_inputs; + // Cache of input, and one of its consumers + std::vector> input_cache_and_consumer; + { + // Avoid duplicate additions, so track what we add + std::unordered_set added; + for (auto cached_input : cached_inputs) { + auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); + TORCH_INTERNAL_ASSERT( + consumer_tvs.size(), + "Input was not succesfully filtered out for scheduling but wasn't used."); + + // Grab a consumer which will be used for computeAt structure of cached + // input into a consumer + input_cache_and_consumer.emplace_back( + std::make_pair(cached_input, consumer_tvs[0])); + + // Grab all consumers which will be used for inlining computeAt for the + // body of the computation (excluding caching inputs/outputs) + for (auto consumer_tv : consumer_tvs) { + // Don't duplicate + if (added.insert(consumer_tv).second) { + consumers_of_cached_inputs.emplace_back(consumer_tv); + } + } + } + } - // Producers for inlined computeAt - std::vector compute_from = consumers_of_cached_inputs; + for (auto entry : input_cache_and_consumer) { + // Compute at inside unswitch position: + auto input_cache = entry.first; + auto input_cache_consumer = entry.second; - // Consumers for inlined computeAt - std::vector compute_to; - // Compute at cached outputs - //[BIDx, Unswitch, Vectorization, TIDx] - for (auto entry : cached_outputs) { - auto cached_output = entry.second; - auto output = entry.first; + auto unswitch_it = std::find_if( + input_cache_consumer->domain()->domain().begin(), + input_cache_consumer->domain()->domain().end(), + [](IterDomain* id) { + return id->getParallelType() == ParallelType::Unswitch; + }); + auto unswitch_pos = + unswitch_it == input_cache_consumer->domain()->domain().end() + ? -1 + : std::distance( + input_cache_consumer->domain()->domain().begin(), unswitch_it) + + 1; + + input_cache->computeAt( + input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort); + } - auto unswitch_it = std::find_if( - output->domain()->domain().begin(), - output->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = unswitch_it == output->domain()->domain().end() - ? -1 - : std::distance(output->domain()->domain().begin(), unswitch_it) + 1; + // Producers for inlined computeAt + std::vector compute_from = consumers_of_cached_inputs; + + // Consumers for inlined computeAt + std::vector compute_to; + // Compute at cached outputs + //[BIDx, Unswitch, Vectorization, TIDx] + for (auto entry : cached_outputs) { + auto cached_output = entry.second; + auto output = entry.first; + + auto unswitch_it = std::find_if( + output->domain()->domain().begin(), + output->domain()->domain().end(), + [](IterDomain* id) { + return id->getParallelType() == ParallelType::Unswitch; + }); + auto unswitch_pos = unswitch_it == output->domain()->domain().end() + ? -1 + : std::distance(output->domain()->domain().begin(), unswitch_it) + 1; - cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort); - compute_to.push_back(cached_output); - } + cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort); + compute_to.push_back(cached_output); + } - scheduler_utils::computeAtBetween( - compute_from, compute_to, -1, ComputeAtMode::BestEffort); + scheduler_utils::computeAtBetween( + compute_from, compute_to, -1, ComputeAtMode::BestEffort); + } } } // namespace cuda From 153889a38ef9bf674a17ac68bd48352894c2a3ba Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 00:26:05 -0700 Subject: [PATCH 03/16] save --- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 12 ++++++++---- torch/csrc/jit/codegen/cuda/inline_propagator.h | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 34eb8a94948ec..1ebef5e5dca19 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -180,13 +180,13 @@ InlinePropagator::InlinePropagator( reference_(reference), reference_pos_(reference_pos), mode_(mode) { - if (reference_pos < 0) { - reference_pos += int64_t(reference->nDims()) + 1; + if (reference_pos_ < 0) { + reference_pos_ += int64_t(reference->nDims()) + 1; } TORCH_INTERNAL_ASSERT( - reference_pos >= 0 && reference_pos <= reference->nDims(), + reference_pos_ >= 0 && reference_pos_ <= reference->nDims(), "Invalid computeAt axis, received ", - reference_pos, + reference_pos_, " but should be > -", reference->nDims(), " and <= ", @@ -213,6 +213,8 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { to_pos >= 0, "Unable to propagate CA position from consumer ", from, + " at ", + from_pos, " to producer ", to, " because this would require replay."); @@ -240,6 +242,8 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { to_pos >= 0, "Unable to propagate CA position from producer ", from, + " at ", + from_pos, " to consumer ", to, " because this would require replay."); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 2ed137ac5955e..c6a96e4e9fb14 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -88,7 +88,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { const MaxPosCalculator max_pos_calc; std::unordered_set selected_; TensorView* reference_; - size_t reference_pos_; + int64_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; bool is_first_ = true; From fee489c6a34ec60a59b8b84d4f14b8e1b751fac0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 01:53:55 -0700 Subject: [PATCH 04/16] save --- torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 83e0d267ca0df..3f114dd9d6dcd 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -862,7 +862,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - if (true) { + if (false) { auto unswitch_it = std::find_if( reference_tv->domain()->domain().begin(), reference_tv->domain()->domain().end(), @@ -887,6 +887,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { cached_inputs.begin(), cached_inputs.end()); cached_inputs_and_outputs.insert( cached_outputs_set.begin(), cached_outputs_set.end()); + // std::transform( + // fusion->outputs().begin(), + // fusion->outputs().end(), + // std::inserter(cached_inputs_and_outputs, cached_inputs_and_outputs.begin()), + // [](Val* v) { return v->as(); }); InlinePropagator inline_unswitch( cached_inputs_and_outputs, reference_tv, @@ -903,6 +908,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } auto all_vals_between = DependencyCheck::getAllValsBetween( consumers_of_cached_inputs, + // fusion->outputs()); {cached_outputs_set.begin(), cached_outputs_set.end()}); auto all_tvs_between = ir_utils::filterByType(all_vals_between); std::unordered_set all_tvs_between_set( From 986216217561d357224eb5267e97734a13eb3d02 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 02:36:35 -0700 Subject: [PATCH 05/16] save --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 4 +- .../jit/codegen/cuda/inline_propagator.cpp | 16 +- .../csrc/jit/codegen/cuda/inline_propagator.h | 4 +- .../jit/codegen/cuda/scheduler/pointwise.cpp | 190 ++++++------------ 4 files changed, 72 insertions(+), 142 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 3793cb1923772..9f842988423bd 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -185,7 +185,7 @@ void ComputeAt::runAt( InlinePropagatorSelector selector(selected); InlinePropagator inline_propagator( - selector.selected(), consumer, consumer_position, mode); + consumer, consumer_position, mode, selector.selected()); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); @@ -227,7 +227,7 @@ void ComputeAt::runWith( InlinePropagatorSelector selector(selected); InlinePropagator inline_propagator( - selector.selected(), producer, producer_position, mode); + producer, producer_position, mode, selector.selected()); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 1ebef5e5dca19..d35c72e3a61d3 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -148,7 +148,7 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { void InlinePropagator::setCAPos(TensorView* tv) { size_t pos = mapped_reference_pos_.at(tv); - if (selected_.count(tv) && !tv->isFusionInput()) { + if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) { auto max_pos = getMaxPosAll(tv); if (mode_ == ComputeAtMode::Standard) { TORCH_INTERNAL_ASSERT( @@ -171,27 +171,27 @@ void InlinePropagator::setCAPos(TensorView* tv) { } InlinePropagator::InlinePropagator( - std::unordered_set selected, TensorView* reference, int64_t reference_pos, - ComputeAtMode mode) + ComputeAtMode mode, + std::unordered_set selected) : max_pos_calc(mode), selected_(std::move(selected)), reference_(reference), - reference_pos_(reference_pos), mode_(mode) { - if (reference_pos_ < 0) { - reference_pos_ += int64_t(reference->nDims()) + 1; + if (reference_pos < 0) { + reference_pos += int64_t(reference->nDims()) + 1; } TORCH_INTERNAL_ASSERT( - reference_pos_ >= 0 && reference_pos_ <= reference->nDims(), + reference_pos >= 0 && reference_pos <= reference->nDims(), "Invalid computeAt axis, received ", - reference_pos_, + reference_pos, " but should be > -", reference->nDims(), " and <= ", reference->nDims(), "."); + reference_pos_ = reference_pos; } void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index c6a96e4e9fb14..b52aa684e2568 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -94,10 +94,10 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { public: InlinePropagator( - std::unordered_set selected, TensorView* reference, int64_t reference_pos, - ComputeAtMode mode); + ComputeAtMode mode, + std::unordered_set selected = {}); ~InlinePropagator() = default; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 3f114dd9d6dcd..d5151f9c717b6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -691,6 +691,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } + int64_t unswitch_pos; if (params.break_point) { // 2D parallelization scheme TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0); @@ -743,8 +744,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } else { // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx] @@ -754,8 +757,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } } else { @@ -767,8 +772,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } else { // [BIDx | BIDy | Unswitch, Unroll, TIDx] @@ -777,8 +784,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } } @@ -823,6 +832,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv->axis(1)->parallelize(ParallelType::Unswitch); reference_tv->axis(3)->parallelize(ParallelType::TIDx); } + unswitch_pos = 2; } TransformPropagator propagator(reference_tv); @@ -862,144 +872,64 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - if (false) { - auto unswitch_it = std::find_if( - reference_tv->domain()->domain().begin(), - reference_tv->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = unswitch_it == reference_tv->domain()->domain().end() - ? -1 - : std::distance(reference_tv->domain()->domain().begin(), unswitch_it) + - 1; - - std::unordered_set cached_outputs_set; - std::transform( - cached_outputs.begin(), - cached_outputs.end(), - std::inserter(cached_outputs_set, cached_outputs_set.begin()), - [](auto pair) { return pair.first; }); - - // inline cached inputs and cached outputs at unswitch position - std::unordered_set cached_inputs_and_outputs; - cached_inputs_and_outputs.insert( - cached_inputs.begin(), cached_inputs.end()); - cached_inputs_and_outputs.insert( - cached_outputs_set.begin(), cached_outputs_set.end()); - // std::transform( - // fusion->outputs().begin(), - // fusion->outputs().end(), - // std::inserter(cached_inputs_and_outputs, cached_inputs_and_outputs.begin()), - // [](Val* v) { return v->as(); }); - InlinePropagator inline_unswitch( - cached_inputs_and_outputs, - reference_tv, - unswitch_pos, - ComputeAtMode::BestEffort); - spanning_tree.traverse(&inline_unswitch); - - // inline all tensors between cached inputs and cached outputs at inner most dimension - std::unordered_set consumers_of_cached_inputs; + // auto unswitch_it = std::find_if( + // reference_tv->domain()->domain().begin(), + // reference_tv->domain()->domain().end(), + // [](IterDomain* id) { + // return id->getParallelType() == ParallelType::Unswitch; + // }); + // auto unswitch_pos = unswitch_it == reference_tv->domain()->domain().end() + // ? -1 + // : std::distance(reference_tv->domain()->domain().begin(), unswitch_it) + // + + // 1; + + // Begin by inlining at the unswitch position for the entire DAG. The cached + // inputs will keep this inline position, but other tensors will get a higher + // position in later inline propagation. + // TODO: update on who will keep who will be overwritten + InlinePropagator inline_unswitch( + reference_tv, unswitch_pos, ComputeAtMode::BestEffort); + spanning_tree.traverse(&inline_unswitch); + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); + + // Compute at into cached inputs + std::vector consumers_of_cached_inputs; + { + // Avoid duplicate additions, so track what we add + std::unordered_set added; for (auto cached_input : cached_inputs) { auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - consumers_of_cached_inputs.insert( - consumer_tvs.begin(), consumer_tvs.end()); - } - auto all_vals_between = DependencyCheck::getAllValsBetween( - consumers_of_cached_inputs, - // fusion->outputs()); - {cached_outputs_set.begin(), cached_outputs_set.end()}); - auto all_tvs_between = ir_utils::filterByType(all_vals_between); - std::unordered_set all_tvs_between_set( - all_tvs_between.begin(), all_tvs_between.end()); - InlinePropagator inline_inner( - all_tvs_between_set, reference_tv, -1, ComputeAtMode::BestEffort); - spanning_tree.traverse(&inline_inner); - - // update max producer positions to ensure they are consistent - MaxProducerPosUpdater updater; - spanning_tree.traverse(&updater); - } else { - // Compute at into cached inputs - std::vector consumers_of_cached_inputs; - // Cache of input, and one of its consumers - std::vector> input_cache_and_consumer; - { - // Avoid duplicate additions, so track what we add - std::unordered_set added; - for (auto cached_input : cached_inputs) { - auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - TORCH_INTERNAL_ASSERT( - consumer_tvs.size(), - "Input was not succesfully filtered out for scheduling but wasn't used."); - - // Grab a consumer which will be used for computeAt structure of cached - // input into a consumer - input_cache_and_consumer.emplace_back( - std::make_pair(cached_input, consumer_tvs[0])); - - // Grab all consumers which will be used for inlining computeAt for the - // body of the computation (excluding caching inputs/outputs) - for (auto consumer_tv : consumer_tvs) { - // Don't duplicate - if (added.insert(consumer_tv).second) { - consumers_of_cached_inputs.emplace_back(consumer_tv); - } + TORCH_INTERNAL_ASSERT( + consumer_tvs.size(), + "Input was not succesfully filtered out for scheduling but wasn't used."); + + // Grab all consumers which will be used for inlining computeAt for the + // body of the computation (excluding caching inputs/outputs) + for (auto consumer_tv : consumer_tvs) { + // Don't duplicate + if (added.insert(consumer_tv).second) { + consumers_of_cached_inputs.emplace_back(consumer_tv); } } } + } - for (auto entry : input_cache_and_consumer) { - // Compute at inside unswitch position: - auto input_cache = entry.first; - auto input_cache_consumer = entry.second; - - auto unswitch_it = std::find_if( - input_cache_consumer->domain()->domain().begin(), - input_cache_consumer->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = - unswitch_it == input_cache_consumer->domain()->domain().end() - ? -1 - : std::distance( - input_cache_consumer->domain()->domain().begin(), unswitch_it) + - 1; - - input_cache->computeAt( - input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort); - } - - // Producers for inlined computeAt - std::vector compute_from = consumers_of_cached_inputs; - - // Consumers for inlined computeAt - std::vector compute_to; - // Compute at cached outputs - //[BIDx, Unswitch, Vectorization, TIDx] - for (auto entry : cached_outputs) { - auto cached_output = entry.first; - auto output = entry.second; - - auto unswitch_it = std::find_if( - output->domain()->domain().begin(), - output->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = unswitch_it == output->domain()->domain().end() - ? -1 - : std::distance(output->domain()->domain().begin(), unswitch_it) + 1; - - cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort); - compute_to.push_back(cached_output); - } + // Producers for inlined computeAt + std::vector compute_from = consumers_of_cached_inputs; - scheduler_utils::computeAtBetween( - compute_from, compute_to, -1, ComputeAtMode::BestEffort); + // Consumers for inlined computeAt + std::vector compute_to; + // Compute at cached outputs + //[BIDx, Unswitch, Vectorization, TIDx] + for (auto entry : cached_outputs) { + auto cached_output = entry.first; + compute_to.push_back(cached_output); } + + scheduler_utils::computeAtBetween( + compute_from, compute_to, -1, ComputeAtMode::BestEffort); } } // namespace cuda From 6284750145524f14505aa223ff5b12894a682a6e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 02:41:24 -0700 Subject: [PATCH 06/16] cleanup --- torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index d5151f9c717b6..090d0f1ccc877 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -872,18 +872,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // auto unswitch_it = std::find_if( - // reference_tv->domain()->domain().begin(), - // reference_tv->domain()->domain().end(), - // [](IterDomain* id) { - // return id->getParallelType() == ParallelType::Unswitch; - // }); - // auto unswitch_pos = unswitch_it == reference_tv->domain()->domain().end() - // ? -1 - // : std::distance(reference_tv->domain()->domain().begin(), unswitch_it) - // + - // 1; - // Begin by inlining at the unswitch position for the entire DAG. The cached // inputs will keep this inline position, but other tensors will get a higher // position in later inline propagation. From 22a8cad1bec7dd0882e2751138e0733b3a59f910 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 14:28:37 -0700 Subject: [PATCH 07/16] save --- .../jit/codegen/cuda/scheduler/pointwise.cpp | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 090d0f1ccc877..bbec1e9f3de55 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -882,6 +882,36 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { MaxProducerPosUpdater updater; spanning_tree.traverse(&updater); + if (true) { + std::unordered_set inner_most_tensors( + all_tvs.begin(), all_tvs.end()); + for (auto cached_input : cached_inputs) { + inner_most_tensors.erase(cached_input); + } + for (auto input : fusion->inputs()) { + if (input->isA()) { + inner_most_tensors.erase(input->as()); + } + } + // for (auto output : fusion->outputs()) { + // if (output->isA()) { + // inner_most_tensors.erase(output->as()); + // } + // } + for (auto entry : cached_outputs) { + auto cached_output = entry.first; + auto output = entry.second; + inner_most_tensors.erase(cached_output); + inner_most_tensors.erase(output); + } + InlinePropagator inline_inner_most( + reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); + spanning_tree.traverse(&inline_inner_most); + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); + return; + } + // Compute at into cached inputs std::vector consumers_of_cached_inputs; { From 3b39284fbb9943a8b2452670e7a786558b8d4da1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 14:37:43 -0700 Subject: [PATCH 08/16] cleanup --- .../jit/codegen/cuda/scheduler/pointwise.cpp | 86 +++++-------------- 1 file changed, 21 insertions(+), 65 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index bbec1e9f3de55..c8e3180ade435 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -872,82 +872,38 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // Begin by inlining at the unswitch position for the entire DAG. The cached - // inputs will keep this inline position, but other tensors will get a higher - // position in later inline propagation. - // TODO: update on who will keep who will be overwritten + // Begin by inlining at the unswitch position for the entire DAG. The inputs, + // cached inputs, outputs and cached outputs will keep this inline position, + // but other tensors will get a higher position in later inline propagation. InlinePropagator inline_unswitch( reference_tv, unswitch_pos, ComputeAtMode::BestEffort); spanning_tree.traverse(&inline_unswitch); - MaxProducerPosUpdater updater; - spanning_tree.traverse(&updater); - if (true) { - std::unordered_set inner_most_tensors( - all_tvs.begin(), all_tvs.end()); - for (auto cached_input : cached_inputs) { - inner_most_tensors.erase(cached_input); - } - for (auto input : fusion->inputs()) { - if (input->isA()) { - inner_most_tensors.erase(input->as()); - } - } - // for (auto output : fusion->outputs()) { - // if (output->isA()) { - // inner_most_tensors.erase(output->as()); - // } - // } - for (auto entry : cached_outputs) { - auto cached_output = entry.first; - auto output = entry.second; - inner_most_tensors.erase(cached_output); - inner_most_tensors.erase(output); - } - InlinePropagator inline_inner_most( - reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); - spanning_tree.traverse(&inline_inner_most); - MaxProducerPosUpdater updater; - spanning_tree.traverse(&updater); - return; + // Inline at the inner most position. The CA position of all tensors except + // inputs, cached inputs, outputs and cached outputs will be updated. + std::unordered_set inner_most_tensors( + all_tvs.begin(), all_tvs.end()); + for (auto cached_input : cached_inputs) { + inner_most_tensors.erase(cached_input); } - - // Compute at into cached inputs - std::vector consumers_of_cached_inputs; - { - // Avoid duplicate additions, so track what we add - std::unordered_set added; - for (auto cached_input : cached_inputs) { - auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - TORCH_INTERNAL_ASSERT( - consumer_tvs.size(), - "Input was not succesfully filtered out for scheduling but wasn't used."); - - // Grab all consumers which will be used for inlining computeAt for the - // body of the computation (excluding caching inputs/outputs) - for (auto consumer_tv : consumer_tvs) { - // Don't duplicate - if (added.insert(consumer_tv).second) { - consumers_of_cached_inputs.emplace_back(consumer_tv); - } - } + for (auto input : fusion->inputs()) { + if (input->isA()) { + inner_most_tensors.erase(input->as()); } } - - // Producers for inlined computeAt - std::vector compute_from = consumers_of_cached_inputs; - - // Consumers for inlined computeAt - std::vector compute_to; - // Compute at cached outputs - //[BIDx, Unswitch, Vectorization, TIDx] for (auto entry : cached_outputs) { auto cached_output = entry.first; - compute_to.push_back(cached_output); + auto output = entry.second; + inner_most_tensors.erase(cached_output); + inner_most_tensors.erase(output); } + InlinePropagator inline_inner_most( + reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); + spanning_tree.traverse(&inline_inner_most); - scheduler_utils::computeAtBetween( - compute_from, compute_to, -1, ComputeAtMode::BestEffort); + // Fix max producer position + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); } } // namespace cuda From dc176d563761c9acbdf2643741218785aaf1a4d9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 15:08:14 -0700 Subject: [PATCH 09/16] more cleanup --- torch/csrc/jit/codegen/cuda/inline_propagator.h | 2 +- torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index b52aa684e2568..59431de91fec4 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -88,7 +88,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { const MaxPosCalculator max_pos_calc; std::unordered_set selected_; TensorView* reference_; - int64_t reference_pos_; + size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; bool is_first_ = true; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index c8e3180ade435..d385fa3808cc5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -872,9 +872,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // Begin by inlining at the unswitch position for the entire DAG. The inputs, - // cached inputs, outputs and cached outputs will keep this inline position, - // but other tensors will get a higher position in later inline propagation. + // Begin by inlining at the unswitch position for the entire DAG. The cached + // inputs, outputs and cached outputs will keep this inline position, but + // other tensors will get a higher position in later inline propagation. InlinePropagator inline_unswitch( reference_tv, unswitch_pos, ComputeAtMode::BestEffort); spanning_tree.traverse(&inline_unswitch); @@ -886,11 +886,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { for (auto cached_input : cached_inputs) { inner_most_tensors.erase(cached_input); } - for (auto input : fusion->inputs()) { - if (input->isA()) { - inner_most_tensors.erase(input->as()); - } - } for (auto entry : cached_outputs) { auto cached_output = entry.first; auto output = entry.second; From 4301cb4e1d86a3a86cad8f8eb43b44fd6af0ffe3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 9 Jul 2022 15:10:57 -0700 Subject: [PATCH 10/16] doc --- torch/csrc/jit/codegen/cuda/inline_propagator.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 59431de91fec4..3c18f531c241f 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -67,6 +67,8 @@ class MaxPosCalculator { MaxPosCalculator(ComputeAtMode mode); }; +// Propagate inline position to the `selected` tensors in the DAG. If `selected` +// is not specified or empty, then propagate to the entire DAG. class InlinePropagator : public MaxInfoSpanningTree::Propagator { // Checks producers and consumers to see what the maximum position in tv is // that can be shared across both directions. From be235237d6ae4b217945084ad418cdb5a6212d97 Mon Sep 17 00:00:00 2001 From: Xiang Date: Mon, 11 Jul 2022 02:57:10 -0400 Subject: [PATCH 11/16] fix merge error --- .../jit/codegen/cuda/scheduler/pointwise.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 8faa5dc5c1c56..37f9f06e306f8 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -642,22 +642,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // Track what should be vectorized versus unrolled std::unordered_set vectorized_tensor; - // Figure out which inputs to cache for unrolling or vectorization - for (auto inp : input_tvs) { - if (inp->uses().empty() || inp->isFusionOutput()) { - continue; - } - cached_inputs.emplace_back(inp->cacheAfter()); - } - - // Figure out which outputs to cache for unrolling or vectorization - for (auto out : output_tvs) { - if (out->definition() == nullptr) { - continue; - } - cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore())); - } - auto all_tvs = ir_utils::allTvs(fusion); // Merge right side of break point From 2517630d31612240e7b3d77f02e6b08c01ec04a2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 11 Jul 2022 00:03:26 -0700 Subject: [PATCH 12/16] fix more merge error --- torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 37f9f06e306f8..2f00863776d7f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -633,15 +633,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv != nullptr, "Could not find a fully broadcasted output to reference schedule on."); - // Caches of inputs - std::vector cached_inputs; - - // Output, cacheBefore of output - std::vector> cached_outputs; - - // Track what should be vectorized versus unrolled - std::unordered_set vectorized_tensor; - auto all_tvs = ir_utils::allTvs(fusion); // Merge right side of break point From 5299efba483d4e6083beb4501fc8d32abb5e2fc0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 13 Jul 2022 00:18:21 -0700 Subject: [PATCH 13/16] save --- torch/csrc/jit/codegen/cuda/inline_propagator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 3c18f531c241f..8c8449bd1b44f 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -98,7 +98,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { InlinePropagator( TensorView* reference, int64_t reference_pos, - ComputeAtMode mode, + ComputeAtMode mode = ComputeAtMode::Standard, std::unordered_set selected = {}); ~InlinePropagator() = default; From 72e4a4e6f7ed41e64d0c8794f2b907bebe19d4a9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 13 Jul 2022 00:23:05 -0700 Subject: [PATCH 14/16] TORCH_CUDA_CU_API --- torch/csrc/jit/codegen/cuda/inline_propagator.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 8c8449bd1b44f..e06939b31902b 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -14,7 +14,8 @@ namespace cuda { // Simple selector that only propagates across tensor views in the provided // unordered_set. Will also propagate to all consumers of those tensors, and the // siblings of those tensors. -class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { +class TORCH_CUDA_CU_API InlinePropagatorSelector + : public MaxInfoSpanningTree::Selector { std::unordered_set selected_; public: @@ -29,7 +30,7 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { } }; -class MaxPosCalculator { +class TORCH_CUDA_CU_API MaxPosCalculator { ComputeAtMode mode_ = ComputeAtMode::Standard; // Root domains in producer that's unmappable to any of its consumers @@ -69,7 +70,8 @@ class MaxPosCalculator { // Propagate inline position to the `selected` tensors in the DAG. If `selected` // is not specified or empty, then propagate to the entire DAG. -class InlinePropagator : public MaxInfoSpanningTree::Propagator { +class TORCH_CUDA_CU_API InlinePropagator + : public MaxInfoSpanningTree::Propagator { // Checks producers and consumers to see what the maximum position in tv is // that can be shared across both directions. size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); @@ -114,7 +116,8 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // the tensors, and it is not needed to compute the max producer position in a // specific order. But MaxInfoSpanningTree provides a very convenient API to // visit the tensors, so I just use it for cleaner code. -class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator { +class TORCH_CUDA_CU_API MaxProducerPosUpdater + : public MaxInfoSpanningTree::Propagator { std::unordered_set updated_; void handle(TensorView* tv); From 56b36d84d9ebf0d47df08ae62c676f119b376185 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 13 Jul 2022 13:05:36 -0700 Subject: [PATCH 15/16] a new ctor for InlinePropagator --- torch/csrc/jit/codegen/cuda/inline_propagator.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index e06939b31902b..d7cd1f82a8d90 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -103,6 +103,16 @@ class TORCH_CUDA_CU_API InlinePropagator ComputeAtMode mode = ComputeAtMode::Standard, std::unordered_set selected = {}); + InlinePropagator( + TensorView* reference, + int64_t reference_pos, + std::unordered_set selected) + : InlinePropagator( + reference, + reference_pos, + ComputeAtMode::Standard, + selected) {} + ~InlinePropagator() = default; // Actually propagate the transformations for the inlining pass. Uses the From 24397fbfb3712a38d8af1f9266f28fd3052bdb39 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 13 Jul 2022 16:13:36 -0700 Subject: [PATCH 16/16] do not exclude cached_outputs --- .../jit/codegen/cuda/scheduler/pointwise.cpp | 8 +++--- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 26 +++++++++---------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 4d91929312ff6..454f9433d6dfa 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -857,23 +857,21 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } // Begin by inlining at the unswitch position for the entire DAG. The cached - // inputs, outputs and cached outputs will keep this inline position, but - // other tensors will get a higher position in later inline propagation. + // inputs, and outputs will keep this inline position, but other tensors will + // get a higher position in later inline propagation. InlinePropagator inline_unswitch( reference_tv, unswitch_pos, ComputeAtMode::BestEffort); spanning_tree.traverse(&inline_unswitch); // Inline at the inner most position. The CA position of all tensors except - // inputs, cached inputs, outputs and cached outputs will be updated. + // inputs, cached inputs and outputs will be updated. std::unordered_set inner_most_tensors( all_tvs.begin(), all_tvs.end()); for (auto cached_input : cached_inputs) { inner_most_tensors.erase(cached_input); } for (auto entry : cached_outputs) { - auto cached_output = entry.first; auto output = entry.second; - inner_most_tensors.erase(cached_output); inner_most_tensors.erase(output); } InlinePropagator inline_inner_most( diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index e16beb70302d2..234d1c35c649f 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -1362,26 +1362,26 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i51; - i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i51 < T0.size[0])) { + int64_t i50; + i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i50 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i51]; + = T1[i50]; float T4[1]; T4[0] = 0; T4[0] - = T0[i51]; - float T6[1]; + = T0[i50]; float T2[1]; T2[0] = T4[0] * T5[0]; + float T6[1]; T6[0] = T2[0] * T4[0]; - T3[i51] + T3[i50] = T6[0]; } } @@ -19086,9 +19086,9 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i172; - i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + int64_t i171; + i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { __half T9[1]; T9[0] = 0; T9[0] @@ -19096,8 +19096,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T8[1]; T8[0] = 0; T8[0] - = T0[i172]; - __half T10[1]; + = T0[i171]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -19114,9 +19113,10 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, float T6[1]; T6[0] = relu(T5[0]); + __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i172] + T7[i171] = T10[0]; } }