Skip to content

Commit

Permalink
Symmetric API for BestEffortReplay (#1870)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 29, 2022
1 parent d1caf33 commit 440102b
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 36 deletions.
26 changes: 7 additions & 19 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ void InlinePropagator::setUp() {
namespace {

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain. Used in computeAt pass only. No
// checking on actual producer-consumer relationship.
// compute at position of producer domain. Used in InlinePropagator pass only.
// No checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
Expand All @@ -239,18 +239,10 @@ unsigned int getConsumerPosAlignedToProducerCA(
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

auto c2p_map =
auto disjoint_sets =
BestEffortReplay::replayPasC(
producer,
consumer,
-1,
// Compute at root domain may not be valid here, as all
// producers don't have to be able to map into consumer at
// max producer position. Since computeAt should be valid
// and this mechanism is only intended to lower produce
// position of consumer, we can simply use the pairwise map.
PairwiseRootDomainMap(producer, consumer))
.getReplay();
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer))
.getDisjointSets();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
Expand All @@ -261,12 +253,8 @@ unsigned int getConsumerPosAlignedToProducerCA(
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &c2p_map](IterDomain* p_id) {
auto c_id_it = c2p_map.find(consumer_id);
if (c_id_it != c2p_map.end()) {
return c_id_it->second == p_id;
}
return false;
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
})) {
break;
}
Expand Down
66 changes: 66 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/grouped_reduction.h>
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
Expand Down Expand Up @@ -24846,6 +24847,71 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
tv2->toString());
}

TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
// broadcasting
auto tv2 = broadcast(tv1, {false, true, false, true, false, true});
auto tv3 = cos(tv2);
auto tv4 = tan(tv3);
fusion.addOutput(tv4);

for (auto tv : {tv2, tv3, tv4}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator);

TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 3);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().view({2, 1, 3, 1, 4, 1}).cos().tan();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {false, true, false, true, false, true});
auto tv2 = sin(tv1);
fusion.addOutput(tv2);

for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,18 @@ void BestEffortReplay::skipSwizzles(
}
}

DisjointSets<IterDomain*> BestEffortReplay::getDisjointSets() {
DisjointSets<IterDomain*> result;
const std::unordered_map<IterDomain*, IterDomain*>* maps[3] = {
&target2replay_id_map_, &replay_forward_id_map_, &target_forward_id_map_};
for (auto map : maps) {
for (auto entry : *map) {
result.mapEntries(entry.first, entry.second);
}
}
return result;
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
Expand Down Expand Up @@ -307,6 +308,8 @@ class TORCH_CUDA_CU_API BestEffortReplay {
return leaf_vec_;
}

DisjointSets<IterDomain*> getDisjointSets();

// Runs a best effort replay that ignores broadcast axes that appear in
// consumer that are not mapped to producer in root_map.
static BestEffortReplay replayCasP(
Expand Down
24 changes: 7 additions & 17 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,9 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

id_map c2p_map =
auto disjoint_sets =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getReplay();
.getDisjointSets();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
Expand All @@ -703,13 +703,8 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
return -1;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
return -1;
}

auto producer_id = *it_producer;
if (c2p_it->second == producer_id) {
if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
Expand Down Expand Up @@ -759,9 +754,9 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
auto it_producer = producer_domain.begin();
auto it_consumer = consumer_domain.begin();

id_map replay_map =
BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map)
.getReplay();
auto disjoint_sets =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getDisjointSets();

int mismatched_producer_pos = 0;
int mismatched_consumer_pos = 0;
Expand All @@ -788,12 +783,7 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
continue;
}

auto replay_it = replay_map.find(producer_id);
if (replay_it == replay_map.end()) {
return -1;
}

if (replay_it->second == consumer_id) {
if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
++mismatched_producer_pos;
++mismatched_consumer_pos;
++it_producer;
Expand Down

0 comments on commit 440102b

Please sign in to comment.