From fe93bf5a6485696ffb36751606a84080349967b5 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 30 Jun 2022 12:10:28 -0700 Subject: [PATCH] Transform propagator skip replay when possible (#1782) This comment in the code describes what this PR is doing: ```C++ // Note: [Using multiple TransformPropagators] // There are cases that we use multiple TransformPropagators along different // spanning trees with different references in the same fusion. Some of these // spanning trees could overlap. In cases when there are overlapping nodes, // TransformPropagator needs to respect the replay of others, because the // current TransformPropagator might not contain the most amount of // information on how to do the correct transformation. The logic below tells // TransformPropagator to skip the replay when not necessary. ``` --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 102 +---------- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 47 ++++- .../jit/codegen/cuda/transform_replay.cpp | 165 +++++++++++++++++- .../csrc/jit/codegen/cuda/transform_replay.h | 21 +++ 4 files changed, 225 insertions(+), 110 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 5603d5e54d577..b56a779c80364 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -342,96 +342,6 @@ void ComputeAt::runWith( ca.runPass(); } -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); - - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - // Actually applies transformation unsigned int ComputeAt::backwardComputeAt_impl( TensorView* producer, @@ -460,9 +370,11 @@ unsigned int ComputeAt::backwardComputeAt_impl( max_consumer_compute_at_pos); } - // Short cut if no replay is necessary - auto maybe_producer_pos = - skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); + // Checks if producer and consumer are transformed consistently so that to + // satisfy the provided compute at position. This means no replay is actually + // necessary for the compute at requested. + auto maybe_producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( + producer, consumer, consumer_compute_at_pos); if (maybe_producer_pos >= 0) { if (!producer->isFusionInput()) { producer->setComputeAt((unsigned int)maybe_producer_pos); @@ -536,8 +448,8 @@ unsigned int ComputeAt::forwardComputeAt_impl( } // Short cut if no replay is necessary - auto maybe_consumer_pos = - skipReplay(producer, consumer, (int)producer_compute_at_pos, false); + auto maybe_consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( + consumer, producer, producer_compute_at_pos); if (maybe_consumer_pos > -1) { if (!producer->isFusionInput()) { producer->setComputeAt(producer_compute_at_pos); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 2db0d385110ad..d543d8dc356ef 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23710,7 +23710,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - checkSiblingConsistency(t1, t2); + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); } } } @@ -23769,7 +23769,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - checkSiblingConsistency(t1, t2); + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); } } } @@ -23922,7 +23922,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatorSelector) { TORCH_CHECK(tv4->nDims() == 1); } -TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) { +TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -23939,10 +23939,9 @@ TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) { TransformPropagator propagator(tv1, 2); MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator); - TORCH_CHECK(tv0->nDims() == 3); - TORCH_CHECK(tv0->axis(0)->extent()->evaluateInt() == 11); - TORCH_CHECK(tv0->axis(1)->extent()->evaluateInt() == 2); - TORCH_CHECK(tv0->axis(2)->extent()->evaluateInt() == 105); + auto expect = makeConcreteTensor({22, 105}); + expect->split(0, 2); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0)); } TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { @@ -23996,6 +23995,40 @@ to: 2 TORCH_CHECK(printer2.ss.str() == expect); } +TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, true}); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + tv0->split(0, 2); + tv2->split(1, 2); + tv2->split(0, 4); + + MaxRootDomainInfoSpanningTree path1(tv2); + TransformPropagator propagator1(tv2); + path1.traverse(&propagator1); + + MaxRootDomainInfoSpanningTree path2(tv0); + TransformPropagator propagator2(tv0); + path2.traverse(&propagator2); + + TORCH_CHECK(tv1->axis(0)->isBroadcast()); + TORCH_CHECK(tv1->axis(1)->isBroadcast()); + TORCH_CHECK(!tv1->axis(2)->isBroadcast()); + TORCH_CHECK(!tv1->axis(3)->isBroadcast()); + TORCH_CHECK(tv1->axis(4)->isBroadcast()); + + auto expect = makeSymbolicTensor(3); + expect->split(1, 2); + expect->split(0, 4); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1)); +} + TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 9015331a1417b..e961867865181 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -644,24 +644,173 @@ std::pair TransformReplay::replayCasP( return replayCasP(consumer, producer, compute_at_axis, root_map); } +namespace { + +int getMatchedLeafPosWithoutReplay( + const TensorView* producer, + const TensorView* consumer, + int consumer_or_producer_pos, + bool consumer_pos = true) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay"); + + const auto c2p_root_map = + PairwiseRootDomainMap(producer, consumer) + .mapConsumerToProducer(consumer->domain(), producer->domain()); + + // IterDomains in consumer root also in producer root + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); + } + + const auto consumer_domain = consumer->domain()->domain(); + + auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set mapped_consumer_domain_ids( + mapped_consumer_domain_ids_vec.begin(), + mapped_consumer_domain_ids_vec.end()); + + const auto producer_domain = producer->domain()->domain(); + + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); + + auto best_effort_PasC = BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + + auto c2p_map = best_effort_PasC.getReplay(); + + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + auto consumer_id = *it_consumer; + if (!mapped_consumer_domain_ids.count(consumer_id)) { + ++it_consumer; + mismatched_consumer_pos++; + continue; + } + + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + break; + } + + if (it_producer == producer_domain.end()) { + break; + } + + auto producer_id = *it_producer; + + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + if (consumer_pos) { + if (consumer_or_producer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + } else { + if (consumer_or_producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + } + } else { + break; + } + } + return -1; +} + +} // namespace + +int TransformReplay::getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos) { + return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true); +} + +int TransformReplay::getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos) { + return getMatchedLeafPosWithoutReplay( + producer, consumer, producer_pos, false); +} + +bool TransformReplay::fullSelfMatching( + const TensorView* replay, + const TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + if (replay_root.size() != target_root.size()) { + return false; + } + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + if (replay_it == r.end() || replay_it->second != replay_dom[i]) { + return false; + } + } + return true; +} + void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayPasC(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // Note: [Using multiple TransformPropagators] + // There are cases that we use multiple TransformPropagators along different + // spanning trees with different references in the same fusion. Some of these + // spanning trees could overlap. In cases when there are overlapping nodes, + // TransformPropagator needs to respect the replay of others, because the + // current TransformPropagator might not contain the most amount of + // information on how to do the correct transformation. The logic below tells + // TransformPropagator to skip the replay when not necessary. + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + to->setDomain(replay.first); + new_pos = replay.second; + } + replayed_pos_[to] = new_pos; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayCasP(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + to->setDomain(replay.first); + new_pos = replay.second; + } + replayed_pos_[to] = new_pos; } void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - to->setDomain(replay); + // See note [Using multiple TransformPropagators] + if (!TransformReplay::fullSelfMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + to->setDomain(replay); + } replayed_pos_[to] = pos; } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index c24ffa93f2954..d026de618c88f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -154,6 +154,27 @@ class TORCH_CUDA_CU_API TransformReplay { static TensorDomain* fullSelfReplay( const TensorDomain* new_self_root, const TensorDomain* self); + + // Returns the leaf position in producer that matches with `consumer_pos` in + // consumer. Returns -1 if matching is impossible. This function can be used + // to test if replay is needed for getting matching outer dims. + static int getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos); + + // Returns the leaf position in consumer that matches with `producer_pos` in + // producer. Returns -1 if matching is impossible. This function can be used + // to test if replay is needed for getting matching outer dims. + static int getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos); + + // tests if two tensors has fully matching transformations + static bool fullSelfMatching( + const TensorView* replay, + const TensorView* target); }; class TORCH_CUDA_CU_API TransformPropagator