Skip to content

Commit

Permalink
Check siblings in getMaxPosAll (csarofeen#1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 6, 2022
1 parent 025c840 commit fa4e6a4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
7 changes: 6 additions & 1 deletion torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,17 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
return producer->nDims();
}

size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {
auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false);
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
max_pos = std::min<size_t>(
max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv));
}
if (check_siblings) {
for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) {
max_pos = std::min<size_t>(max_pos, getMaxPosAll(sibling_tv, false));
}
}
return max_pos;
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/inline_propagator.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class MaxPosCalculator {
class InlinePropagator : public MaxInfoSpanningTree::Propagator {
// Checks producers and consumers to see what the maximum position in tv is
// that can be shared across both directions.
size_t getMaxPosAll(TensorView* tv);
size_t getMaxPosAll(TensorView* tv, bool check_siblings = true);

// We use mapped_reference_pos_ to keep track of the outer axes information of
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24167,6 +24167,30 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) {
}
}

TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

TensorView* tv0 = makeContigTensor(2);

fusion.addInput(tv0);
auto tv1 = set(tv0);
auto tvs = Welford(tv1, {1});
auto tvo = set(tvs.var_sum);
fusion.addOutput(tvo);

tvo->split(0, 16);
tvo->axis(1)->parallelize(ParallelType::Unroll);

tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort);

TORCH_CHECK(
tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition());
TORCH_CHECK(
tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition());
TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)

0 comments on commit fa4e6a4

Please sign in to comment.