diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 35a35acfe0a97..ed0f9fb271d57 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -165,7 +165,7 @@ void ComputeAt::runAt( TensorView* consumer, unsigned int consumer_position, ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::run"); + FUSER_PERF_SCOPE("ComputeAt::runAt"); // Make sure the correct fusion is setup between this and consumer. TORCH_CHECK( @@ -175,6 +175,10 @@ void ComputeAt::runAt( consumer, " are not in the same fusion."); + if (mode == ComputeAtMode::MostInlined) { + consumer_position = consumer->nDims(); + } + FusionGuard fg(producer->fusion()); auto selected = getPropagationSubgraph(producer, consumer); @@ -206,6 +210,10 @@ void ComputeAt::runWith( consumer, " are not in the same fusion."); + if (mode == ComputeAtMode::MostInlined) { + producer_position = producer->nDims(); + } + FusionGuard fg(producer->fusion()); auto selected = getPropagationSubgraph(producer, consumer); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 74e2e6b407c1a..195ef3e67a188 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -309,9 +309,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = getReplayPosPasC(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); - // TODO: Can we make TransformPropagator do the transformation, and - // InlinePropagator only set the CA positions? - // TORCH_CHECK(to_pos >= 0); + if (mode_ != ComputeAtMode::MostInlined) { + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from consumer ", + from, + " to producer ", + to, + " because this would require replay."); + } if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -335,9 +341,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = getReplayPosCasP(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); - // TODO: Can we make TransformPropagator do the transformation, and - // InlinePropagator only set the CA positions? - // TORCH_CHECK(to_pos >= 0); + if (mode_ != ComputeAtMode::MostInlined) { + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from producer ", + from, + " to consumer ", + to, + " because this would require replay."); + } if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -373,23 +385,71 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { recordReplayedPos(to, from_pos); } +namespace { + // Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. -void MaxProducerPosUpdater::handle(TensorView* consumer) { +// compute at position of producer domain. Used in computeAt pass only. No +// checking on actual producer-consumer relationship. +unsigned int getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto c2p_map = + BestEffortReplay::replayPasC( + producer, + consumer, + -1, + // Compute at root domain may not be valid here, as all + // producers don't have to be able to map into consumer at + // max producer position. Since computeAt should be valid + // and this mechanism is only intended to lower produce + // position of consumer, we can simply use the pairwise map. + PairwiseRootDomainMap(producer, consumer)) + .getReplay(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. unsigned int consumer_pos = consumer->nDims(); while (consumer_pos > 0) { - for (auto producer : ir_utils::producerTvsOf(consumer)) { - auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( - producer, consumer, consumer_pos); - if (producer_pos >= 0 && - producer_pos <= producer->getComputeAtPosition()) { - goto finished; - } + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &c2p_map](IterDomain* p_id) { + auto c_id_it = c2p_map.find(consumer_id); + if (c_id_it != c2p_map.end()) { + return c_id_it->second == p_id; + } + return false; + })) { + break; } consumer_pos--; } -finished: - consumer->setMaxProducer(consumer_pos, true); + + return consumer_pos; +} + +} // namespace + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. +void MaxProducerPosUpdater::handle(TensorView* consumer) { + unsigned int consumer_pos = 0; + for (auto producer : ir_utils::producerTvsOf(consumer)) { + consumer_pos = std::max( + consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer)); + } + consumer->setMaxProducer(consumer_pos); } void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {