diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index d1f10e56a33aa..34eb8a94948ec 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -132,12 +132,17 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( return producer->nDims(); } -size_t InlinePropagator::getMaxPosAll(TensorView* tv) { +size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { max_pos = std::min( max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv)); } + if (check_siblings) { + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + max_pos = std::min(max_pos, getMaxPosAll(sibling_tv, false)); + } + } return max_pos; } diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 46af175f6e8e8..2ed137ac5955e 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -70,7 +70,7 @@ class MaxPosCalculator { class InlinePropagator : public MaxInfoSpanningTree::Propagator { // Checks producers and consumers to see what the maximum position in tv is // that can be shared across both directions. - size_t getMaxPosAll(TensorView* tv); + size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); // We use mapped_reference_pos_ to keep track of the outer axes information of // the reference tensor. That is, mapped_reference_pos_[tv] answers the diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index abaf28dfdf77e..ca9724636e984 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24167,6 +24167,30 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { } } +TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tvs = Welford(tv1, {1}); + auto tvo = set(tvs.var_sum); + fusion.addOutput(tvo); + + tvo->split(0, 16); + tvo->axis(1)->parallelize(ParallelType::Unroll); + + tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort); + + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition()); + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition()); + TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)