From 38c7f3cf69ea58cc9480b0621506bbfd90a7c9d3 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 1 Jul 2022 13:02:43 -0700 Subject: [PATCH] InlinePropagator please don't replay (#1797) This PR makes `InlinePropagator` just set compute-at positions. It will not replay any tensor. If you want to replay, please use `TransformPropagator` and friends to do so. Currently, `InlinePropagator` is already asserting no replay for standard and best effort compute at. So this PR is mostly about making most inlined compute at works as well. This PR also does a lot of cleanups to remove the word "replay" from comments and variable and function names from `InlinePropagator`. I also cleaned up `recordReplayedPos` and `retrieveReplayedPos`, now the logic is much easier to understand. --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 21 ++- .../jit/codegen/cuda/inline_propagator.cpp | 174 ++++++------------ .../csrc/jit/codegen/cuda/inline_propagator.h | 36 ++-- .../jit/codegen/cuda/ir_interface_nodes.h | 2 + .../jit/codegen/cuda/transform_replay.cpp | 57 ++++++ .../csrc/jit/codegen/cuda/transform_replay.h | 7 + 6 files changed, 156 insertions(+), 141 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index ed0f9fb271d574..3793cb19237727 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -184,13 +184,20 @@ void ComputeAt::runAt( auto selected = getPropagationSubgraph(producer, consumer); InlinePropagatorSelector selector(selected); - TransformPropagator propagator(consumer, consumer_position); InlinePropagator inline_propagator( selector.selected(), consumer, consumer_position, mode); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); - path.traverse(&propagator); + + if (mode == ComputeAtMode::MostInlined) { + MostInlinedTransformPropagator propagator; + path.traverse(&propagator); + } else { + TransformPropagator propagator(consumer, consumer_position); + path.traverse(&propagator); + } + path.traverse(&inline_propagator); path.traverse(&updater); } @@ -219,13 +226,19 @@ void ComputeAt::runWith( auto selected = getPropagationSubgraph(producer, consumer); InlinePropagatorSelector selector(selected); - TransformPropagator propagator(producer, producer_position); InlinePropagator inline_propagator( selector.selected(), producer, producer_position, mode); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); - path.traverse(&propagator); + + if (mode == ComputeAtMode::MostInlined) { + MostInlinedTransformPropagator propagator; + path.traverse(&propagator); + } else { + TransformPropagator propagator(producer, producer_position); + path.traverse(&propagator); + } path.traverse(&inline_propagator); path.traverse(&updater); } diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 195ef3e67a188d..5b0cb66d1d7331 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -178,22 +178,11 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) { return max_pos; } -size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { - pos = std::min(pos, getMaxPosAll(tv)); - - // hoist inner most broadcast - while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { - pos--; - } - - return pos; -} - -size_t InlinePropagator::getReplayPosPasC( +size_t InlinePropagator::getFromPosPasC( TensorView* producer, TensorView* consumer) { size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer); - size_t pos = retrieveReplayedPos(consumer); + size_t pos = mapped_reference_pos_.at(consumer); if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -203,10 +192,10 @@ size_t InlinePropagator::getReplayPosPasC( TORCH_INTERNAL_ASSERT( pos <= max_pos, - "Invalid compute at position detected in compute at when trying to replay producer: ", - producer, - " as consumer: ", + "Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ", consumer, + " to producer: ", + producer, " tried to do this at position: ", pos, " but max position that's allowed is ", @@ -214,11 +203,11 @@ size_t InlinePropagator::getReplayPosPasC( return pos; } -size_t InlinePropagator::getReplayPosCasP( +size_t InlinePropagator::getFromPosCasP( TensorView* consumer, TensorView* producer) { size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer); - size_t pos = retrieveReplayedPos(producer); + size_t pos = mapped_reference_pos_.at(producer); if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -228,10 +217,10 @@ size_t InlinePropagator::getReplayPosCasP( TORCH_INTERNAL_ASSERT( pos <= max_pos, - "Invalid compute at position detected in compute at when trying to replay consumer: ", - consumer, - " as producer: ", + "Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ", producer, + " to consumer: ", + consumer, " tried to do this at position: ", pos, " but max position that's allowed is ", @@ -239,31 +228,17 @@ size_t InlinePropagator::getReplayPosCasP( return pos; } -void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) { - if (selected_.count(tv)) { - auto new_pos = adjustComputeAtPos(tv, pos); - if (pos != new_pos) { - replayed_pos_[tv] = pos; - pos = new_pos; +void InlinePropagator::setCAPos(TensorView* tv, size_t pos) { + if (selected_.count(tv) && !tv->isFusionInput()) { + pos = std::min(pos, getMaxPosAll(tv)); + // hoist inner most broadcast + while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { + pos--; } - if (!tv->isFusionInput()) { - tv->setComputeAt(pos); - } else { - replayed_pos_[tv] = pos; - } - } else { - replayed_pos_[tv] = pos; + tv->setComputeAt(pos); } } -size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { - auto it = replayed_pos_.find(tv); - if (it != replayed_pos_.end()) { - return it->second; - } - return tv->getComputeAtPosition(); -} - InlinePropagator::InlinePropagator( std::unordered_set selected, TensorView* reference, @@ -288,101 +263,62 @@ InlinePropagator::InlinePropagator( "."); } -namespace { - -// Make sure if tv is set to new_td it doesn't violate set compute at and max -// produce at positions. -bool validateDomain(TensorView* tv, TensorDomain* new_td) { - auto first_mismatch = - BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); - return first_mismatch >= (int)tv->getMaxProducerPosition() && - first_mismatch >= (int)tv->getComputeAtPosition(); -} - -} // namespace - void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - recordReplayedPos(reference_, reference_pos_); + setCAPos(reference_, reference_pos_); + mapped_reference_pos_[reference_] = reference_pos_; } - int pos = getReplayPosPasC(to, from); + int from_pos = getFromPosPasC(to, from); auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); - if (mode_ != ComputeAtMode::MostInlined) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from consumer ", - from, - " to producer ", - to, - " because this would require replay."); - } - if (to_pos < 0) { - auto replay = TransformReplay::replayPasC(to, from, pos); - TORCH_INTERNAL_ASSERT( - validateDomain(to, replay.first), - "Tried to set the domain of ", - to, - " to ", - replay.first, - " but that would invalidate previously compute at position or max producer position."); - to->setDomain(replay.first); - to_pos = replay.second; - } - recordReplayedPos(to, to_pos); + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from consumer ", + from, + " to producer ", + to, + " because this would require replay."); + setCAPos(to, to_pos); + mapped_reference_pos_[to] = to_pos; } void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - recordReplayedPos(reference_, reference_pos_); + setCAPos(reference_, reference_pos_); + mapped_reference_pos_[reference_] = reference_pos_; } - int pos = getReplayPosCasP(to, from); + int from_pos = getFromPosCasP(to, from); auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); - if (mode_ != ComputeAtMode::MostInlined) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from producer ", - from, - " to consumer ", - to, - " because this would require replay."); - } - if (to_pos < 0) { - auto replay = TransformReplay::replayCasP(to, from, pos); - TORCH_INTERNAL_ASSERT( - validateDomain(to, replay.first), - "Tried to set the domain of ", - to, - " to ", - replay.first, - " but that would invalidate previously compute at position or max producer position."); - to->setDomain(replay.first); - to_pos = replay.second; - } - recordReplayedPos(to, to_pos); + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from producer ", + from, + " to consumer ", + to, + " because this would require replay."); + setCAPos(to, to_pos); + mapped_reference_pos_[to] = to_pos; } void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - recordReplayedPos(reference_, reference_pos_); - } - auto from_pos = retrieveReplayedPos(from); - if (!TransformReplay::fullSelfMatching(to, from)) { - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - TORCH_INTERNAL_ASSERT( - validateDomain(to, replay), - "Tried to set the domain of ", - to, - " to ", - replay, - " but that would invalidate previously compute at position or max producer position."); - to->setDomain(replay); + setCAPos(reference_, reference_pos_); + mapped_reference_pos_[reference_] = reference_pos_; } - recordReplayedPos(to, from_pos); + auto from_pos = mapped_reference_pos_.at(from); + TORCH_CHECK( + TransformReplay::fullSelfMatching(to, from), + "Unable to propagate CA position from ", + from, + " to sibling ", + to, + " because this would require replay."); + setCAPos(to, from_pos); + mapped_reference_pos_[to] = from_pos; } namespace { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 40df6548add0d1..5b07ac1fd30881 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -51,18 +51,18 @@ class MaxPosCalculator { bool allow_unmappable) const; public: - // Returns the position at which tv can be relayed within. + // Returns the position at which tv can be inlined within. size_t getMaxPosSelf( TensorView* tv, bool allow_reduction, bool allow_vectorize, bool allow_unmappable) const; - // Returns the maximum position producer can be replayed based on consumer + // Returns the maximum position producer can be inlined based on consumer // given the set ComputeAtMode size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const; - // Returns the maximum position consumer can be replayed based on producer + // Returns the maximum position consumer can be inlined based on producer // given the set ComputeAtMode size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const; @@ -74,34 +74,34 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // that can be shared across both directions. size_t getMaxPosAll(TensorView* tv); - // Returns position of getMaxPosAll while also hoisting outside broadcast - // dimensions. - size_t adjustComputeAtPos(TensorView* tv, size_t pos); - - // Returns the replay position in consumer that producer should be replayed as + // Returns the inline position in consumer that producer should be inlined as // based on consumer, taking into consideration the max possible returned by // getMaxPos{PasC, CasP}, the compute at mode type. - size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); + size_t getFromPosPasC(TensorView* producer, TensorView* consumer); - // Returns the replay position in producer that consumer should be replayed as + // Returns the inline position in producer that consumer should be inlined as // based on producer, taking into consideration the max possible returned by // getMaxPos{PasC, CasP}, the compute at mode type. - size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); + size_t getFromPosCasP(TensorView* consumer, TensorView* producer); - // Sets the compute at position of tv and records the position in - // replayed_pos_ - void recordReplayedPos(TensorView* tv, size_t pos); + // We use mapped_reference_pos_ to keep track of the outer axes information of + // the reference tensor. That is, mapped_reference_pos_[tv] answers the + // question "What outer axes in tv are shared with the specified reference + // tensor's outer axes?". However, when we actually set the CA position of tv, + // we might not want to set it as mapped_reference_pos_[tv] because because we + // don't want to inline certain things (such as vectorized dimensions, inner + // most broadcasting, etc.). + std::unordered_map mapped_reference_pos_; - // Returns the entry for tv in replayed_pos_ if it exists, else returns the - // compute at position of tv. - size_t retrieveReplayedPos(TensorView* tv); + // Actually set the computeAt position. This does not necessarily equal to + // mapped_reference_pos_[tv] because we don't want to inline certain things. + void setCAPos(TensorView* tv, size_t pos); const MaxPosCalculator max_pos_calc; std::unordered_set selected_; TensorView* reference_; size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; - std::unordered_map replayed_pos_; bool is_first_ = true; public: diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index bb4484f7db6c22..db68a9339948c0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -157,6 +157,7 @@ enum class ComputeAtMode { Standard, BestEffort, MostInlined }; class InlinePropagator; class MaxProducerPosUpdater; class TransformPropagator; +struct MostInlinedTransformPropagator; class TransformIter; class TransformReplay; class OptOutMutator; @@ -457,6 +458,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { void applyMmaSwizzle(MmaOptions options); friend TORCH_CUDA_CU_API TransformPropagator; + friend TORCH_CUDA_CU_API MostInlinedTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; friend TORCH_CUDA_CU_API InlinePropagator; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 301b693503023a..3b1493ee684d6c 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -922,6 +922,63 @@ TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { replayed_pos_[from] = pos; } +void MostInlinedTransformPropagator::propagateTvPasC( + TensorView* from, + TensorView* to) { + int pos = from->nDims(); + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + } +} + +void MostInlinedTransformPropagator::propagateTvCasP( + TensorView* from, + TensorView* to) { + int pos = from->nDims(); + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + } +} + +void MostInlinedTransformPropagator::propagateTvSibling( + TensorView* from, + TensorView* to) { + // See note [Using multiple TransformPropagators] + if (!TransformReplay::fullSelfMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay), + "Tried to set the domain of ", + to, + " to ", + replay, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay); + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 1a5433bef6c3a3..1ad4a8d2331b73 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -194,6 +194,13 @@ class TORCH_CUDA_CU_API TransformPropagator TransformPropagator(TensorView* from, int64_t pos = -1); }; +struct TORCH_CUDA_CU_API MostInlinedTransformPropagator + : public MaxRootDomainInfoSpanningTree::Propagator { + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; +}; + } // namespace cuda } // namespace fuser } // namespace jit