diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index b6e8d1ed35af0..20667a9e7f87e 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -104,129 +104,56 @@ size_t MaxPosCalculator::getMaxPosSelf( return std::distance(dom.begin(), iter); } -// Return the max position in consumer that producer can be inlined to -// Cannot inline: -// Reduction dimensions in producer -// Block broadcast dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -size_t MaxPosCalculator::getMaxPosC2P( - TensorView* consumer, - TensorView* producer) const { - // Limit max position based on vectorized dims in consumer. - auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); - - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto replay_PasC = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - auto c2p_replay_map = replay_PasC.getReplay(); - - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); - if (map_it != c2p_replay_map.end()) { - auto p_id = map_it->second; - if (!isAllowedID(p_id, producer, true, false, false)) { - max_consumer_pos = consumer_pos - 1; - } - } - } - - return max_consumer_pos; -} - // Return the max position in producer that can be inlined to consumer // Cannot inline: -// Reduction dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -size_t MaxPosCalculator::getMaxPosP2C( +// Vectorized dimensions in consumer +// Unrolled dimensions in consumer +size_t MaxPosCalculator::getMaxProducerPosFromConsumer( TensorView* producer, TensorView* consumer) const { - auto max_producer_pos = getMaxPosSelf(producer, false, false, false); - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto replay_CasP = BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); auto p2c_replay_map = replay_CasP.getReplay(); - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); + for (size_t producer_pos = 0; producer_pos < producer->nDims(); + producer_pos++) { + auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); if (map_it != p2c_replay_map.end()) { auto c_id = map_it->second; if (!isAllowedID(c_id, consumer, true, false, true)) { - max_producer_pos = producer_pos - 1; + return producer_pos; } } } - - return max_producer_pos; + return producer->nDims(); } size_t InlinePropagator::getMaxPosAll(TensorView* tv) { auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - // consumers are always replayed consistently - max_pos = - std::min(max_pos, max_pos_calc.getMaxPosP2C(tv, consumer_tv)); + max_pos = std::min( + max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv)); } return max_pos; } -size_t InlinePropagator::getFromPosC2P(TensorView* from, TensorView* to) { - size_t max_pos = max_pos_calc.getMaxPosC2P(from, to); - size_t pos = mapped_reference_pos_.at(from); - - if (mode_ == ComputeAtMode::BestEffort) { - return std::min(pos, max_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - return max_pos; - } - - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ", - from, - " to producer: ", - to, - " tried to do this at position: ", - pos, - " but max position that's allowed is ", - max_pos); - return pos; -} - -size_t InlinePropagator::getFromPosP2C(TensorView* from, TensorView* to) { - size_t max_pos = max_pos_calc.getMaxPosP2C(from, to); - size_t pos = mapped_reference_pos_.at(from); - - if (mode_ == ComputeAtMode::BestEffort) { - return std::min(pos, max_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - return max_pos; - } - - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ", - from, - " to consumer: ", - to, - " tried to do this at position: ", - pos, - " but max position that's allowed is ", - max_pos); - return pos; -} - -void InlinePropagator::setCAPos(TensorView* tv, size_t pos) { +void InlinePropagator::setCAPos(TensorView* tv) { + size_t pos = mapped_reference_pos_.at(tv); if (selected_.count(tv) && !tv->isFusionInput()) { - pos = std::min(pos, getMaxPosAll(tv)); + auto max_pos = getMaxPosAll(tv); + if (mode_ == ComputeAtMode::Standard) { + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ", + tv, + " to ", + pos, + ", max position that's allowed is ", + max_pos); + } else { + pos = std::min(pos, max_pos); + } // hoist inner most broadcast while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { pos--; @@ -262,10 +189,16 @@ InlinePropagator::InlinePropagator( void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - setCAPos(reference_, reference_pos_); mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); + } + // Step 1: find mapped_reference_pos_[to] + int from_pos; + if (mode_ != ComputeAtMode::MostInlined) { + from_pos = mapped_reference_pos_.at(from); + } else { + from_pos = from->nDims(); } - int from_pos = getFromPosC2P(from, to); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); TORCH_CHECK( @@ -275,17 +208,24 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { " to producer ", to, " because this would require replay."); - setCAPos(to, to_pos); mapped_reference_pos_[to] = to_pos; + // Step 2: set CA position of `to` + setCAPos(to); } void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - setCAPos(reference_, reference_pos_); mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); + } + // Step 1: find mapped_reference_pos_[to] + int from_pos; + if (mode_ != ComputeAtMode::MostInlined) { + from_pos = mapped_reference_pos_.at(from); + } else { + from_pos = from->nDims(); } - int from_pos = getFromPosP2C(from, to); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); TORCH_CHECK( @@ -295,16 +235,18 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { " to consumer ", to, " because this would require replay."); - setCAPos(to, to_pos); mapped_reference_pos_[to] = to_pos; + // Step 2: set CA position of `to` + setCAPos(to); } void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; - setCAPos(reference_, reference_pos_); mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); } + // Step 1: find mapped_reference_pos_[to] auto from_pos = mapped_reference_pos_.at(from); TORCH_CHECK( TransformReplay::fullSelfMatching(to, from), @@ -313,8 +255,9 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) { " to sibling ", to, " because this would require replay."); - setCAPos(to, from_pos); mapped_reference_pos_[to] = from_pos; + // Step 2: set CA position of `to` + setCAPos(to); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 7dc41511ae8ed..46af175f6e8e8 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -60,11 +60,9 @@ class MaxPosCalculator { // Returns the maximum position producer can be inlined based on consumer // given the set ComputeAtMode - size_t getMaxPosC2P(TensorView* from, TensorView* to) const; - - // Returns the maximum position consumer can be inlined based on producer - // given the set ComputeAtMode - size_t getMaxPosP2C(TensorView* from, TensorView* to) const; + size_t getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer) const; MaxPosCalculator(ComputeAtMode mode); }; @@ -74,16 +72,6 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // that can be shared across both directions. size_t getMaxPosAll(TensorView* tv); - // Returns the inline position in consumer that producer should be inlined as - // based on consumer, taking into consideration the max possible returned by - // getMaxPos{P2C, C2P}, the compute at mode type. - size_t getFromPosC2P(TensorView* from, TensorView* to); - - // Returns the inline position in producer that consumer should be inlined as - // based on producer, taking into consideration the max possible returned by - // getMaxPos{P2C, C2P}, the compute at mode type. - size_t getFromPosP2C(TensorView* from, TensorView* to); - // We use mapped_reference_pos_ to keep track of the outer axes information of // the reference tensor. That is, mapped_reference_pos_[tv] answers the // question "What outer axes in tv are shared with the specified reference @@ -95,7 +83,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // Actually set the computeAt position. This does not necessarily equal to // mapped_reference_pos_[tv] because we don't want to inline certain things. - void setCAPos(TensorView* tv, size_t pos); + void setCAPos(TensorView* tv); const MaxPosCalculator max_pos_calc; std::unordered_set selected_;