Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow trivial reduction to be merged #1871

Merged
merged 6 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ void IterDomainGraph::build(Fusion* fusion) {
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map);

const auto& permissive_c2p_map = permissive_replay_PasC.getReplay();
const auto permissive_disjoint_sets =
permissive_replay_PasC.getDisjointSets();

// For exact mapings do not map any broadcast dimensions to
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
Expand Down Expand Up @@ -213,6 +215,17 @@ void IterDomainGraph::build(Fusion* fusion) {
auto p_id = entry.second;
if (idIsAComputeAtLeafDomain(p_id, p_tv)) {
loop_nodes_.mapEntries(c_id, p_id);
} else {
// When there are trivial reductions merged with other dims, `p_id`
// might not be a compute at leaf domain of `p_tv`, but it actually
// has an equivalent compute at leaf domain. For that case, we map
// the equivalent compute at leaf domain.
for (int i = 0; i < p_tv->getComputeAtPosition(); i++) {
auto id = p_tv->axis(i);
if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
loop_nodes_.mapEntries(c_id, id);
}
}
Comment on lines +218 to +228
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was trying to do

        for (const auto& set : permissive_disjoint_sets.disjointSets()) {
          auto id1 = set->front();
          for (auto id2 : *set) {
            auto is_leaf1 = idIsAComputeAtLeafDomain(id1, p_tv);
            auto is_leaf2 = idIsAComputeAtLeafDomain(id2, p_tv);
            if (is_leaf1 || is_leaf2) {
              loop_nodes_.mapEntries(id1, id2);
            }
            permissive_nodes_.mapEntries(id1, id2);

            // Add the swizzle inputs to the same
            //  disjoint set as well if either c_id
            //  or p_id is swizzle output.
            mapMaybeSwizzleOp(permissive_nodes_, id1);
            mapMaybeSwizzleOp(permissive_nodes_, id2);
          }
        }

But it didn't work. Don't know why.

}
permissive_nodes_.mapEntries(c_id, p_id);
consumers_.at(p_id).pushBack(c_id);
Expand All @@ -225,8 +238,8 @@ void IterDomainGraph::build(Fusion* fusion) {
mapMaybeSwizzleOp(permissive_nodes_, c_id);
}

// Make sure we always get root mapping for the permissive map. Because
// of forwarding we could otherwise miss some root mappings.
// Make sure we always get root mapping for the permissive map.
// Because of forwarding we could otherwise miss some root mappings.
for (auto entry : permissive_c2p_root_map) {
auto c_id = entry.first;
auto p_id = entry.second;
Expand Down
26 changes: 7 additions & 19 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ void InlinePropagator::setUp() {
namespace {

// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain. Used in computeAt pass only. No
// checking on actual producer-consumer relationship.
// compute at position of producer domain. Used in InlinePropagator pass only.
// No checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
Expand All @@ -254,18 +254,10 @@ unsigned int getConsumerPosAlignedToProducerCA(
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA

auto c2p_map =
auto disjoint_sets =
BestEffortReplay::replayPasC(
producer,
consumer,
-1,
// Compute at root domain may not be valid here, as all
// producers don't have to be able to map into consumer at
// max producer position. Since computeAt should be valid
// and this mechanism is only intended to lower produce
// position of consumer, we can simply use the pairwise map.
PairwiseRootDomainMap(producer, consumer))
.getReplay();
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer))
.getDisjointSets();

// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
Expand All @@ -276,12 +268,8 @@ unsigned int getConsumerPosAlignedToProducerCA(
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &c2p_map](IterDomain* p_id) {
auto c_id_it = c2p_map.find(consumer_id);
if (c_id_it != c2p_map.end()) {
return c_id_it->second == p_id;
}
return false;
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
})) {
break;
}
Expand Down
174 changes: 174 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/grouped_reduction.h>
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
Expand Down Expand Up @@ -24846,6 +24847,179 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
tv2->toString());
}

TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
// broadcasting
auto tv2 = broadcast(tv1, {false, true, false, true, false, true});
auto tv3 = cos(tv2);
auto tv4 = tan(tv3);
fusion.addOutput(tv4);

for (auto tv : {tv2, tv3, tv4}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator);

TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 3);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().view({2, 1, 3, 1, 4, 1}).cos().tan();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
// broadcasting
auto tv2 = broadcast(tv1, {false, true, false, true, false, true});
auto tv3 = tan(tv2);
// trivial reduction
auto tv4 = sum(tv3, {1, 3, 5});
auto tv5 = cos(tv4);
auto tv6 = exp(tv5);
fusion.addOutput(tv6);

for (auto tv : {tv2, tv3, tv4}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv6).traverse(&inline_propagator);

TORCH_CHECK(tv6->getComputeAtPosition() == 3);
TORCH_CHECK(tv5->getComputeAtPosition() == 3);
TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 3);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().tan().cos().exp();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 1, 3, 1, 4, 1});
fusion.addInput(tv0);
auto tv1 = sum(tv0, {1, 3, 5});
auto tv2 = sin(tv1);
fusion.addOutput(tv1);

for (auto tv : {tv0, tv1}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3);
}

TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {false, true, false, true, false, true});
auto tv2 = sin(tv1);
fusion.addOutput(tv2);

for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3);
}

TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {false, true, false, true, false, true});
auto tv2 = sum(tv1, {1, 3, 5});
auto tv3 = sin(tv2);
fusion.addOutput(tv3);

for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv3).traverse(&inline_propagator);

ComputeAtMap ca_map(&fusion);

auto all_tvs = ir_utils::allTvs(&fusion);
for (auto tv1 : all_tvs) {
for (auto tv2 : all_tvs) {
if (tv1->isFusionInput() || tv2->isFusionInput()) {
continue;
}
for (int i : c10::irange(3)) {
auto id1 = tv1->axis(i);
auto id2 = tv2->axis(i);
TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::LOOP));
TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE));
}
}
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,18 @@ void BestEffortReplay::skipSwizzles(
}
}

DisjointSets<IterDomain*> BestEffortReplay::getDisjointSets() {
DisjointSets<IterDomain*> result;
const std::unordered_map<IterDomain*, IterDomain*>* maps[3] = {
&target2replay_id_map_, &replay_forward_id_map_, &target_forward_id_map_};
for (auto map : maps) {
for (auto entry : *map) {
result.mapEntries(entry.first, entry.second);
}
}
return result;
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/transform_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <c10/macros/Export.h>

#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
Expand Down Expand Up @@ -307,6 +308,8 @@ class TORCH_CUDA_CU_API BestEffortReplay {
return leaf_vec_;
}

DisjointSets<IterDomain*> getDisjointSets();

// Runs a best effort replay that ignores broadcast axes that appear in
// consumer that are not mapped to producer in root_map.
static BestEffortReplay replayCasP(
Expand Down
24 changes: 7 additions & 17 deletions torch/csrc/jit/codegen/cuda/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,9 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
auto it_consumer = consumer_domain.begin();
auto it_producer = producer_domain.begin();

id_map c2p_map =
auto disjoint_sets =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getReplay();
.getDisjointSets();

int mismatched_consumer_pos = 0;
int mismatched_producer_pos = 0;
Expand All @@ -703,13 +703,8 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
return -1;
}

auto c2p_it = c2p_map.find(consumer_id);
if (c2p_it == c2p_map.end()) {
return -1;
}

auto producer_id = *it_producer;
if (c2p_it->second == producer_id) {
if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
++mismatched_consumer_pos;
++mismatched_producer_pos;
++it_consumer;
Expand Down Expand Up @@ -759,9 +754,9 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
auto it_producer = producer_domain.begin();
auto it_consumer = consumer_domain.begin();

id_map replay_map =
BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map)
.getReplay();
auto disjoint_sets =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getDisjointSets();

int mismatched_producer_pos = 0;
int mismatched_consumer_pos = 0;
Expand All @@ -788,12 +783,7 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
continue;
}

auto replay_it = replay_map.find(producer_id);
if (replay_it == replay_map.end()) {
return -1;
}

if (replay_it->second == consumer_id) {
if (disjoint_sets.permissiveAreMapped(producer_id, consumer_id)) {
++mismatched_producer_pos;
++mismatched_consumer_pos;
++it_producer;
Expand Down