Skip to content

Commit

Permalink
Make MostInlined and BestEffort inline propagation no longer assert r…
Browse files Browse the repository at this point in the history
…eplayed (#1868)
  • Loading branch information
zasdfgbnm authored Jul 29, 2022
1 parent a64462a commit 172fb36
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 31 deletions.
83 changes: 52 additions & 31 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
producer_pos++) {
auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
if (map_it != p2c_replay_map.end()) {
// If the producer position is mismatching with the consumer, then we can
// not inline into this position, otherwise the max producer position of
// the consumer will become invalid and expression sort will fail.
if (TransformReplay::getMatchedLeafPosWithoutReplayCasP(
consumer, producer, producer_pos + 1) < 0) {
return producer_pos;
}
auto c_id = map_it->second;
if (!isAllowedID(c_id, consumer, true, false, true)) {
return producer_pos;
Expand Down Expand Up @@ -159,8 +166,10 @@ void InlinePropagator::setCAPos(TensorView* tv) {
pos,
", max position that's allowed is ",
max_pos);
} else {
} else if (mode_ == ComputeAtMode::BestEffort) {
pos = std::min<size_t>(pos, max_pos);
} else {
pos = max_pos;
}
// hoist inner most broadcast
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
Expand Down Expand Up @@ -285,23 +294,29 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
std::cout << " to: " << to << std::endl;
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = mapped_reference_pos_.at(from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
if (mode_ == ComputeAtMode::Standard) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
} else {
// For MostInlined and BestEffort inline propagation, we allow the DAG to
// be not replayed fully consistently. For such case, we just don't inline
// into the mismatched dimension.
while (to_pos < 0) {
from_pos--;
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
to, from, from_pos);
}
}
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
Expand All @@ -315,23 +330,29 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
std::cout << " to: " << to << std::endl;
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = mapped_reference_pos_.at(from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
if (mode_ == ComputeAtMode::Standard) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
} else {
// For MostInlined and BestEffort inline propagation, we allow the DAG to
// be not replayed fully consistently. For such case, we just don't inline
// into the mismatched dimension.
while (to_pos < 0) {
from_pos--;
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
to, from, from_pos);
}
}
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
Expand Down
116 changes: 116 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24847,6 +24847,122 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
tv2->toString());
}

TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
auto tv2 = cos(tv1);
auto tv3 = transpose(tv2, 1, 2);
auto tv4 = exp(tv3);
auto tv5 = tan(tv4);
fusion.addOutput(tv5);

InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator);

TORCH_CHECK(tv5->getComputeAtPosition() == 3);
TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 1);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().cos().transpose(1, 2).exp().tan();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims2_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
auto tv2 = cos(tv1);
auto tv3 = transpose(tv2, 1, 2);
auto tv4 = exp(tv3);
auto tv5 = tan(tv4);
fusion.addOutput(tv5);

InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::BestEffort);
MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator);

TORCH_CHECK(tv5->getComputeAtPosition() == 3);
TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 1);
TORCH_CHECK(tv1->getComputeAtPosition() == 1);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().cos().transpose(1, 2).exp().tan();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims3_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
// broadcasting
auto tv2 = broadcast(tv1, {false, true, false, true, false, true});
auto tv3 = relu(tv2);
// trivial reduction
auto tv4 = sum(tv3, {1, 3, 5});
auto tv5 = cos(tv4);
auto tv6 = transpose(tv5, 1, 2);
auto tv7 = exp(tv6);
auto tv8 = tan(tv7);
fusion.addOutput(tv8);

for (auto tv : {tv2, tv3, tv4}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv8, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv8).traverse(&inline_propagator);

TORCH_CHECK(tv8->getComputeAtPosition() == 3);
TORCH_CHECK(tv7->getComputeAtPosition() == 3);
TORCH_CHECK(tv6->getComputeAtPosition() == 3);
TORCH_CHECK(tv5->getComputeAtPosition() == 1);
TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 3);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().relu().cos().transpose(1, 2).exp().tan();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit 172fb36

Please sign in to comment.