From 133814622ed8e6b92b5dbd9bdcdf81a547e1be5f Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 10:55:33 -0700 Subject: [PATCH 001/100] Compute at refactor --- tools/build_variables.bzl | 1 + torch/csrc/jit/codegen/cuda/compute_at.cpp | 1688 +++++++++-------- torch/csrc/jit/codegen/cuda/compute_at.h | 91 +- torch/csrc/jit/codegen/cuda/consume_at.cpp | 323 ++++ torch/csrc/jit/codegen/cuda/consume_at.h | 51 + .../jit/codegen/cuda/ir_interface_nodes.h | 6 +- .../jit/codegen/cuda/transform_replay.cpp | 65 +- .../csrc/jit/codegen/cuda/transform_replay.h | 46 +- 8 files changed, 1342 insertions(+), 929 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/consume_at.cpp create mode 100644 torch/csrc/jit/codegen/cuda/consume_at.h diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 2190aca6e3b73..3956e89430c86 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -657,6 +657,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/jit/codegen/cuda/arith.cpp", "torch/csrc/jit/codegen/cuda/compute_at.cpp", + "torch/csrc/jit/codegen/cuda/consume_at.cpp", "torch/csrc/jit/codegen/cuda/compute_at_map.cpp", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/contiguity.cpp", diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 77fc513638296..3064477a57690 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include @@ -41,264 +40,247 @@ std::deque> tvChains( return tv_chains; } -bool validateDomain(TensorView* tv, TensorDomain* new_td) { - auto first_mismatch = - BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); - return first_mismatch >= (int)tv->getMaxProducerPosition() && - first_mismatch >= (int)tv->getComputeAtPosition(); -} - -// Return the max position in consumer that producer can be inlined to -// Cannot inline: -// Reduction dimensions in producer -// Block broadcast dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -unsigned int getReplayablePosPasC( - TensorView* producer, - TensorView* consumer, - const std::unordered_set& unmappable_producer_dims, - ComputeAtMode mode) { - // Check if any consumer dimensions are marked as vectorize as producer can - // not be inlined to vectorized dimensions in consumer. - auto c_dom = consumer->domain()->domain(); - auto vector_dim_it = - std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - - // Limit max position based on vectorized dims in consumer. - auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); - - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - auto replay_PasC = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - - // Look for id's that map to a consumer id that's vectorized - auto c2p_replay_map = replay_PasC.getReplay(); - - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); - if (map_it != c2p_replay_map.end()) { - auto p_id = map_it->second; - // If we find a consumer dim that maps to a producer dim that's - // vectorized or unrolled limit max compute at by it. - if (isParallelTypeVectorize(p_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - p_id->getParallelType() == ParallelType::Unroll)) { - max_consumer_pos = consumer_pos - 1; - } - } - } - - // Start at max position and work backwards, try to find a location where - // producer can be inlined. - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - // Grab all root dimensions of consumer as roots must be used to understand - // inlining potential. - auto consumer_root_dim_vals = - IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); - // convert to iter domains - auto consumer_root_dim_ids = - ir_utils::filterByType(consumer_root_dim_vals); - // If any root dimensions cannot be mapped to producer we can't inline. If - // any root dimension - if (std::any_of( - consumer_root_dim_ids.begin(), - consumer_root_dim_ids.end(), - [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { - auto p_root_id_it = c2p_root_map.find(c_root_id); - if (p_root_id_it == c2p_root_map.end()) { - return false; - } - auto p_id = p_root_id_it->second; - return unmappable_producer_dims.find(p_id) != - unmappable_producer_dims.end(); - })) { - continue; - } - return consumer_pos; - } - - return 0; -} - -// Return the max position in producer that can be inlined to consumer -// Cannot inline: -// Reduction dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -unsigned int getReplayablePosCasP( - TensorView* consumer, +// bool validateDomain(TensorView* tv, TensorDomain* new_td) { +// auto first_mismatch = +// BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); +// return first_mismatch >= (int)tv->getMaxProducerPosition() && +// first_mismatch >= (int)tv->getComputeAtPosition(); +// } + +// // Return the max position in consumer that producer can be inlined to +// // Cannot inline: +// // Reduction dimensions in producer +// // Block broadcast dimensions in producer +// // Vectorized dimensions in producer or consumer +// // Unrolled dimensions in producer or consumer +// // Dimensions derived from root dimensions that exist in both but are +// // unmappable +// unsigned int getReplayablePosPasC( +// TensorView* producer, +// TensorView* consumer, +// const std::unordered_set& unmappable_producer_dims, +// ComputeAtMode mode) { +// // Check if any consumer dimensions are marked as vectorize as producer can +// // not be inlined to vectorized dimensions in consumer. +// auto c_dom = consumer->domain()->domain(); +// auto vector_dim_it = +// std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) { +// return isParallelTypeVectorize(id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// id->getParallelType() == ParallelType::Unroll); +// }); + +// // Limit max position based on vectorized dims in consumer. +// auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); + +// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); +// auto c2p_root_map = +// PairwiseRootDomainMap(producer, consumer) +// .mapConsumerToProducer(consumer->domain(), producer->domain()); + +// auto replay_PasC = +// BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); + +// // Look for id's that map to a consumer id that's vectorized +// auto c2p_replay_map = replay_PasC.getReplay(); + +// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; +// consumer_pos--) { +// auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); +// if (map_it != c2p_replay_map.end()) { +// auto p_id = map_it->second; +// // If we find a consumer dim that maps to a producer dim that's +// // vectorized or unrolled limit max compute at by it. +// if (isParallelTypeVectorize(p_id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// p_id->getParallelType() == ParallelType::Unroll)) { +// max_consumer_pos = consumer_pos - 1; +// } +// } +// } + +// // Start at max position and work backwards, try to find a location where +// // producer can be inlined. +// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; +// consumer_pos--) { +// // Grab all root dimensions of consumer as roots must be used to understand +// // inlining potential. +// auto consumer_root_dim_vals = +// IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); +// // convert to iter domains +// auto consumer_root_dim_ids = +// ir_utils::filterByType(consumer_root_dim_vals); +// // If any root dimensions cannot be mapped to producer we can't inline. If +// // any root dimension +// if (std::any_of( +// consumer_root_dim_ids.begin(), +// consumer_root_dim_ids.end(), +// [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { +// auto p_root_id_it = c2p_root_map.find(c_root_id); +// if (p_root_id_it == c2p_root_map.end()) { +// return false; +// } +// auto p_id = p_root_id_it->second; +// return unmappable_producer_dims.find(p_id) != +// unmappable_producer_dims.end(); +// })) { +// continue; +// } +// return consumer_pos; +// } + +// return 0; +// } + +// // Return the max position in producer that can be inlined to consumer +// // Cannot inline: +// // Reduction dimensions in producer +// // Vectorized dimensions in producer or consumer +// // Unrolled dimensions in producer or consumer +// // Dimensions derived from root dimensions that exist in both but are +// // unmappable +// unsigned int getReplayablePosCasP( +// TensorView* consumer, +// TensorView* producer, +// const std::unordered_set& unmappable_producer_dims, +// ComputeAtMode mode) { +// auto p_dom = producer->domain()->domain(); +// auto first_reduction = +// std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { +// return id->isReduction(); +// }); + +// auto first_vectorized_axis = +// std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) { +// return isParallelTypeVectorize(id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// id->getParallelType() == ParallelType::Unroll); +// }); + +// auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); + +// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); +// auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( +// producer->domain(), consumer->domain()); + +// auto replay_CasP = +// BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); + +// // Look for id's that map to a consumer id that's vectorized +// auto p2c_replay_map = replay_CasP.getReplay(); + +// for (size_t producer_pos = max_producer_pos; producer_pos > 0; +// producer_pos--) { +// auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); +// if (map_it != p2c_replay_map.end()) { +// auto c_id = map_it->second; +// // If we find a producer dim that maps to a consumer vectorized or +// // unrolled dim, limit max compute at by it +// if (isParallelTypeVectorize(c_id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// c_id->getParallelType() == ParallelType::Unroll)) { +// max_producer_pos = producer_pos - 1; +// } +// } +// } + +// for (size_t producer_pos = max_producer_pos; producer_pos > 0; +// producer_pos--) { +// auto all_vals = DependencyCheck::getAllValsBetween( +// {producer->getMaybeRFactorDomain().begin(), +// producer->getMaybeRFactorDomain().end()}, +// {p_dom.begin(), p_dom.begin() + producer_pos}); + +// // If any root dims could have mapped to consumer, but don't, then we can't +// // compute at this point +// if (std::any_of( +// producer->getMaybeRFactorDomain().begin(), +// producer->getMaybeRFactorDomain().end(), +// [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { +// return std::find(all_vals.begin(), all_vals.end(), p_root_id) != +// all_vals.end() && +// unmappable_producer_dims.find(p_root_id) != +// unmappable_producer_dims.end(); +// })) { +// continue; +// } + +// return producer_pos; +// } +// return 0; +// } + +// unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) { +// unsigned int ret = tv->getComputeAtPosition(); + +// // Still assuming we only have block broadcast for now. +// // This part may change +// while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) { +// ret--; +// } + +// return ret; +// } + +// // Try to find the aligned position on consumer's domain corresponding to the +// // compute at position of producer domain. Used in computeAt pass only. No +// // checking on actual producer-consumer relationship. +// unsigned int getConsumerPosAlignedToProducerCA( +// TensorView* consumer, +// TensorView* producer) { +// // Locate consumer's position that aligns with +// // the producer's new compute at axis. We need broadcast axes forwarded so we +// // need to replay PasC as CasP will not forward braodcast dims. For example +// // if we have: +// // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) +// // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will +// // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to +// // NVFuserTest.FusionComplexBCast1_CUDA + +// auto c2p_map = +// BestEffortReplay::replayPasC( +// producer, +// consumer, +// -1, +// // Compute at root domain may not be valid here, as all +// // producers don't have to be able to map into consumer at +// // max producer position. Since computeAt should be valid +// // and this mechanism is only intended to lower produce +// // position of consumer, we can simply use the pairwise map. +// PairwiseRootDomainMap(producer, consumer)) +// .getReplay(); + +// // Find the innermost position of consumer that has +// // been mapped within the producer ca axis. +// unsigned int consumer_pos = consumer->nDims(); +// while (consumer_pos > 0) { +// auto consumer_id = consumer->axis((int)consumer_pos - 1); +// auto p_dom = producer->domain()->domain(); +// if (std::any_of( +// p_dom.begin(), +// p_dom.begin() + producer->getComputeAtPosition(), +// [&consumer_id, &c2p_map](IterDomain* p_id) { +// auto c_id_it = c2p_map.find(consumer_id); +// if (c_id_it != c2p_map.end()) { +// return c_id_it->second == p_id; +// } +// return false; +// })) { +// break; +// } +// consumer_pos--; +// } + +// return consumer_pos; +// } + +std::unordered_set getAllTVsBetween( TensorView* producer, - const std::unordered_set& unmappable_producer_dims, - ComputeAtMode mode) { - auto p_dom = producer->domain()->domain(); - auto first_reduction = - std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { - return id->isReduction(); - }); - - auto first_vectorized_axis = - std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - - auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); - - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( - producer->domain(), consumer->domain()); - - auto replay_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - - // Look for id's that map to a consumer id that's vectorized - auto p2c_replay_map = replay_CasP.getReplay(); - - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); - if (map_it != p2c_replay_map.end()) { - auto c_id = map_it->second; - // If we find a producer dim that maps to a consumer vectorized or - // unrolled dim, limit max compute at by it - if (isParallelTypeVectorize(c_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - c_id->getParallelType() == ParallelType::Unroll)) { - max_producer_pos = producer_pos - 1; - } - } - } - - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto all_vals = DependencyCheck::getAllValsBetween( - {producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end()}, - {p_dom.begin(), p_dom.begin() + producer_pos}); - - // If any root dims could have mapped to consumer, but don't, then we can't - // compute at this point - if (std::any_of( - producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end(), - [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { - return std::find(all_vals.begin(), all_vals.end(), p_root_id) != - all_vals.end() && - unmappable_producer_dims.find(p_root_id) != - unmappable_producer_dims.end(); - })) { - continue; - } - - return producer_pos; - } - return 0; -} - -unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) { - unsigned int ret = tv->getComputeAtPosition(); - - // Still assuming we only have block broadcast for now. - // This part may change - while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) { - ret--; - } - - return ret; -} - -// Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. Used in computeAt pass only. No -// checking on actual producer-consumer relationship. -unsigned int getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - auto c2p_map = - BestEffortReplay::replayPasC( - producer, - consumer, - -1, - // Compute at root domain may not be valid here, as all - // producers don't have to be able to map into consumer at - // max producer position. Since computeAt should be valid - // and this mechanism is only intended to lower produce - // position of consumer, we can simply use the pairwise map. - PairwiseRootDomainMap(producer, consumer)) - .getReplay(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - unsigned int consumer_pos = consumer->nDims(); - while (consumer_pos > 0) { - auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto p_dom = producer->domain()->domain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer->getComputeAtPosition(), - [&consumer_id, &c2p_map](IterDomain* p_id) { - auto c_id_it = c2p_map.find(consumer_id); - if (c_id_it != c2p_map.end()) { - return c_id_it->second == p_id; - } - return false; - })) { - break; - } - consumer_pos--; - } - - return consumer_pos; -} - -} // namespace - -void ComputeAt::runAt( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_position, - ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::run"); - - // Make sure the correct fusion is setup between this and consumer. - TORCH_CHECK( - producer->fusion() == consumer->fusion(), - producer, - " and ", - consumer, - " are not in the same fusion."); - - // Make sure Fusion Guard is set appropriately - FusionGuard fg(producer->fusion()); - + TensorView* consumer) { TORCH_CHECK( DependencyCheck::isDependencyOf(producer, consumer), "Compute At expects ", @@ -306,280 +288,19 @@ void ComputeAt::runAt( " is a dependency of ", consumer->name(), ", however it is not."); - - // Run computeAt on our potentially modified producer(s) - ComputeAt ca(producer, consumer, consumer, consumer_position, mode); - ca.runPass(); -} - -void ComputeAt::runWith( - TensorView* producer, - TensorView* consumer, - unsigned int producer_position, - ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::runWith"); - - // Make sure the correct fusion is setup between this and consumer. - TORCH_CHECK( - producer->fusion() == consumer->fusion(), - producer, - " and ", - consumer, - " are not in the same fusion."); - - TORCH_CHECK( - DependencyCheck::isDependencyOf(producer, consumer), - "Compute At expects ", - producer->name(), - " is a dependency of ", - consumer->name(), - ", however it is not."); - - // Make sure Fusion Guard is set appropriately - FusionGuard fg(producer->fusion()); - - ComputeAt ca(producer, consumer, producer, producer_position, mode); - ca.runPass(); -} - -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); - - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - -// Actually applies transformation -unsigned int ComputeAt::backwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_compute_at_pos) { - FUSER_PERF_SCOPE("backwardComputeAt_impl"); - - auto max_consumer_compute_at_pos = - getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); - - if (mode_ == ComputeAtMode::BestEffort) { - consumer_compute_at_pos = - std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - consumer_compute_at_pos = max_consumer_compute_at_pos; - } else { - TORCH_INTERNAL_ASSERT( - consumer_compute_at_pos <= max_consumer_compute_at_pos, - "Invalid compute at position detected in compute at when trying to replay producer: ", - producer, - " as consumer: ", - consumer, - " tried to do this at position: ", - consumer_compute_at_pos, - " but max position that's allowed is ", - max_consumer_compute_at_pos); - } - - // Short cut if no replay is necessary - auto maybe_producer_pos = - skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); - if (maybe_producer_pos >= 0) { - if (!producer->isFusionInput()) { - producer->setComputeAt((unsigned int)maybe_producer_pos); - } - consumer->setMaxProducer(consumer_compute_at_pos); - return (unsigned int)maybe_producer_pos; - } - - auto replay_producer_pair = TransformReplay::replayPasC( - producer, - consumer, - (int)consumer_compute_at_pos, - PairwiseRootDomainMap(producer, consumer)); - - if (replay_producer_pair.second == 0) { - return 0; - } - - if (replay_producer_pair.second >= producer->getComputeAtPosition()) { - const TensorDomain* current_domain = producer->domain(); - TensorDomain* new_domain = replay_producer_pair.first; - - TORCH_INTERNAL_ASSERT( - validateDomain(producer, new_domain), - "Tried to set the domain of ", - producer, - " to ", - new_domain, - " but that would invalidate previously compute at position or max producer position."); - - producer->setDomain(new_domain); - if (!producer->isFusionInput()) { - producer->setComputeAt(replay_producer_pair.second); - } - - consumer->setMaxProducer(consumer_compute_at_pos); - root_map_.setAlias(current_domain, new_domain); - } - - return replay_producer_pair.second; + auto between_vals = DependencyCheck::getAllValsBetween({producer}, {consumer}); + auto between_tvs = ir_utils::filterByType(between_vals); + std::unordered_set result(between_tvs.begin(), between_tvs.end()); + result.erase(consumer); + return result; } -// Actually applies transformation, replay consumer based on producer, set -// compute at of producer, set pass position of consumer, return position -// relative to consumer -unsigned int ComputeAt::forwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int producer_compute_at_pos) { - FUSER_PERF_SCOPE("forwardComputeAt_impl"); - - auto max_producer_compute_at_pos = - getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); - - if (mode_ == ComputeAtMode::BestEffort) { - producer_compute_at_pos = - std::min(producer_compute_at_pos, max_producer_compute_at_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - producer_compute_at_pos = max_producer_compute_at_pos; - } else { - TORCH_INTERNAL_ASSERT( - producer_compute_at_pos <= max_producer_compute_at_pos, - "Invalid compute at position detected in compute at when trying to replay consumer: ", - consumer, - " as producer: ", - producer, - " tried to do this at position: ", - producer_compute_at_pos, - " but max position that's allowed is ", - max_producer_compute_at_pos); - } - - // Short cut if no replay is necessary - auto maybe_consumer_pos = - skipReplay(producer, consumer, (int)producer_compute_at_pos, false); - if (maybe_consumer_pos > -1) { - if (!producer->isFusionInput()) { - producer->setComputeAt(producer_compute_at_pos); - } - consumer->setMaxProducer((unsigned int)maybe_consumer_pos); - return (unsigned int)maybe_consumer_pos; - } - - auto replay_consumer_pair = TransformReplay::replayCasP( - consumer, - producer, - (int)producer_compute_at_pos, - PairwiseRootDomainMap(producer, consumer)); - - if (producer_compute_at_pos > producer->getComputeAtPosition()) { - if (!producer->isFusionInput()) { - producer->setComputeAt((int)producer_compute_at_pos); - } - } - - if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) { - const TensorDomain* current_domain = consumer->domain(); - TensorDomain* new_domain = replay_consumer_pair.first; - - TORCH_INTERNAL_ASSERT( - validateDomain(consumer, new_domain), - "Tried to set the domain of ", - consumer, - " to ", - new_domain, - " but that would invalidate previously compute at position or max producer position."); - - consumer->setDomain(new_domain); - consumer->setMaxProducer(replay_consumer_pair.second); - root_map_.setAlias(current_domain, new_domain); - } - - return replay_consumer_pair.second; -} - -void ComputeAt::setCommonConsumer() { +TensorView * getCommonConsumer( + TensorView *producer, + TensorView *consumer +) { FUSER_PERF_SCOPE("ComputeAt::setCommonConsumer"); + auto producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer)); // Convert the first chain to a set. std::set common_consumers( @@ -594,337 +315,670 @@ void ComputeAt::setCommonConsumer() { } auto all_chains = - tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); + tvChains(DependencyCheck::getAllDependencyChains(producer, consumer)); // Right now we only support compute at if at some point in the graph consumer // is dependent on producer. TORCH_CHECK( !all_chains.empty(), "Compute At expects ", - producer_->name(), + producer->name(), " is a dependency of ", - consumer_->name(), + consumer->name(), ", however it is not."); // Remove all TVs from producer to consumer as common consumer must be at or // after consumer for (const auto& tv_chain : all_chains) { for (auto tv : tv_chain) { - if (tv != consumer_) + if (tv != consumer) common_consumers.erase(tv); } } // If there is a common consumer, grab the first one at or after consumer - common_consumer_ = nullptr; + TensorView *common_consumer = nullptr; if (!common_consumers.empty()) { for (auto tv : producer_use_chains_.front()) { if (common_consumers.find(tv) != common_consumers.end()) { - common_consumer_ = tv; + common_consumer = tv; break; } } TORCH_INTERNAL_ASSERT( - common_consumer_ != nullptr, + common_consumer != nullptr, "Hit a logical inconsistency in the computeAt pass."); } + return common_consumer; } -// Similar to backward traversal in traverseAllKnown but we should only apply -// computeAt if it will increase computeAt positions. -void ComputeAt::traverseBackward() { - FUSER_PERF_SCOPE("ComputeAt::traverseBackward"); - if (reference_ == producer_) { - // Forward compute at don't need to run backward traversal - producer_position_ = reference_position_; - return; - } - - // propagate *backward* through all *producer* use_chains or from *producer* - // to common_consumer if common_consumer exists. Only apply transform if - // increases computeAt position. - auto chains = - tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); - - for (auto tv_chain : chains) { - TensorView* running_producer = tv_chain.back(); - TensorView* running_consumer = nullptr; - unsigned int running_consumer_pos = reference_position_; - tv_chain.pop_back(); - - TORCH_INTERNAL_ASSERT(running_producer == consumer_); - - while (!tv_chain.empty()) { - running_consumer = running_producer; - running_producer = tv_chain.back(); - tv_chain.pop_back(); - running_consumer_pos = backwardComputeAt_impl( - running_producer, running_consumer, running_consumer_pos); - } - - TORCH_INTERNAL_ASSERT( - running_producer == producer_, - "Compute at backward traversal ended up on something other than the producer."); - producer_position_ = running_consumer_pos; - } -} - -void ComputeAt::traverseForward() { - FUSER_PERF_SCOPE("ComputeAt::traverseForward"); - - // propagate forward through all *producer* use_chains or from *producer* to - // common_consumer if common_consumer exists. - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } - - // propagate forward through all chains - for (auto tv_dep_chain : chains) { - TensorView* running_producer = nullptr; - TensorView* running_consumer = tv_dep_chain.front(); - tv_dep_chain.pop_front(); - unsigned int running_producer_pos = producer_position_; - - TORCH_INTERNAL_ASSERT(running_consumer == producer_); - - while (!tv_dep_chain.empty()) { - running_producer = running_consumer; - running_consumer = tv_dep_chain.front(); - tv_dep_chain.pop_front(); - running_producer_pos = forwardComputeAt_impl( - running_producer, running_consumer, running_producer_pos); - } - } -} - -void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) { - if (consumer_tv->definition() == nullptr) { - consumer_tv->setMaxProducer(0, true); - } - - unsigned int new_consummer_pa_pos = 0; - - // Re-compute the max producer position as one or more - // of the producers of this consumer have updated their - // compute at position. - for (auto inp : ir_utils::producerTvsOf(consumer_tv)) { - if (!inp->isFusionInput()) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. - unsigned int inp_ca_pos_to_consumer = - getConsumerPosAlignedToProducerCA(consumer_tv, inp); - - // Populate the max consumer position required by - // producer compute at. - new_consummer_pa_pos = - std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer); - } - } - - consumer_tv->setMaxProducer(new_consummer_pa_pos, true); -} - -void ComputeAt::hoistInnermostBroadcast() { - auto fusion = producer_->fusion(); - - std::unordered_set consumers_to_update; - auto all_vals = fusion->usedMathVals(); - auto all_tvs = ir_utils::filterByType(all_vals); - - for (auto running_producer : all_tvs) { - if (!running_producer->isFusionInput()) { - auto producer_ca_pos = running_producer->getComputeAtPosition(); - // Find the innermost iterdomain that is not a broadcast - auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer); - // Update the compute at pos of this producer if the original - // compute at is within inner most broadcast axes - if (new_ca_pos < producer_ca_pos) { - running_producer->setComputeAt(new_ca_pos, true); - } - // Mark all consumers of this producer for later produce - // position update. - // This is safe with segmented fusion. TV uses will reset - // when FusionSegmentGuard try to change the IO. - auto tv_consumers = ir_utils::consumerTvsOf(running_producer); - consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end()); - } +// I am just trying to get the same set of tensors being transformed matching +// the previous behavior of ComputeAt. The algorithm to compute this set is +// horrible, but I don't care because I will eventually completely remove ComputeAt, +// and this algorihtm is not worse than the pervious ComputeAt. :) +std::unordered_set getPropagationSubgraph( + TensorView* producer, + TensorView* consumer) { + TORCH_CHECK( + DependencyCheck::isDependencyOf(producer, consumer), + "Compute At expects ", + producer->name(), + " is a dependency of ", + consumer->name(), + ", however it is not."); + TensorView *common_consumer = getCommonConsumer(producer, consumer); + if (common_consumer != nullptr) { + return getAllTVsBetween(producer, consumer); } + auto result_vals = DependencyCheck::getAllDependentVals({producer}); + auto result_tvs = ir_utils::filterByType(result_vals); + return {result_tvs.begin(), result_tvs.end()}; } -void ComputeAt::updateSiblings() { - // Track which consumers may have a wrong produce at position to update - // later - auto updateSiblingsOfTv = [&](TensorView* tv) { - if (tv->definition() == nullptr) { - return; - } - - std::unordered_set consumers_to_update; - - if (tv->definition()->outputs().size() > 1) { - auto outs = tv->definition()->outputs(); - auto out_tvs = ir_utils::filterByType(outs); - for (auto sibling_tv : out_tvs) { - if (sibling_tv == tv) { - continue; - } - - std::unordered_map tv_to_sibling_map; - TORCH_INTERNAL_ASSERT( - tv->getRootDomain().size() == sibling_tv->getRootDomain().size(), - "Error replaying multiple output expressions in computeAt."); - - // Propagate any root parallelization as fullSelfReplay expects it. - for (const auto i : c10::irange(sibling_tv->getRootDomain().size())) { - auto id = tv->getRootDomain()[i]; - auto sibling_id = sibling_tv->getRootDomain()[i]; - if (id->getParallelType() != ParallelType::Serial && - sibling_id->getParallelType() == ParallelType::Serial) { - sibling_id->parallelize(id->getParallelType()); - } else if ( - id->getParallelType() == ParallelType::Serial && - sibling_id->getParallelType() != ParallelType::Serial) { - id->parallelize(sibling_id->getParallelType()); - } - } - auto sibling_domain = - TransformReplay::fullSelfReplay(sibling_tv->domain(), tv->domain()); - validateDomain(sibling_tv, sibling_domain); - sibling_tv->setDomain(sibling_domain); - sibling_tv->setComputeAt(tv->getComputeAtPosition()); - sibling_tv->setMaxProducer(tv->getMaxProducerPosition()); - auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv); - consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end()); - } - } - - // Update sibling consumer tv's max producer position - for (auto consumer : consumers_to_update) { - this->resetMaxProducerPos(consumer); - } - }; +} // namespace - // Find all tensor views that may have been modified - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } +void ComputeAt::runAt( + TensorView* producer, + TensorView* consumer, + unsigned int consumer_position, + ComputeAtMode mode) { + FUSER_PERF_SCOPE("ComputeAt::run"); - std::unordered_set participating_tvs; - for (auto chain : chains) { - participating_tvs.insert(chain.begin(), chain.end()); - } + // Make sure the correct fusion is setup between this and consumer. + TORCH_CHECK( + producer->fusion() == consumer->fusion(), + producer, + " and ", + consumer, + " are not in the same fusion."); - for (auto tv : participating_tvs) { - updateSiblingsOfTv(tv); - } + FusionGuard fg(producer->fusion()); + ConsumeAt::consumeAllAt( + getPropagationSubgraph(producer, consumer), + consumer, + consumer_position, + mode); } -void ComputeAt::runPass() { - FUSER_PERF_SCOPE("ComputeAt::runPass"); - - // Traverse backward through all dep chains from producer to consumer - traverseBackward(); - - // Start at producer and traverse forward through all chains - traverseForward(); - - // Back off on inlining the inner broadcast axes - hoistInnermostBroadcast(); - - // Update siblings of multi output expressions - updateSiblings(); - - // Update the compute at position of all consumers, this used to be done - // during the compute at pass itself, but its cleaner to do this as a cleanup - // pass similar to hoistInnermostBroadcast and updateSiblings. - std::unordered_set all_consumers; - - // Find all tensor views that may have been modified - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } +void ComputeAt::runWith( + TensorView* producer, + TensorView* consumer, + unsigned int producer_position, + ComputeAtMode mode) { + FUSER_PERF_SCOPE("ComputeAt::runWith"); - for (const auto& chain : chains) { - for (auto tv : chain) { - all_consumers.emplace(tv); - } - } + // Make sure the correct fusion is setup between this and consumer. + TORCH_CHECK( + producer->fusion() == consumer->fusion(), + producer, + " and ", + consumer, + " are not in the same fusion."); - // Reset max producer position of all tensor views. - for (auto tv : all_consumers) { - resetMaxProducerPos(tv); - } -} + // Make sure Fusion Guard is set appropriately + FusionGuard fg(producer->fusion()); -void ComputeAt::buildUnmappableDims() { - auto all_tvs = ir_utils::allTvs(producer_->fusion()); - for (auto tv : all_tvs) { - auto consumers = ir_utils::consumerTvsOf(tv); - for (auto consumer : consumers) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline non-trivial - // reduction structures. - auto mappable_roots = - root_map_.getMappableDims(tv->domain(), consumer->domain()); - for (auto tv_root_id : tv->getMaybeRFactorDomain()) { - if (mappable_roots.find(tv_root_id) == mappable_roots.end() && - !tv_root_id->isTrivialReduction()) { - unmappable_dims_.emplace(tv_root_id); - } - } - } - } + ConsumeAt::consumeAllAt( + getPropagationSubgraph(producer, consumer), + producer, + producer_position, + mode); } -ComputeAt::ComputeAt( - TensorView* _producer, - TensorView* _consumer, - TensorView* _reference, - unsigned int _reference_position, - ComputeAtMode _mode) - : producer_(_producer), - consumer_(_consumer), - reference_(_reference), - reference_position_(_reference_position), - mode_(_mode) { - TORCH_INTERNAL_ASSERT( - reference_ == producer_ || reference_ == consumer_, - "For compute at reference must be producer or consumer, it's neither.", - " reference: ", - reference_, - " consumer: ", - consumer_, - " producer: ", - producer_); - TORCH_INTERNAL_ASSERT( - reference_position_ >= 0 && reference_position_ <= reference_->nDims(), - "Invalid computeAt axis, received ", - reference_position_, - " but should be > -", - reference_->nDims(), - " and <= ", - reference_->nDims(), - "."); - - producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_)); - - // Look through all the use chains of producer. Check if there's a single - // consumer for all chains at or after the consumer specified in the computeAt - // call. - setCommonConsumer(); - - root_map_.build(); - - buildUnmappableDims(); -} +// namespace { + +// // Checks if producer and consumer are transformed consistently so that to +// // satisfy the provided compute at position. This means no replay is actually +// // necessary for the compute at requested. If consumer_pos then +// // consumer_or_producer_pos is relative to the consumer and skipReplay returns +// // the associated position in producer. +// // +// // If producer and consumer are not transformed consistently with provided +// // postition, returns -1. +// int skipReplay( +// const TensorView* producer, +// const TensorView* consumer, +// int consumer_or_producer_pos, +// bool consumer_pos = true) { +// FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); + +// const auto c2p_root_map = +// PairwiseRootDomainMap(producer, consumer) +// .mapConsumerToProducer(consumer->domain(), producer->domain()); + +// // IterDomains in consumer root also in producer root +// std::unordered_set mapped_consumer_roots; +// for (auto entry : c2p_root_map) { +// mapped_consumer_roots.emplace(entry.first); +// } + +// const auto consumer_domain = consumer->domain()->domain(); + +// auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( +// mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + +// std::unordered_set mapped_consumer_domain_ids( +// mapped_consumer_domain_ids_vec.begin(), +// mapped_consumer_domain_ids_vec.end()); + +// const auto producer_domain = producer->domain()->domain(); + +// auto it_consumer = consumer_domain.begin(); +// auto it_producer = producer_domain.begin(); + +// auto best_effort_PasC = BestEffortReplay::replayPasC( +// producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + +// auto c2p_map = best_effort_PasC.getReplay(); + +// int mismatched_consumer_pos = 0; +// int mismatched_producer_pos = 0; +// while (it_consumer != consumer_domain.end()) { +// auto consumer_id = *it_consumer; +// if (!mapped_consumer_domain_ids.count(consumer_id)) { +// ++it_consumer; +// mismatched_consumer_pos++; +// continue; +// } + +// auto c2p_it = c2p_map.find(consumer_id); +// if (c2p_it == c2p_map.end()) { +// break; +// } + +// if (it_producer == producer_domain.end()) { +// break; +// } + +// auto producer_id = *it_producer; + +// if (c2p_it->second == producer_id) { +// ++mismatched_consumer_pos; +// ++mismatched_producer_pos; +// ++it_consumer; +// ++it_producer; +// if (consumer_pos) { +// if (consumer_or_producer_pos == mismatched_consumer_pos) { +// return mismatched_producer_pos; +// } +// } else { +// if (consumer_or_producer_pos == mismatched_producer_pos) { +// return mismatched_consumer_pos; +// } +// } +// } else { +// break; +// } +// } +// return -1; +// } + +// } // namespace + +// // Actually applies transformation +// unsigned int ComputeAt::backwardComputeAt_impl( +// TensorView* producer, +// TensorView* consumer, +// unsigned int consumer_compute_at_pos) { +// FUSER_PERF_SCOPE("backwardComputeAt_impl"); + +// auto max_consumer_compute_at_pos = +// getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); + +// if (mode_ == ComputeAtMode::BestEffort) { +// consumer_compute_at_pos = +// std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); +// } else if (mode_ == ComputeAtMode::MostInlined) { +// consumer_compute_at_pos = max_consumer_compute_at_pos; +// } else { +// TORCH_INTERNAL_ASSERT( +// consumer_compute_at_pos <= max_consumer_compute_at_pos, +// "Invalid compute at position detected in compute at when trying to replay producer: ", +// producer, +// " as consumer: ", +// consumer, +// " tried to do this at position: ", +// consumer_compute_at_pos, +// " but max position that's allowed is ", +// max_consumer_compute_at_pos); +// } + +// // Short cut if no replay is necessary +// auto maybe_producer_pos = +// skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); +// if (maybe_producer_pos >= 0) { +// if (!producer->isFusionInput()) { +// producer->setComputeAt((unsigned int)maybe_producer_pos); +// } +// consumer->setMaxProducer(consumer_compute_at_pos); +// return (unsigned int)maybe_producer_pos; +// } + +// auto replay_producer_pair = TransformReplay::replayPasC( +// producer, +// consumer, +// (int)consumer_compute_at_pos, +// PairwiseRootDomainMap(producer, consumer)); + +// if (replay_producer_pair.second == 0) { +// return 0; +// } + +// if (replay_producer_pair.second >= producer->getComputeAtPosition()) { +// const TensorDomain* current_domain = producer->domain(); +// TensorDomain* new_domain = replay_producer_pair.first; + +// TORCH_INTERNAL_ASSERT( +// validateDomain(producer, new_domain), +// "Tried to set the domain of ", +// producer, +// " to ", +// new_domain, +// " but that would invalidate previously compute at position or max producer position."); + +// producer->setDomain(new_domain); +// if (!producer->isFusionInput()) { +// producer->setComputeAt(replay_producer_pair.second); +// } + +// consumer->setMaxProducer(consumer_compute_at_pos); +// root_map_.setAlias(current_domain, new_domain); +// } + +// return replay_producer_pair.second; +// } + +// // Actually applies transformation, replay consumer based on producer, set +// // compute at of producer, set pass position of consumer, return position +// // relative to consumer +// unsigned int ComputeAt::forwardComputeAt_impl( +// TensorView* producer, +// TensorView* consumer, +// unsigned int producer_compute_at_pos) { +// FUSER_PERF_SCOPE("forwardComputeAt_impl"); + +// auto max_producer_compute_at_pos = +// getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); + +// if (mode_ == ComputeAtMode::BestEffort) { +// producer_compute_at_pos = +// std::min(producer_compute_at_pos, max_producer_compute_at_pos); +// } else if (mode_ == ComputeAtMode::MostInlined) { +// producer_compute_at_pos = max_producer_compute_at_pos; +// } else { +// TORCH_INTERNAL_ASSERT( +// producer_compute_at_pos <= max_producer_compute_at_pos, +// "Invalid compute at position detected in compute at when trying to replay consumer: ", +// consumer, +// " as producer: ", +// producer, +// " tried to do this at position: ", +// producer_compute_at_pos, +// " but max position that's allowed is ", +// max_producer_compute_at_pos); +// } + +// // Short cut if no replay is necessary +// auto maybe_consumer_pos = +// skipReplay(producer, consumer, (int)producer_compute_at_pos, false); +// if (maybe_consumer_pos > -1) { +// if (!producer->isFusionInput()) { +// producer->setComputeAt(producer_compute_at_pos); +// } +// consumer->setMaxProducer((unsigned int)maybe_consumer_pos); +// return (unsigned int)maybe_consumer_pos; +// } + +// auto replay_consumer_pair = TransformReplay::replayCasP( +// consumer, +// producer, +// (int)producer_compute_at_pos, +// PairwiseRootDomainMap(producer, consumer)); + +// if (producer_compute_at_pos > producer->getComputeAtPosition()) { +// if (!producer->isFusionInput()) { +// producer->setComputeAt((int)producer_compute_at_pos); +// } +// } + +// if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) { +// const TensorDomain* current_domain = consumer->domain(); +// TensorDomain* new_domain = replay_consumer_pair.first; + +// TORCH_INTERNAL_ASSERT( +// validateDomain(consumer, new_domain), +// "Tried to set the domain of ", +// consumer, +// " to ", +// new_domain, +// " but that would invalidate previously compute at position or max producer position."); + +// consumer->setDomain(new_domain); +// consumer->setMaxProducer(replay_consumer_pair.second); +// root_map_.setAlias(current_domain, new_domain); +// } + +// return replay_consumer_pair.second; +// } + +// // Similar to backward traversal in traverseAllKnown but we should only apply +// // computeAt if it will increase computeAt positions. +// void ComputeAt::traverseBackward() { +// FUSER_PERF_SCOPE("ComputeAt::traverseBackward"); +// if (reference_ == producer_) { +// // Forward compute at don't need to run backward traversal +// producer_position_ = reference_position_; +// return; +// } + +// // propagate *backward* through all *producer* use_chains or from *producer* +// // to common_consumer if common_consumer exists. Only apply transform if +// // increases computeAt position. +// auto chains = +// tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); + +// for (auto tv_chain : chains) { +// TensorView* running_producer = tv_chain.back(); +// TensorView* running_consumer = nullptr; +// unsigned int running_consumer_pos = reference_position_; +// tv_chain.pop_back(); + +// TORCH_INTERNAL_ASSERT(running_producer == consumer_); + +// while (!tv_chain.empty()) { +// running_consumer = running_producer; +// running_producer = tv_chain.back(); +// tv_chain.pop_back(); +// running_consumer_pos = backwardComputeAt_impl( +// running_producer, running_consumer, running_consumer_pos); +// } + +// TORCH_INTERNAL_ASSERT( +// running_producer == producer_, +// "Compute at backward traversal ended up on something other than the producer."); +// producer_position_ = running_consumer_pos; +// } +// } + +// void ComputeAt::traverseForward() { +// FUSER_PERF_SCOPE("ComputeAt::traverseForward"); + +// // propagate forward through all *producer* use_chains or from *producer* to +// // common_consumer if common_consumer exists. +// auto chains = producer_use_chains_; +// if (common_consumer_ != nullptr) { +// chains = tvChains( +// DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); +// } + +// // propagate forward through all chains +// for (auto tv_dep_chain : chains) { +// TensorView* running_producer = nullptr; +// TensorView* running_consumer = tv_dep_chain.front(); +// tv_dep_chain.pop_front(); +// unsigned int running_producer_pos = producer_position_; + +// TORCH_INTERNAL_ASSERT(running_consumer == producer_); + +// while (!tv_dep_chain.empty()) { +// running_producer = running_consumer; +// running_consumer = tv_dep_chain.front(); +// tv_dep_chain.pop_front(); +// running_producer_pos = forwardComputeAt_impl( +// running_producer, running_consumer, running_producer_pos); +// } +// } +// } + +// void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) { +// if (consumer_tv->definition() == nullptr) { +// consumer_tv->setMaxProducer(0, true); +// } + +// unsigned int new_consummer_pa_pos = 0; + +// // Re-compute the max producer position as one or more +// // of the producers of this consumer have updated their +// // compute at position. +// for (auto inp : ir_utils::producerTvsOf(consumer_tv)) { +// if (!inp->isFusionInput()) { +// // Locate consumer's position that aligns with +// // the producer's new compute at axis. +// unsigned int inp_ca_pos_to_consumer = +// getConsumerPosAlignedToProducerCA(consumer_tv, inp); + +// // Populate the max consumer position required by +// // producer compute at. +// new_consummer_pa_pos = +// std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer); +// } +// } + +// consumer_tv->setMaxProducer(new_consummer_pa_pos, true); +// } + +// void ComputeAt::hoistInnermostBroadcast() { +// auto fusion = producer_->fusion(); + +// std::unordered_set consumers_to_update; + +// auto all_vals = fusion->usedMathVals(); +// auto all_tvs = ir_utils::filterByType(all_vals); + +// for (auto running_producer : all_tvs) { +// if (!running_producer->isFusionInput()) { +// auto producer_ca_pos = running_producer->getComputeAtPosition(); +// // Find the innermost iterdomain that is not a broadcast +// auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer); +// // Update the compute at pos of this producer if the original +// // compute at is within inner most broadcast axes +// if (new_ca_pos < producer_ca_pos) { +// running_producer->setComputeAt(new_ca_pos, true); +// } +// // Mark all consumers of this producer for later produce +// // position update. +// // This is safe with segmented fusion. TV uses will reset +// // when FusionSegmentGuard try to change the IO. +// auto tv_consumers = ir_utils::consumerTvsOf(running_producer); +// consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end()); +// } +// } +// } + +// void ComputeAt::updateSiblings() { +// // Track which consumers may have a wrong produce at position to update +// // later +// auto updateSiblingsOfTv = [&](TensorView* tv) { +// if (tv->definition() == nullptr) { +// return; +// } + +// std::unordered_set consumers_to_update; + +// if (tv->definition()->outputs().size() > 1) { +// auto outs = tv->definition()->outputs(); +// auto out_tvs = ir_utils::filterByType(outs); +// for (auto sibling_tv : out_tvs) { +// if (sibling_tv == tv) { +// continue; +// } + +// std::unordered_map tv_to_sibling_map; +// TORCH_INTERNAL_ASSERT( +// tv->getRootDomain().size() == sibling_tv->getRootDomain().size(), +// "Error replaying multiple output expressions in computeAt."); + +// // Propagate any root parallelization as fullSelfReplay expects it. +// for (const auto i : c10::irange(sibling_tv->getRootDomain().size())) { +// auto id = tv->getRootDomain()[i]; +// auto sibling_id = sibling_tv->getRootDomain()[i]; +// if (id->getParallelType() != ParallelType::Serial && +// sibling_id->getParallelType() == ParallelType::Serial) { +// sibling_id->parallelize(id->getParallelType()); +// } else if ( +// id->getParallelType() == ParallelType::Serial && +// sibling_id->getParallelType() != ParallelType::Serial) { +// id->parallelize(sibling_id->getParallelType()); +// } +// } +// auto sibling_domain = +// TransformReplay::fullSelfReplay(sibling_tv->domain(), tv->domain()); +// validateDomain(sibling_tv, sibling_domain); +// sibling_tv->setDomain(sibling_domain); +// sibling_tv->setComputeAt(tv->getComputeAtPosition()); +// sibling_tv->setMaxProducer(tv->getMaxProducerPosition()); +// auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv); +// consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end()); +// } +// } + +// // Update sibling consumer tv's max producer position +// for (auto consumer : consumers_to_update) { +// this->resetMaxProducerPos(consumer); +// } +// }; + +// // Find all tensor views that may have been modified +// auto chains = producer_use_chains_; +// if (common_consumer_ != nullptr) { +// chains = tvChains( +// DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); +// } + +// std::unordered_set participating_tvs; +// for (auto chain : chains) { +// participating_tvs.insert(chain.begin(), chain.end()); +// } + +// for (auto tv : participating_tvs) { +// updateSiblingsOfTv(tv); +// } +// } + +// void ComputeAt::runPass() { +// FUSER_PERF_SCOPE("ComputeAt::runPass"); + + // // Traverse backward through all dep chains from producer to consumer + // traverseBackward(); + + // // Start at producer and traverse forward through all chains + // traverseForward(); + + // // Back off on inlining the inner broadcast axes + // hoistInnermostBroadcast(); + + // // Update siblings of multi output expressions + // updateSiblings(); + + // // Update the compute at position of all consumers, this used to be done + // // during the compute at pass itself, but its cleaner to do this as a cleanup + // // pass similar to hoistInnermostBroadcast and updateSiblings. + // std::unordered_set all_consumers; + + // // Find all tensor views that may have been modified + // auto chains = producer_use_chains_; + // if (common_consumer_ != nullptr) { + // chains = tvChains( + // DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); + // } + + // for (const auto& chain : chains) { + // for (auto tv : chain) { + // all_consumers.emplace(tv); + // } + // } + + // // Reset max producer position of all tensor views. + // for (auto tv : all_consumers) { + // resetMaxProducerPos(tv); + // } +// } + +// void ComputeAt::buildUnmappableDims() { +// auto all_tvs = ir_utils::allTvs(producer_->fusion()); +// for (auto tv : all_tvs) { +// auto consumers = ir_utils::consumerTvsOf(tv); +// for (auto consumer : consumers) { +// // Grab dimensions in producer and consumer that are mappable to eachother +// // based on the computeAtRootDomainMap. This will tell us which dimensions +// // can be inlined based on avoiding trying to inline non-trivial +// // reduction structures. +// auto mappable_roots = +// root_map_.getMappableDims(tv->domain(), consumer->domain()); +// for (auto tv_root_id : tv->getMaybeRFactorDomain()) { +// if (mappable_roots.find(tv_root_id) == mappable_roots.end() && +// !tv_root_id->isTrivialReduction()) { +// unmappable_dims_.emplace(tv_root_id); +// } +// } +// } +// } +// } + +// ComputeAt::ComputeAt( +// TensorView* _producer, +// TensorView* _consumer, +// TensorView* _reference, +// unsigned int _reference_position, +// ComputeAtMode _mode) +// : producer_(_producer), +// consumer_(_consumer), +// reference_(_reference), +// reference_position_(_reference_position), +// mode_(_mode) { +// TORCH_INTERNAL_ASSERT( +// reference_ == producer_ || reference_ == consumer_, +// "For compute at reference must be producer or consumer, it's neither.", +// " reference: ", +// reference_, +// " consumer: ", +// consumer_, +// " producer: ", +// producer_); +// TORCH_INTERNAL_ASSERT( +// reference_position_ >= 0 && reference_position_ <= reference_->nDims(), +// "Invalid computeAt axis, received ", +// reference_position_, +// " but should be > -", +// reference_->nDims(), +// " and <= ", +// reference_->nDims(), +// "."); + +// producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_)); + +// // Look through all the use chains of producer. Check if there's a single +// // consumer for all chains at or after the consumer specified in the computeAt +// // call. +// setCommonConsumer(); + +// root_map_.build(); + +// buildUnmappableDims(); +// } + +// ComputeAt::ComputeAt( +// std::unordered_set subgraph, +// TensorView* reference, +// unsigned int reference_position, +// ComputeAtMode mode) +// : TransformPropagator(reference, reference_position), +// subgraph_(std::move(subgraph)), +// mode_(mode) { +// TORCH_INTERNAL_ASSERT( +// subgraph_.count(reference), +// "Reference must be within subgraph."); +// TORCH_INTERNAL_ASSERT( +// reference_position >= 0 && reference_position <= reference->nDims(), +// "Invalid computeAt axis, received ", +// reference_position, +// " but should be > -", +// reference->nDims(), +// " and <= ", +// reference->nDims(), +// "."); +// } } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 75fca5705ed9e..9462571c7963d 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include @@ -18,7 +20,7 @@ namespace cuda { class TensorDomain; class TensorView; -class ComputeAt { +struct ComputeAt { public: // Runs the compute at pass making producer look like consumer, computing // producer relative to consumer @@ -35,93 +37,6 @@ class ComputeAt { TensorView* consumer, unsigned int producer_position, ComputeAtMode mode = ComputeAtMode::Standard); - - ComputeAt() = delete; - ComputeAt(ComputeAt&) = delete; - ComputeAt& operator=(const ComputeAt& other) = delete; - - private: - TensorView* producer_; - TensorView* consumer_; - TensorView* reference_; - unsigned int reference_position_; - ComputeAtMode mode_ = ComputeAtMode::Standard; - - unsigned int producer_position_ = 0; - ComputeAtRootDomainMap root_map_; - - // Runs replayPasC and sets producer computeAt settings. Returns - // producer_compute_at_pos. - unsigned int backwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_compute_at_pos); - - // Runs replayCasP and sets producer computeAt settings. Returns - // consumer_compute_at_pos. - unsigned int forwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int producer_compute_at_pos); - - // Look through all the use chains of producer. Check if there's a single - // consumer for all chains at or after the consumer specified in the computeAt - // call. - void setCommonConsumer(); - - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); - - // Propagate backward from consumer to producer, check if it increase - // computeAt position on tensors, if so take it! - void traverseBackward(); - - // Traverse from producer to common_consumer if it exists or through all uses - // of producer - void traverseForward(); - - // Looks at producer tensor views of consumer_tv, recomputes its max - // producer position, and sets max producer position. This function can - // only potentially lower the max producer position of consumer_tv. - void resetMaxProducerPos(TensorView* consumer_tv); - - // Undo the inlining of block broadcast at the innermost positions - // to avoid generating repeated block broadcasts - void hoistInnermostBroadcast(); - - // Update multi-output expressions. If one output is modified, all outputs - // should be modified as well. Propagate transformations, compute at, and - // produce at from tv to siblings. Run as final pass as it will invalidate the - // computeAt map originally computed. - void updateSiblings(); - - // Compute at pass requires tracking "maxProducerPosition" even if set simply - // from input tensor views. However, when lowering, we need a valid produce at - // position of all tensors, so inputs should never actually set their - // consumers maxProduceAt position. - void updateInputProduceAts(); - - // Run the computeAt pass - void runPass(); - - // Common consumer if it exists - TensorView* common_consumer_ = nullptr; - - // Producer use chains set in, used in a few spots. - std::deque> producer_use_chains_; - - // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims_; - - ComputeAt( - TensorView* _producer, - TensorView* _consumer, - TensorView* _reference, - unsigned int _reference_position, - ComputeAtMode _mode); - - ~ComputeAt() = default; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp new file mode 100644 index 0000000000000..1cb77e443cb12 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -0,0 +1,323 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +ConsumeAt::ConsumeAt( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode) + : TransformPropagatorBase(reference, reference_pos), + consume(std::move(consume)), + mode(mode) { + TORCH_INTERNAL_ASSERT( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid computeAt axis, received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); +} + +void ConsumeAt::consumeAllAt( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode) { + ConsumeAt ca(std::move(consume), reference, reference_pos, mode); + ca.run(); +} + +void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { + if (consume.count(tv)) { + tv->setComputeAt(pos); + } else { + replayed_pos[tv] = pos; + } +} + +size_t ConsumeAt::getReplayablePos(TensorView* tv) { + // TODO: ComputeAt has something about unmappable dims + // should I do the same? + auto c_dom = tv->domain()->domain(); + auto vector_dim_it = + std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { + return isParallelTypeVectorize(id->getParallelType()) || + ((mode == ComputeAtMode::BestEffort || + mode == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + }); + + return std::distance(c_dom.begin(), vector_dim_it); +} + +c10::optional ConsumeAt::retrieveReplayedPos(TensorView* tv) { + size_t max_pos = getReplayablePos(tv); + size_t pos; + + if (consume.count(tv)) { + pos = tv->getComputeAtPosition(); + } + auto it = replayed_pos.find(tv); + if (it == replayed_pos.end()) { + return c10::nullopt; + } + pos = it->second; + + if (mode == ComputeAtMode::BestEffort) { + return std::min(pos, max_pos); + } else if (mode == ComputeAtMode::MostInlined) { + return max_pos; + } + + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in ConsumeAt when trying to replay: ", + tv, + ", tried to do this at position: ", + pos, + " but max position that's allowed is ", + max_pos); + + return pos; +} + +// // Return the max position in consumer that producer can be inlined to +// // Cannot inline: +// // Reduction dimensions in producer +// // Block broadcast dimensions in producer +// // Vectorized dimensions in producer or consumer +// // Unrolled dimensions in producer or consumer +// // Dimensions derived from root dimensions that exist in both but are +// // unmappable +// size_t ConsumeAt::getReplayablePosPasC( +// TensorView* producer, +// TensorView* consumer) { +// // Check if any consumer dimensions are marked as vectorize as producer can +// // not be inlined to vectorized dimensions in consumer. +// auto c_dom = consumer->domain()->domain(); +// auto vector_dim_it = +// std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { +// return isParallelTypeVectorize(id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// id->getParallelType() == ParallelType::Unroll); +// }); + +// // Limit max position based on vectorized dims in consumer. +// auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); + +// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); +// auto c2p_root_map = +// PairwiseRootDomainMap(producer, consumer) +// .mapConsumerToProducer(consumer->domain(), producer->domain()); + +// auto replay_PasC = +// BestEffortReplay::replayPasC(producer, consumer, -1, +// pairwise_root_map); + +// // Look for id's that map to a consumer id that's vectorized +// auto c2p_replay_map = replay_PasC.getReplay(); + +// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; +// consumer_pos--) { +// auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); +// if (map_it != c2p_replay_map.end()) { +// auto p_id = map_it->second; +// // If we find a consumer dim that maps to a producer dim that's +// // vectorized or unrolled limit max compute at by it. +// if (isParallelTypeVectorize(p_id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// p_id->getParallelType() == ParallelType::Unroll)) { +// max_consumer_pos = consumer_pos - 1; +// } +// } +// } + +// return max_consumer_pos; +// } + +// // Return the max position in producer that can be inlined to consumer +// // Cannot inline: +// // Reduction dimensions in producer +// // Vectorized dimensions in producer or consumer +// // Unrolled dimensions in producer or consumer +// // Dimensions derived from root dimensions that exist in both but are +// // unmappable +// size_t ConsumeAt::getReplayablePosCasP( +// TensorView* consumer, +// TensorView* producer) { +// auto p_dom = producer->domain()->domain(); +// auto first_reduction = +// std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { +// return id->isReduction(); +// }); + +// auto first_vectorized_axis = +// std::find_if(p_dom.begin(), first_reduction, [this](IterDomain* id) { +// return isParallelTypeVectorize(id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// id->getParallelType() == ParallelType::Unroll); +// }); + +// auto max_producer_pos = std::distance(p_dom.begin(), +// first_vectorized_axis); + +// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); +// auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( +// producer->domain(), consumer->domain()); + +// auto replay_CasP = +// BestEffortReplay::replayCasP(consumer, producer, -1, +// pairwise_root_map); + +// // Look for id's that map to a consumer id that's vectorized +// auto p2c_replay_map = replay_CasP.getReplay(); + +// for (size_t producer_pos = max_producer_pos; producer_pos > 0; +// producer_pos--) { +// auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); +// if (map_it != p2c_replay_map.end()) { +// auto c_id = map_it->second; +// // If we find a producer dim that maps to a consumer vectorized or +// // unrolled dim, limit max compute at by it +// if (isParallelTypeVectorize(c_id->getParallelType()) || +// ((mode == ComputeAtMode::BestEffort || +// mode == ComputeAtMode::MostInlined) && +// c_id->getParallelType() == ParallelType::Unroll)) { +// max_producer_pos = producer_pos - 1; +// } +// } +// } + +// return max_producer_pos; +// } + +void ConsumeAt::hoistInnermostBroadcast() { + for (auto tv : consume) { + if (!tv->isFusionInput()) { + auto ca_pos = tv->getComputeAtPosition(); + while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { + ca_pos--; + } + tv->setComputeAt(ca_pos, true); + } + } +} + +// TODO: most of this is copy-pasted code. I need to investigate this to +// see which makes sense and which needs change. +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. Used in computeAt pass only. No +// checking on actual producer-consumer relationship. +size_t getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto c2p_map = + BestEffortReplay::replayPasC( + producer, + consumer, + -1, + // Compute at root domain may not be valid here, as all + // producers don't have to be able to map into consumer at + // max producer position. Since computeAt should be valid + // and this mechanism is only intended to lower produce + // position of consumer, we can simply use the pairwise map. + PairwiseRootDomainMap(producer, consumer)) + .getReplay(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + unsigned int consumer_pos = consumer->nDims(); + while (consumer_pos > 0) { + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &c2p_map](IterDomain* p_id) { + auto c_id_it = c2p_map.find(consumer_id); + if (c_id_it != c2p_map.end()) { + return c_id_it->second == p_id; + } + return false; + })) { + break; + } + consumer_pos--; + } + + return consumer_pos; +} + +void ConsumeAt::computeMaxProducerPos() { + std::unordered_set todo; + for (auto p : consume) { + auto consumers = ir_utils::consumerTvsOf(p); + std::copy( + consumers.begin(), consumers.end(), std::inserter(todo, todo.end())); + } + for (auto tv : todo) { + auto producers = ir_utils::producerTvsOf(tv); + size_t max_pos = 0; + for (auto p : producers) { + max_pos = + std::max(max_pos, getConsumerPosAlignedToProducerCA(tv, p)); + } + tv->setMaxProducer(max_pos, true); + } +} + +bool ConsumeAt::shouldPropagate(TensorView* tv) { + if (consume.count(tv)) { + return true; + } + + // If one of tv's producer is in the consume set, then tv must also be + // replayed to obtain a compatible loop structure so that this producer + // can be consumed in this loop. + auto def = tv->definition(); + if (def != nullptr) { + auto tv_inputs = ir_utils::filterByType(def->inputs()); + for (auto input : tv_inputs) { + if (consume.count(input)) { + return true; + } + } + } + + return false; +} + +void ConsumeAt::run() { + TransformPropagatorBase::run(); + hoistInnermostBroadcast(); + computeMaxProducerPos(); + // TODO: check everyone is visited +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h new file mode 100644 index 0000000000000..850715058c7d3 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class ConsumeAt : public TransformPropagatorBase { + private: +// size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); +// size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); + size_t getReplayablePos(TensorView* tv); + void hoistInnermostBroadcast(); + void computeMaxProducerPos(); + + protected: + std::unordered_set consume; + ComputeAtMode mode = ComputeAtMode::Standard; + std::unordered_map replayed_pos; + + virtual bool shouldPropagate(TensorView* tv) override; + virtual void recordReplayedPos(TensorView* tv, size_t pos) override; + virtual c10::optional retrieveReplayedPos(TensorView* tv) override; + + ConsumeAt( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode); + + ~ConsumeAt() = default; + + void run(); + + public: + static void consumeAllAt( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode); +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index ab37e8fcf6319..019a97179eca6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -153,7 +153,8 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class ComputeAt; +class ConsumeAt; +class TransformPropagatorBase; class TransformPropagator; class TransformIter; class TransformReplay; @@ -453,10 +454,11 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . void applyMmaSwizzle(MmaOptions options); + friend TORCH_CUDA_CU_API TransformPropagatorBase; friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend ComputeAt; + friend TORCH_CUDA_CU_API ConsumeAt; friend class ir_utils::TVDomainGuard; friend TORCH_CUDA_CU_API void groupReductions( const std::vector&); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 6ab0df7b47cb9..17929af0a1827 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -705,21 +705,22 @@ std::deque producersFor(TensorView* tv) { }; // namespace -bool TransformPropagator::replayPasC( +bool TransformPropagatorBase::replayPasC( TensorView* producer_tv, TensorView* consumer_tv) { if (producer_tv == starting_tv) { return false; } - auto consumer_pos_it = replayed_pos.find(consumer_tv); - if (consumer_pos_it == replayed_pos.end()) { + auto consumer_pos_opt = retrieveReplayedPos(consumer_tv); + if (!consumer_pos_opt.has_value()) { return false; } + auto consumer_pos = consumer_pos_opt.value(); auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); auto replayed_producer = TransformReplay::replayPasC( - producer_tv, consumer_tv, consumer_pos_it->second, pairwiseMap); + producer_tv, consumer_tv, consumer_pos, pairwiseMap); auto producer_root = producer_tv->getMaybeRFactorDomain(); auto replayed_domain = replayed_producer.first->domain(); @@ -739,36 +740,38 @@ bool TransformPropagator::replayPasC( return dep_vals_set.find(root_id) != dep_vals_set.end(); }); - if (replayed_pos.find(producer_tv) != replayed_pos.end()) { + auto producer_pos_opt = retrieveReplayedPos(producer_tv); + if (producer_pos_opt.has_value()) { if (n_transformed_root_dims < n_replayed_root_dims.at(producer_tv) || (n_transformed_root_dims == n_replayed_root_dims.at(producer_tv) && - replayed_producer.second <= replayed_pos.at(producer_tv))) { + replayed_producer.second <= producer_pos_opt.value())) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } } producer_tv->setDomain(replayed_producer.first); - replayed_pos[producer_tv] = replayed_producer.second; + recordReplayedPos(producer_tv, replayed_producer.second); n_replayed_root_dims[producer_tv] = n_transformed_root_dims; return true; } -bool TransformPropagator::replayCasP( +bool TransformPropagatorBase::replayCasP( TensorView* consumer_tv, TensorView* producer_tv) { if (consumer_tv == starting_tv) { return false; } - auto producer_pos_it = replayed_pos.find(producer_tv); - if (producer_pos_it == replayed_pos.end()) { + auto producer_pos_opt = retrieveReplayedPos(producer_tv); + if (!producer_pos_opt.has_value()) { return false; } + auto producer_pos = producer_pos_opt.value(); auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); auto replayed_consumer = TransformReplay::replayCasP( - consumer_tv, producer_tv, producer_pos_it->second, pairwiseMap); + consumer_tv, producer_tv, producer_pos, pairwiseMap); auto consumer_root = consumer_tv->getRootDomain(); auto replayed_domain = replayed_consumer.first->domain(); @@ -788,26 +791,32 @@ bool TransformPropagator::replayCasP( return dep_vals_set.find(root_id) != dep_vals_set.end(); }); - if (replayed_pos.find(consumer_tv) != replayed_pos.end()) { + auto consumer_pos_opt = retrieveReplayedPos(consumer_tv); + if (consumer_pos_opt.has_value()) { if (n_transformed_root_dims < n_replayed_root_dims.at(consumer_tv) || (n_transformed_root_dims == n_replayed_root_dims.at(consumer_tv) && - replayed_consumer.second <= replayed_pos.at(consumer_tv))) { + replayed_consumer.second <= consumer_pos_opt.value())) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } } consumer_tv->setDomain(replayed_consumer.first); - replayed_pos[consumer_tv] = replayed_consumer.second; + recordReplayedPos(consumer_tv, replayed_consumer.second); n_replayed_root_dims[consumer_tv] = n_transformed_root_dims; return true; } -TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { +TransformPropagatorBase::TransformPropagatorBase( + TensorView* from, + size_t starting_pos) + : starting_tv(from), starting_pos(starting_pos) {} + +void TransformPropagatorBase::run() { VectorOfUniqueEntries propagation{starting_tv}; // Seed position with local tv - replayed_pos[from] = from->nDims(); + recordReplayedPos(starting_tv, starting_pos); // While tensor views are being replayed, if they're modified, make sure we // propagate back to all producers as well as consumers. This is definitely @@ -818,6 +827,9 @@ TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { // Replay tv forward to its consumers. for (auto consumer_tv : consumersOf(tv)) { + if (!shouldPropagate(consumer_tv)) { + continue; + } auto replayed = replayCasP(consumer_tv, tv); // If consumer has changed, mark we should propagate if (replayed) { @@ -826,6 +838,9 @@ TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { } for (auto producer_tv : producersFor(tv)) { + if (!shouldPropagate(producer_tv)) { + continue; + } // If producer has changed, mark we should propagate auto replayed = replayPasC(producer_tv, tv); if (replayed) { @@ -835,8 +850,24 @@ TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { } } +TransformPropagator::TransformPropagator(TensorView* from, size_t starting_pos) + : TransformPropagatorBase(from, starting_pos) {} + void TransformPropagator::from(TensorView* tv) { - TransformPropagator propagate(tv); + TransformPropagator propagate(tv, tv->nDims()); + propagate.run(); +} + +void TransformPropagator::recordReplayedPos(TensorView *tv, size_t pos) { + replayed_pos[tv] = pos; +} + +c10::optional TransformPropagator::retrieveReplayedPos(TensorView *tv) { + auto it = replayed_pos.find(tv); + if (it != replayed_pos.end()) { + return it->second; + } + return c10::nullopt; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 0f7c7c00c8532..370215280200b 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -153,14 +154,44 @@ class TORCH_CUDA_CU_API TransformReplay { const TensorDomain* self); }; -class TORCH_CUDA_CU_API TransformPropagator { - private: +class TORCH_CUDA_CU_API TransformPropagatorBase { + // TODO: keep? + // This example comes from a BN kernel, the domain: + // + // [ iS{ceilDiv(ceilDiv(ceilDiv(i4, 128), 4), 1)}, iS{1}, iS{4}, iS{128}, + // iS{i0}, iS{i2}, iS{i3} ] + // + // and + // + // [ iS{ceilDiv(ceilDiv(ceilDiv(i5*i6*i7*i8, 128), 4), 1)}, iS252{1}, + // iS250{4}, iS248{128} ] + // + // Have the same number of replayed dimensions, however the second one + // involves more root domains. The second one is also likely the prefered + // replay. Therefore keep track of how many root domains were part of the + // replay and prefer transformations with more root domains. We could probably + // fix this instances of this occuring by changing the traversal pattern so + // that once propagating towards roots through broadcast axes, it can't come + // back through another broadcast, losing the transformation on those axes. + // However, this should work for existing cases. + std::unordered_map n_replayed_root_dims; + + TensorView* starting_tv = nullptr; + size_t starting_pos; + + protected: + TransformPropagatorBase(TensorView* from, size_t starting_pos); bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr); bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr); + void run(); + + virtual void recordReplayedPos(TensorView *tv, size_t pos) = 0; + virtual c10::optional retrieveReplayedPos(TensorView *tv) = 0; + virtual bool shouldPropagate(TensorView* tv) { return true; } +}; - TransformPropagator(TensorView* from); +class TORCH_CUDA_CU_API TransformPropagator : public TransformPropagatorBase { - private: std::unordered_map replayed_pos; // This example comes from a BN kernel, the domain: @@ -182,7 +213,12 @@ class TORCH_CUDA_CU_API TransformPropagator { // back through another broadcast, losing the transformation on those axes. // However, this should work for existing cases. std::unordered_map n_replayed_root_dims; - TensorView* starting_tv = nullptr; + + protected: + TransformPropagator(TensorView* from, size_t starting_pos); + + virtual void recordReplayedPos(TensorView *tv, size_t pos) override; + virtual c10::optional retrieveReplayedPos(TensorView *tv) override; public: static void from(TensorView* tv); From 28348d0270402bc463f2b57ba85b79a556b718e1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 11:32:44 -0700 Subject: [PATCH 002/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 7 +++++-- torch/csrc/jit/codegen/cuda/consume_at.h | 3 +-- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 7 +++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 1cb77e443cb12..6b4b8ddfdbae9 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -17,7 +17,8 @@ ConsumeAt::ConsumeAt( ComputeAtMode mode) : TransformPropagatorBase(reference, reference_pos), consume(std::move(consume)), - mode(mode) { + mode(mode), + unvisited(consume) { TORCH_INTERNAL_ASSERT( reference_pos >= 0 && reference_pos <= reference->nDims(), "Invalid computeAt axis, received ", @@ -291,6 +292,7 @@ void ConsumeAt::computeMaxProducerPos() { bool ConsumeAt::shouldPropagate(TensorView* tv) { if (consume.count(tv)) { + unvisited.erase(tv); return true; } @@ -314,7 +316,8 @@ void ConsumeAt::run() { TransformPropagatorBase::run(); hoistInnermostBroadcast(); computeMaxProducerPos(); - // TODO: check everyone is visited + TORCH_CHECK( + unvisited.empty(), "Unable to propagate to the entire consume set"); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 850715058c7d3..456b12b467245 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -11,17 +11,16 @@ namespace fuser { namespace cuda { class ConsumeAt : public TransformPropagatorBase { - private: // size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); // size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); size_t getReplayablePos(TensorView* tv); void hoistInnermostBroadcast(); void computeMaxProducerPos(); - protected: std::unordered_set consume; ComputeAtMode mode = ComputeAtMode::Standard; std::unordered_map replayed_pos; + std::unordered_set unvisited; virtual bool shouldPropagate(TensorView* tv) override; virtual void recordReplayedPos(TensorView* tv, size_t pos) override; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 9c77d1b4d3978..9fb665911c445 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -5062,9 +5062,16 @@ TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { tv7->split(-1, 4); tv7->split(0, 8); + fusion.print(); + tv0->computeAt(tv7, -1); + + fusion.print(); + tv2->computeAt(tv7, -1); + fusion.print(); + tv7->axis(0)->parallelize(ParallelType::BIDx); tv7->axis(-1)->parallelize(ParallelType::TIDx); From 1a1e609d71159a7a2522c235b9b4c4932765a30b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 12:09:51 -0700 Subject: [PATCH 003/100] fix --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 4 +++- torch/csrc/jit/codegen/cuda/consume_at.cpp | 22 ++++++++++++------- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 7 ------ .../jit/codegen/cuda/transform_replay.cpp | 3 +++ 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 3064477a57690..8a8b853d610f5 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -369,7 +369,9 @@ std::unordered_set getPropagationSubgraph( ", however it is not."); TensorView *common_consumer = getCommonConsumer(producer, consumer); if (common_consumer != nullptr) { - return getAllTVsBetween(producer, consumer); + auto result = getAllTVsBetween(producer, common_consumer); + std::cout << "common_consumer != nullptr; returning " << ir_utils::toString(result) << std::endl; + return result; } auto result_vals = DependencyCheck::getAllDependentVals({producer}); auto result_tvs = ir_utils::filterByType(result_vals); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 6b4b8ddfdbae9..fd21c9ac071e7 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -67,13 +67,17 @@ c10::optional ConsumeAt::retrieveReplayedPos(TensorView* tv) { size_t pos; if (consume.count(tv)) { + if (!tv->hasComputeAt()) { + return c10::nullopt; + } pos = tv->getComputeAtPosition(); + } else { + auto it = replayed_pos.find(tv); + if (it == replayed_pos.end()) { + return c10::nullopt; + } + pos = it->second; } - auto it = replayed_pos.find(tv); - if (it == replayed_pos.end()) { - return c10::nullopt; - } - pos = it->second; if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -89,7 +93,7 @@ c10::optional ConsumeAt::retrieveReplayedPos(TensorView* tv) { pos, " but max position that's allowed is ", max_pos); - + std::cout << "retrieveReplayedPos: tensor=" << tv << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; return pos; } @@ -292,6 +296,7 @@ void ConsumeAt::computeMaxProducerPos() { bool ConsumeAt::shouldPropagate(TensorView* tv) { if (consume.count(tv)) { + std::cout << "visiting " << tv << std::endl; unvisited.erase(tv); return true; } @@ -314,10 +319,11 @@ bool ConsumeAt::shouldPropagate(TensorView* tv) { void ConsumeAt::run() { TransformPropagatorBase::run(); - hoistInnermostBroadcast(); - computeMaxProducerPos(); TORCH_CHECK( unvisited.empty(), "Unable to propagate to the entire consume set"); + + hoistInnermostBroadcast(); + computeMaxProducerPos(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 9fb665911c445..9c77d1b4d3978 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -5062,16 +5062,9 @@ TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { tv7->split(-1, 4); tv7->split(0, 8); - fusion.print(); - tv0->computeAt(tv7, -1); - - fusion.print(); - tv2->computeAt(tv7, -1); - fusion.print(); - tv7->axis(0)->parallelize(ParallelType::BIDx); tv7->axis(-1)->parallelize(ParallelType::TIDx); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 17929af0a1827..6cfe8a514d3f3 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -843,9 +843,12 @@ void TransformPropagatorBase::run() { } // If producer has changed, mark we should propagate auto replayed = replayPasC(producer_tv, tv); + std::cout << "replay " << producer_tv << " as " << tv; if (replayed) { + std::cout << " replayed" << std::endl; propagation.pushBack(producer_tv); } + std::cout << std::endl; } } } From 9f421d23df9fcc9fa598b69520529e2163248c35 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 12:11:10 -0700 Subject: [PATCH 004/100] cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 2 -- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 3 --- 2 files changed, 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index fd21c9ac071e7..b2c3f6c1296fb 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -93,7 +93,6 @@ c10::optional ConsumeAt::retrieveReplayedPos(TensorView* tv) { pos, " but max position that's allowed is ", max_pos); - std::cout << "retrieveReplayedPos: tensor=" << tv << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; return pos; } @@ -296,7 +295,6 @@ void ConsumeAt::computeMaxProducerPos() { bool ConsumeAt::shouldPropagate(TensorView* tv) { if (consume.count(tv)) { - std::cout << "visiting " << tv << std::endl; unvisited.erase(tv); return true; } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 6cfe8a514d3f3..17929af0a1827 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -843,12 +843,9 @@ void TransformPropagatorBase::run() { } // If producer has changed, mark we should propagate auto replayed = replayPasC(producer_tv, tv); - std::cout << "replay " << producer_tv << " as " << tv; if (replayed) { - std::cout << " replayed" << std::endl; propagation.pushBack(producer_tv); } - std::cout << std::endl; } } } From b84d0c4099f9128cb2d412507a452af39b3e4ae0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 12:12:14 -0700 Subject: [PATCH 005/100] cleanup --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 8a8b853d610f5..6cf25d811bc55 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -370,7 +370,6 @@ std::unordered_set getPropagationSubgraph( TensorView *common_consumer = getCommonConsumer(producer, consumer); if (common_consumer != nullptr) { auto result = getAllTVsBetween(producer, common_consumer); - std::cout << "common_consumer != nullptr; returning " << ir_utils::toString(result) << std::endl; return result; } auto result_vals = DependencyCheck::getAllDependentVals({producer}); From e8545a7640268e48acefcc79a63c4b407dae1e93 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 12:24:10 -0700 Subject: [PATCH 006/100] fix --- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 17929af0a1827..79d9a6094aa83 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -741,9 +741,10 @@ bool TransformPropagatorBase::replayPasC( }); auto producer_pos_opt = retrieveReplayedPos(producer_tv); - if (producer_pos_opt.has_value()) { - if (n_transformed_root_dims < n_replayed_root_dims.at(producer_tv) || - (n_transformed_root_dims == n_replayed_root_dims.at(producer_tv) && + auto replayed_root_dims_it = n_replayed_root_dims.find(producer_tv); + if (producer_pos_opt.has_value() && replayed_root_dims_it != n_replayed_root_dims.end()) { + if (n_transformed_root_dims < replayed_root_dims_it->second || + (n_transformed_root_dims == replayed_root_dims_it->second && replayed_producer.second <= producer_pos_opt.value())) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } From a441a90ae0a9fc5ab415e300ce7fce87ebf371cf Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 12:39:16 -0700 Subject: [PATCH 007/100] fix --- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 79d9a6094aa83..bdfd816c9a908 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -793,9 +793,10 @@ bool TransformPropagatorBase::replayCasP( }); auto consumer_pos_opt = retrieveReplayedPos(consumer_tv); - if (consumer_pos_opt.has_value()) { - if (n_transformed_root_dims < n_replayed_root_dims.at(consumer_tv) || - (n_transformed_root_dims == n_replayed_root_dims.at(consumer_tv) && + auto replayed_root_dims_it = n_replayed_root_dims.find(consumer_tv); + if (consumer_pos_opt.has_value() && replayed_root_dims_it != n_replayed_root_dims.end()) { + if (n_transformed_root_dims < replayed_root_dims_it->second || + (n_transformed_root_dims == replayed_root_dims_it->second && replayed_consumer.second <= consumer_pos_opt.value())) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } From 5c7e8463a6da9950debb01762a15a9b84eea3ea0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 13:35:30 -0700 Subject: [PATCH 008/100] cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 12 +++------ torch/csrc/jit/codegen/cuda/consume_at.h | 2 +- .../jit/codegen/cuda/transform_replay.cpp | 26 +++++++++---------- .../csrc/jit/codegen/cuda/transform_replay.h | 4 +-- 4 files changed, 19 insertions(+), 25 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index b2c3f6c1296fb..d9dd823631458 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -62,21 +62,17 @@ size_t ConsumeAt::getReplayablePos(TensorView* tv) { return std::distance(c_dom.begin(), vector_dim_it); } -c10::optional ConsumeAt::retrieveReplayedPos(TensorView* tv) { +size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { size_t max_pos = getReplayablePos(tv); - size_t pos; + size_t pos = 0; if (consume.count(tv)) { - if (!tv->hasComputeAt()) { - return c10::nullopt; - } pos = tv->getComputeAtPosition(); } else { auto it = replayed_pos.find(tv); - if (it == replayed_pos.end()) { - return c10::nullopt; + if (it != replayed_pos.end()) { + pos = it->second; } - pos = it->second; } if (mode == ComputeAtMode::BestEffort) { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 456b12b467245..904fb4fc37ada 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -24,7 +24,7 @@ class ConsumeAt : public TransformPropagatorBase { virtual bool shouldPropagate(TensorView* tv) override; virtual void recordReplayedPos(TensorView* tv, size_t pos) override; - virtual c10::optional retrieveReplayedPos(TensorView* tv) override; + virtual size_t retrieveReplayedPos(TensorView* tv) override; ConsumeAt( std::unordered_set consume, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index bdfd816c9a908..d57484f643dbd 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -712,11 +712,10 @@ bool TransformPropagatorBase::replayPasC( return false; } - auto consumer_pos_opt = retrieveReplayedPos(consumer_tv); - if (!consumer_pos_opt.has_value()) { + auto consumer_pos = retrieveReplayedPos(consumer_tv); + if (consumer_pos == 0) { return false; } - auto consumer_pos = consumer_pos_opt.value(); auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); auto replayed_producer = TransformReplay::replayPasC( @@ -740,12 +739,12 @@ bool TransformPropagatorBase::replayPasC( return dep_vals_set.find(root_id) != dep_vals_set.end(); }); - auto producer_pos_opt = retrieveReplayedPos(producer_tv); + auto producer_pos = retrieveReplayedPos(producer_tv); auto replayed_root_dims_it = n_replayed_root_dims.find(producer_tv); - if (producer_pos_opt.has_value() && replayed_root_dims_it != n_replayed_root_dims.end()) { + if (producer_pos > 0 && replayed_root_dims_it != n_replayed_root_dims.end()) { if (n_transformed_root_dims < replayed_root_dims_it->second || (n_transformed_root_dims == replayed_root_dims_it->second && - replayed_producer.second <= producer_pos_opt.value())) { + replayed_producer.second <= producer_pos)) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } } @@ -764,11 +763,10 @@ bool TransformPropagatorBase::replayCasP( return false; } - auto producer_pos_opt = retrieveReplayedPos(producer_tv); - if (!producer_pos_opt.has_value()) { + auto producer_pos = retrieveReplayedPos(producer_tv); + if (producer_pos == 0) { return false; } - auto producer_pos = producer_pos_opt.value(); auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); auto replayed_consumer = TransformReplay::replayCasP( @@ -792,12 +790,12 @@ bool TransformPropagatorBase::replayCasP( return dep_vals_set.find(root_id) != dep_vals_set.end(); }); - auto consumer_pos_opt = retrieveReplayedPos(consumer_tv); + auto consumer_pos = retrieveReplayedPos(consumer_tv); auto replayed_root_dims_it = n_replayed_root_dims.find(consumer_tv); - if (consumer_pos_opt.has_value() && replayed_root_dims_it != n_replayed_root_dims.end()) { + if (consumer_pos > 0 && replayed_root_dims_it != n_replayed_root_dims.end()) { if (n_transformed_root_dims < replayed_root_dims_it->second || (n_transformed_root_dims == replayed_root_dims_it->second && - replayed_consumer.second <= consumer_pos_opt.value())) { + replayed_consumer.second <= consumer_pos)) { return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) } } @@ -864,12 +862,12 @@ void TransformPropagator::recordReplayedPos(TensorView *tv, size_t pos) { replayed_pos[tv] = pos; } -c10::optional TransformPropagator::retrieveReplayedPos(TensorView *tv) { +size_t TransformPropagator::retrieveReplayedPos(TensorView *tv) { auto it = replayed_pos.find(tv); if (it != replayed_pos.end()) { return it->second; } - return c10::nullopt; + return 0; } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 370215280200b..ce796eb85cbff 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -186,7 +186,7 @@ class TORCH_CUDA_CU_API TransformPropagatorBase { void run(); virtual void recordReplayedPos(TensorView *tv, size_t pos) = 0; - virtual c10::optional retrieveReplayedPos(TensorView *tv) = 0; + virtual size_t retrieveReplayedPos(TensorView *tv) = 0; virtual bool shouldPropagate(TensorView* tv) { return true; } }; @@ -218,7 +218,7 @@ class TORCH_CUDA_CU_API TransformPropagator : public TransformPropagatorBase { TransformPropagator(TensorView* from, size_t starting_pos); virtual void recordReplayedPos(TensorView *tv, size_t pos) override; - virtual c10::optional retrieveReplayedPos(TensorView *tv) override; + virtual size_t retrieveReplayedPos(TensorView *tv) override; public: static void from(TensorView* tv); From 938de6f27e9a8ce96712cc662389163d9aa159e0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 14:43:12 -0700 Subject: [PATCH 009/100] save --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 6cf25d811bc55..5152b82d10d62 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -374,7 +374,13 @@ std::unordered_set getPropagationSubgraph( } auto result_vals = DependencyCheck::getAllDependentVals({producer}); auto result_tvs = ir_utils::filterByType(result_vals); - return {result_tvs.begin(), result_tvs.end()}; + std::unordered_set result; + std::copy_if( + result_tvs.begin(), + result_tvs.end(), + std::inserter(result, result.begin()), + [](TensorView* tv) { return !tv->isFusionOutput(); }); + return result; } } // namespace From fd161987f60e34f14ef9862bcf896286a5c13561 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 15:48:26 -0700 Subject: [PATCH 010/100] fix --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 5152b82d10d62..78ed7cb062b09 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -373,6 +373,7 @@ std::unordered_set getPropagationSubgraph( return result; } auto result_vals = DependencyCheck::getAllDependentVals({producer}); + result_vals.emplace(producer); auto result_tvs = ir_utils::filterByType(result_vals); std::unordered_set result; std::copy_if( From af1a560631ca9719423a195188871f3e593c5da9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 16:14:50 -0700 Subject: [PATCH 011/100] fix --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 78ed7cb062b09..9e4ed34b06913 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -380,7 +380,7 @@ std::unordered_set getPropagationSubgraph( result_tvs.begin(), result_tvs.end(), std::inserter(result, result.begin()), - [](TensorView* tv) { return !tv->isFusionOutput(); }); + [](TensorView* tv) { return !tv->uses().empty(); }); return result; } From 62852ff89796321e04ddd800cd2d666e4905763d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 17:24:37 -0700 Subject: [PATCH 012/100] note --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index d9dd823631458..14b80a2adc784 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -47,6 +47,7 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { } } +// TODO: this is definitely wrong. We need to split PasC and CasP case size_t ConsumeAt::getReplayablePos(TensorView* tv) { // TODO: ComputeAt has something about unmappable dims // should I do the same? From fd9259d5be63d56b1785defb57fbecfbe763b630 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 18:21:07 -0700 Subject: [PATCH 013/100] split PasC and CasP --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 264 +++++++++--------- torch/csrc/jit/codegen/cuda/consume_at.h | 8 +- .../jit/codegen/cuda/transform_replay.cpp | 12 +- .../csrc/jit/codegen/cuda/transform_replay.h | 4 + 4 files changed, 157 insertions(+), 131 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 14b80a2adc784..39822992bf3ec 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -47,11 +47,20 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { } } -// TODO: this is definitely wrong. We need to split PasC and CasP case -size_t ConsumeAt::getReplayablePos(TensorView* tv) { - // TODO: ComputeAt has something about unmappable dims - // should I do the same? - auto c_dom = tv->domain()->domain(); +// Return the max position in consumer that producer can be inlined to +// Cannot inline: +// Reduction dimensions in producer +// Block broadcast dimensions in producer +// Vectorized dimensions in producer or consumer +// Unrolled dimensions in producer or consumer +// Dimensions derived from root dimensions that exist in both but are +// unmappable +size_t ConsumeAt::getReplayablePosPasC( + TensorView* producer, + TensorView* consumer) { + // Check if any consumer dimensions are marked as vectorize as producer can + // not be inlined to vectorized dimensions in consumer. + auto c_dom = consumer->domain()->domain(); auto vector_dim_it = std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { return isParallelTypeVectorize(id->getParallelType()) || @@ -60,21 +69,112 @@ size_t ConsumeAt::getReplayablePos(TensorView* tv) { id->getParallelType() == ParallelType::Unroll); }); - return std::distance(c_dom.begin(), vector_dim_it); + // Limit max position based on vectorized dims in consumer. + auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); + + auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto c2p_root_map = + PairwiseRootDomainMap(producer, consumer) + .mapConsumerToProducer(consumer->domain(), producer->domain()); + + auto replay_PasC = + BestEffortReplay::replayPasC(producer, consumer, -1, + pairwise_root_map); + + // Look for id's that map to a consumer id that's vectorized + auto c2p_replay_map = replay_PasC.getReplay(); + + for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; + consumer_pos--) { + auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); + if (map_it != c2p_replay_map.end()) { + auto p_id = map_it->second; + // If we find a consumer dim that maps to a producer dim that's + // vectorized or unrolled limit max compute at by it. + if (isParallelTypeVectorize(p_id->getParallelType()) || + ((mode == ComputeAtMode::BestEffort || + mode == ComputeAtMode::MostInlined) && + p_id->getParallelType() == ParallelType::Unroll)) { + max_consumer_pos = consumer_pos - 1; + } + } + } + + return max_consumer_pos; } -size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { - size_t max_pos = getReplayablePos(tv); - size_t pos = 0; +// Return the max position in producer that can be inlined to consumer +// Cannot inline: +// Reduction dimensions in producer +// Vectorized dimensions in producer or consumer +// Unrolled dimensions in producer or consumer +// Dimensions derived from root dimensions that exist in both but are +// unmappable +size_t ConsumeAt::getReplayablePosCasP( + TensorView* consumer, + TensorView* producer) { + auto p_dom = producer->domain()->domain(); + auto first_reduction = + std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { + return id->isReduction(); + }); + auto first_vectorized_axis = + std::find_if(p_dom.begin(), first_reduction, [this](IterDomain* id) { + return isParallelTypeVectorize(id->getParallelType()) || + ((mode == ComputeAtMode::BestEffort || + mode == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + }); + + auto max_producer_pos = std::distance(p_dom.begin(), + first_vectorized_axis); + + auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); + + auto replay_CasP = + BestEffortReplay::replayCasP(consumer, producer, -1, + pairwise_root_map); + + // Look for id's that map to a consumer id that's vectorized + auto p2c_replay_map = replay_CasP.getReplay(); + + for (size_t producer_pos = max_producer_pos; producer_pos > 0; + producer_pos--) { + auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); + if (map_it != p2c_replay_map.end()) { + auto c_id = map_it->second; + // If we find a producer dim that maps to a consumer vectorized or + // unrolled dim, limit max compute at by it + if (isParallelTypeVectorize(c_id->getParallelType()) || + ((mode == ComputeAtMode::BestEffort || + mode == ComputeAtMode::MostInlined) && + c_id->getParallelType() == ParallelType::Unroll)) { + max_producer_pos = producer_pos - 1; + } + } + } + + return max_producer_pos; +} + +size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { if (consume.count(tv)) { - pos = tv->getComputeAtPosition(); + return tv->getComputeAtPosition(); } else { auto it = replayed_pos.find(tv); if (it != replayed_pos.end()) { - pos = it->second; + return it->second; } } + return 0; +} + +size_t ConsumeAt::getReplayPosPasC(TensorView *producer, TensorView *consumer) { + size_t max_pos = getReplayablePosPasC(producer, consumer); + size_t pos = retrieveReplayedPos(consumer); if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -84,127 +184,39 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { TORCH_INTERNAL_ASSERT( pos <= max_pos, - "Invalid compute at position detected in ConsumeAt when trying to replay: ", - tv, - ", tried to do this at position: ", + "Invalid compute at position detected in compute at when trying to replay producer: ", + producer, + " as consumer: ", + consumer, + " tried to do this at position: ", pos, " but max position that's allowed is ", max_pos); return pos; } -// // Return the max position in consumer that producer can be inlined to -// // Cannot inline: -// // Reduction dimensions in producer -// // Block broadcast dimensions in producer -// // Vectorized dimensions in producer or consumer -// // Unrolled dimensions in producer or consumer -// // Dimensions derived from root dimensions that exist in both but are -// // unmappable -// size_t ConsumeAt::getReplayablePosPasC( -// TensorView* producer, -// TensorView* consumer) { -// // Check if any consumer dimensions are marked as vectorize as producer can -// // not be inlined to vectorized dimensions in consumer. -// auto c_dom = consumer->domain()->domain(); -// auto vector_dim_it = -// std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { -// return isParallelTypeVectorize(id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// id->getParallelType() == ParallelType::Unroll); -// }); - -// // Limit max position based on vectorized dims in consumer. -// auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); - -// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); -// auto c2p_root_map = -// PairwiseRootDomainMap(producer, consumer) -// .mapConsumerToProducer(consumer->domain(), producer->domain()); - -// auto replay_PasC = -// BestEffortReplay::replayPasC(producer, consumer, -1, -// pairwise_root_map); - -// // Look for id's that map to a consumer id that's vectorized -// auto c2p_replay_map = replay_PasC.getReplay(); - -// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; -// consumer_pos--) { -// auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); -// if (map_it != c2p_replay_map.end()) { -// auto p_id = map_it->second; -// // If we find a consumer dim that maps to a producer dim that's -// // vectorized or unrolled limit max compute at by it. -// if (isParallelTypeVectorize(p_id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// p_id->getParallelType() == ParallelType::Unroll)) { -// max_consumer_pos = consumer_pos - 1; -// } -// } -// } - -// return max_consumer_pos; -// } - -// // Return the max position in producer that can be inlined to consumer -// // Cannot inline: -// // Reduction dimensions in producer -// // Vectorized dimensions in producer or consumer -// // Unrolled dimensions in producer or consumer -// // Dimensions derived from root dimensions that exist in both but are -// // unmappable -// size_t ConsumeAt::getReplayablePosCasP( -// TensorView* consumer, -// TensorView* producer) { -// auto p_dom = producer->domain()->domain(); -// auto first_reduction = -// std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { -// return id->isReduction(); -// }); - -// auto first_vectorized_axis = -// std::find_if(p_dom.begin(), first_reduction, [this](IterDomain* id) { -// return isParallelTypeVectorize(id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// id->getParallelType() == ParallelType::Unroll); -// }); - -// auto max_producer_pos = std::distance(p_dom.begin(), -// first_vectorized_axis); - -// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); -// auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( -// producer->domain(), consumer->domain()); - -// auto replay_CasP = -// BestEffortReplay::replayCasP(consumer, producer, -1, -// pairwise_root_map); - -// // Look for id's that map to a consumer id that's vectorized -// auto p2c_replay_map = replay_CasP.getReplay(); - -// for (size_t producer_pos = max_producer_pos; producer_pos > 0; -// producer_pos--) { -// auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); -// if (map_it != p2c_replay_map.end()) { -// auto c_id = map_it->second; -// // If we find a producer dim that maps to a consumer vectorized or -// // unrolled dim, limit max compute at by it -// if (isParallelTypeVectorize(c_id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// c_id->getParallelType() == ParallelType::Unroll)) { -// max_producer_pos = producer_pos - 1; -// } -// } -// } - -// return max_producer_pos; -// } +size_t ConsumeAt::getReplayPosCasP(TensorView *consumer, TensorView *producer) { + size_t max_pos = getReplayablePosCasP(consumer, producer); + size_t pos = retrieveReplayedPos(producer); + + if (mode == ComputeAtMode::BestEffort) { + return std::min(pos, max_pos); + } else if (mode == ComputeAtMode::MostInlined) { + return max_pos; + } + + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in compute at when trying to replay consumer: ", + consumer, + " as producer: ", + producer, + " tried to do this at position: ", + pos, + " but max position that's allowed is ", + max_pos); + return pos; +} void ConsumeAt::hoistInnermostBroadcast() { for (auto tv : consume) { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 904fb4fc37ada..2afafc84d8500 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -11,9 +11,9 @@ namespace fuser { namespace cuda { class ConsumeAt : public TransformPropagatorBase { -// size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); -// size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); - size_t getReplayablePos(TensorView* tv); + size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); + size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); +// size_t getReplayablePos(TensorView* tv); // TODO: delete void hoistInnermostBroadcast(); void computeMaxProducerPos(); @@ -25,6 +25,8 @@ class ConsumeAt : public TransformPropagatorBase { virtual bool shouldPropagate(TensorView* tv) override; virtual void recordReplayedPos(TensorView* tv, size_t pos) override; virtual size_t retrieveReplayedPos(TensorView* tv) override; + virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) override; + virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) override; ConsumeAt( std::unordered_set consume, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d57484f643dbd..905c7329514af 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -712,7 +712,7 @@ bool TransformPropagatorBase::replayPasC( return false; } - auto consumer_pos = retrieveReplayedPos(consumer_tv); + auto consumer_pos = getReplayPosPasC(producer_tv, consumer_tv); if (consumer_pos == 0) { return false; } @@ -763,7 +763,7 @@ bool TransformPropagatorBase::replayCasP( return false; } - auto producer_pos = retrieveReplayedPos(producer_tv); + auto producer_pos = getReplayPosCasP(consumer_tv, producer_tv); if (producer_pos == 0) { return false; } @@ -870,6 +870,14 @@ size_t TransformPropagator::retrieveReplayedPos(TensorView *tv) { return 0; } +size_t TransformPropagator::getReplayPosPasC(TensorView *producer, TensorView *consumer) { + return retrieveReplayedPos(consumer); +} + +size_t TransformPropagator::getReplayPosCasP(TensorView *consumer, TensorView *producer) { + return retrieveReplayedPos(producer); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index ce796eb85cbff..6e94d4fb0ce15 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -187,6 +187,8 @@ class TORCH_CUDA_CU_API TransformPropagatorBase { virtual void recordReplayedPos(TensorView *tv, size_t pos) = 0; virtual size_t retrieveReplayedPos(TensorView *tv) = 0; + virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) = 0; + virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) = 0; virtual bool shouldPropagate(TensorView* tv) { return true; } }; @@ -219,6 +221,8 @@ class TORCH_CUDA_CU_API TransformPropagator : public TransformPropagatorBase { virtual void recordReplayedPos(TensorView *tv, size_t pos) override; virtual size_t retrieveReplayedPos(TensorView *tv) override; + virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) override; + virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) override; public: static void from(TensorView* tv); From 2af1b1a8a1db2ac6b7f190febc9631edbdc2b8aa Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 21:51:22 -0700 Subject: [PATCH 014/100] fix --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 32 ++++++++++++++++------ torch/csrc/jit/codegen/cuda/consume_at.h | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 39822992bf3ec..ad5a4a4c6a0be 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -41,12 +41,31 @@ void ConsumeAt::consumeAllAt( void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { if (consume.count(tv)) { + pos = std::min(getMaxComputeAtPos(tv), pos); tv->setComputeAt(pos); } else { replayed_pos[tv] = pos; } } +size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { + auto dom = tv->domain()->domain(); + auto first_reduction = + std::find_if(dom.begin(), dom.end(), [](IterDomain* id) { + return id->isReduction(); + }); + + auto first_vectorized_axis = + std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { + return isParallelTypeVectorize(id->getParallelType()) || + ((mode == ComputeAtMode::BestEffort || + mode == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + }); + + return std::distance(dom.begin(), first_vectorized_axis); +} + // Return the max position in consumer that producer can be inlined to // Cannot inline: // Reduction dimensions in producer @@ -78,8 +97,7 @@ size_t ConsumeAt::getReplayablePosPasC( .mapConsumerToProducer(consumer->domain(), producer->domain()); auto replay_PasC = - BestEffortReplay::replayPasC(producer, consumer, -1, - pairwise_root_map); + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); // Look for id's that map to a consumer id that's vectorized auto c2p_replay_map = replay_PasC.getReplay(); @@ -127,16 +145,14 @@ size_t ConsumeAt::getReplayablePosCasP( id->getParallelType() == ParallelType::Unroll); }); - auto max_producer_pos = std::distance(p_dom.begin(), - first_vectorized_axis); + auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( producer->domain(), consumer->domain()); auto replay_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, - pairwise_root_map); + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); // Look for id's that map to a consumer id that's vectorized auto p2c_replay_map = replay_CasP.getReplay(); @@ -172,7 +188,7 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { return 0; } -size_t ConsumeAt::getReplayPosPasC(TensorView *producer, TensorView *consumer) { +size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); @@ -195,7 +211,7 @@ size_t ConsumeAt::getReplayPosPasC(TensorView *producer, TensorView *consumer) { return pos; } -size_t ConsumeAt::getReplayPosCasP(TensorView *consumer, TensorView *producer) { +size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 2afafc84d8500..75147ff533868 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -13,7 +13,7 @@ namespace cuda { class ConsumeAt : public TransformPropagatorBase { size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); -// size_t getReplayablePos(TensorView* tv); // TODO: delete + size_t getMaxComputeAtPos(TensorView* tv); void hoistInnermostBroadcast(); void computeMaxProducerPos(); From 8b4c8a015ff83da644377dc9d4982b340826c996 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 23:10:00 -0700 Subject: [PATCH 015/100] fix --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index ad5a4a4c6a0be..796442efa1e14 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -41,7 +41,11 @@ void ConsumeAt::consumeAllAt( void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { if (consume.count(tv)) { - pos = std::min(getMaxComputeAtPos(tv), pos); + auto max_pos = getMaxComputeAtPos(tv); + if (pos > max_pos) { + replayed_pos[tv] = pos; + pos = max_pos; + } tv->setComputeAt(pos); } else { replayed_pos[tv] = pos; @@ -177,20 +181,21 @@ size_t ConsumeAt::getReplayablePosCasP( } size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { + size_t pos1 = 0, pos2 = 0; if (consume.count(tv)) { - return tv->getComputeAtPosition(); - } else { - auto it = replayed_pos.find(tv); - if (it != replayed_pos.end()) { - return it->second; - } + pos1 = tv->getComputeAtPosition(); + } + auto it = replayed_pos.find(tv); + if (it != replayed_pos.end()) { + pos2 = it->second; } - return 0; + return std::max(pos1, pos2); } size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); + // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -214,6 +219,7 @@ size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); + // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); From 89a2a75d63c5ddef95fa6b20136f1fd9fa99e3fc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 2 Jun 2022 23:55:10 -0700 Subject: [PATCH 016/100] don't set ca pos for fusion input --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 796442efa1e14..966243e206c26 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -46,7 +46,9 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { replayed_pos[tv] = pos; pos = max_pos; } - tv->setComputeAt(pos); + if (!tv->isFusionInput()) { + tv->setComputeAt(pos); + } } else { replayed_pos[tv] = pos; } From b9a923957677261a78ac358d48df3757bc3dcc83 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 3 Jun 2022 00:46:55 -0700 Subject: [PATCH 017/100] fix --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 966243e206c26..a16ff3417a511 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -48,6 +48,8 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { } if (!tv->isFusionInput()) { tv->setComputeAt(pos); + } else { + replayed_pos[tv] = pos; } } else { replayed_pos[tv] = pos; From 49d19a4c3c684e7375703ff918384639561ac799 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 3 Jun 2022 01:37:46 -0700 Subject: [PATCH 018/100] siblings --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 811 +-------------------- 1 file changed, 19 insertions(+), 792 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 9e4ed34b06913..c85dbc99a5c9a 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -40,244 +40,6 @@ std::deque> tvChains( return tv_chains; } -// bool validateDomain(TensorView* tv, TensorDomain* new_td) { -// auto first_mismatch = -// BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); -// return first_mismatch >= (int)tv->getMaxProducerPosition() && -// first_mismatch >= (int)tv->getComputeAtPosition(); -// } - -// // Return the max position in consumer that producer can be inlined to -// // Cannot inline: -// // Reduction dimensions in producer -// // Block broadcast dimensions in producer -// // Vectorized dimensions in producer or consumer -// // Unrolled dimensions in producer or consumer -// // Dimensions derived from root dimensions that exist in both but are -// // unmappable -// unsigned int getReplayablePosPasC( -// TensorView* producer, -// TensorView* consumer, -// const std::unordered_set& unmappable_producer_dims, -// ComputeAtMode mode) { -// // Check if any consumer dimensions are marked as vectorize as producer can -// // not be inlined to vectorized dimensions in consumer. -// auto c_dom = consumer->domain()->domain(); -// auto vector_dim_it = -// std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) { -// return isParallelTypeVectorize(id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// id->getParallelType() == ParallelType::Unroll); -// }); - -// // Limit max position based on vectorized dims in consumer. -// auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); - -// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); -// auto c2p_root_map = -// PairwiseRootDomainMap(producer, consumer) -// .mapConsumerToProducer(consumer->domain(), producer->domain()); - -// auto replay_PasC = -// BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - -// // Look for id's that map to a consumer id that's vectorized -// auto c2p_replay_map = replay_PasC.getReplay(); - -// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; -// consumer_pos--) { -// auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); -// if (map_it != c2p_replay_map.end()) { -// auto p_id = map_it->second; -// // If we find a consumer dim that maps to a producer dim that's -// // vectorized or unrolled limit max compute at by it. -// if (isParallelTypeVectorize(p_id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// p_id->getParallelType() == ParallelType::Unroll)) { -// max_consumer_pos = consumer_pos - 1; -// } -// } -// } - -// // Start at max position and work backwards, try to find a location where -// // producer can be inlined. -// for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; -// consumer_pos--) { -// // Grab all root dimensions of consumer as roots must be used to understand -// // inlining potential. -// auto consumer_root_dim_vals = -// IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); -// // convert to iter domains -// auto consumer_root_dim_ids = -// ir_utils::filterByType(consumer_root_dim_vals); -// // If any root dimensions cannot be mapped to producer we can't inline. If -// // any root dimension -// if (std::any_of( -// consumer_root_dim_ids.begin(), -// consumer_root_dim_ids.end(), -// [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { -// auto p_root_id_it = c2p_root_map.find(c_root_id); -// if (p_root_id_it == c2p_root_map.end()) { -// return false; -// } -// auto p_id = p_root_id_it->second; -// return unmappable_producer_dims.find(p_id) != -// unmappable_producer_dims.end(); -// })) { -// continue; -// } -// return consumer_pos; -// } - -// return 0; -// } - -// // Return the max position in producer that can be inlined to consumer -// // Cannot inline: -// // Reduction dimensions in producer -// // Vectorized dimensions in producer or consumer -// // Unrolled dimensions in producer or consumer -// // Dimensions derived from root dimensions that exist in both but are -// // unmappable -// unsigned int getReplayablePosCasP( -// TensorView* consumer, -// TensorView* producer, -// const std::unordered_set& unmappable_producer_dims, -// ComputeAtMode mode) { -// auto p_dom = producer->domain()->domain(); -// auto first_reduction = -// std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { -// return id->isReduction(); -// }); - -// auto first_vectorized_axis = -// std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) { -// return isParallelTypeVectorize(id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// id->getParallelType() == ParallelType::Unroll); -// }); - -// auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); - -// auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); -// auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( -// producer->domain(), consumer->domain()); - -// auto replay_CasP = -// BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - -// // Look for id's that map to a consumer id that's vectorized -// auto p2c_replay_map = replay_CasP.getReplay(); - -// for (size_t producer_pos = max_producer_pos; producer_pos > 0; -// producer_pos--) { -// auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); -// if (map_it != p2c_replay_map.end()) { -// auto c_id = map_it->second; -// // If we find a producer dim that maps to a consumer vectorized or -// // unrolled dim, limit max compute at by it -// if (isParallelTypeVectorize(c_id->getParallelType()) || -// ((mode == ComputeAtMode::BestEffort || -// mode == ComputeAtMode::MostInlined) && -// c_id->getParallelType() == ParallelType::Unroll)) { -// max_producer_pos = producer_pos - 1; -// } -// } -// } - -// for (size_t producer_pos = max_producer_pos; producer_pos > 0; -// producer_pos--) { -// auto all_vals = DependencyCheck::getAllValsBetween( -// {producer->getMaybeRFactorDomain().begin(), -// producer->getMaybeRFactorDomain().end()}, -// {p_dom.begin(), p_dom.begin() + producer_pos}); - -// // If any root dims could have mapped to consumer, but don't, then we can't -// // compute at this point -// if (std::any_of( -// producer->getMaybeRFactorDomain().begin(), -// producer->getMaybeRFactorDomain().end(), -// [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { -// return std::find(all_vals.begin(), all_vals.end(), p_root_id) != -// all_vals.end() && -// unmappable_producer_dims.find(p_root_id) != -// unmappable_producer_dims.end(); -// })) { -// continue; -// } - -// return producer_pos; -// } -// return 0; -// } - -// unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) { -// unsigned int ret = tv->getComputeAtPosition(); - -// // Still assuming we only have block broadcast for now. -// // This part may change -// while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) { -// ret--; -// } - -// return ret; -// } - -// // Try to find the aligned position on consumer's domain corresponding to the -// // compute at position of producer domain. Used in computeAt pass only. No -// // checking on actual producer-consumer relationship. -// unsigned int getConsumerPosAlignedToProducerCA( -// TensorView* consumer, -// TensorView* producer) { -// // Locate consumer's position that aligns with -// // the producer's new compute at axis. We need broadcast axes forwarded so we -// // need to replay PasC as CasP will not forward braodcast dims. For example -// // if we have: -// // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) -// // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will -// // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to -// // NVFuserTest.FusionComplexBCast1_CUDA - -// auto c2p_map = -// BestEffortReplay::replayPasC( -// producer, -// consumer, -// -1, -// // Compute at root domain may not be valid here, as all -// // producers don't have to be able to map into consumer at -// // max producer position. Since computeAt should be valid -// // and this mechanism is only intended to lower produce -// // position of consumer, we can simply use the pairwise map. -// PairwiseRootDomainMap(producer, consumer)) -// .getReplay(); - -// // Find the innermost position of consumer that has -// // been mapped within the producer ca axis. -// unsigned int consumer_pos = consumer->nDims(); -// while (consumer_pos > 0) { -// auto consumer_id = consumer->axis((int)consumer_pos - 1); -// auto p_dom = producer->domain()->domain(); -// if (std::any_of( -// p_dom.begin(), -// p_dom.begin() + producer->getComputeAtPosition(), -// [&consumer_id, &c2p_map](IterDomain* p_id) { -// auto c_id_it = c2p_map.find(consumer_id); -// if (c_id_it != c2p_map.end()) { -// return c_id_it->second == p_id; -// } -// return false; -// })) { -// break; -// } -// consumer_pos--; -// } - -// return consumer_pos; -// } - std::unordered_set getAllTVsBetween( TensorView* producer, TensorView* consumer) { @@ -352,6 +114,23 @@ TensorView * getCommonConsumer( return common_consumer; } +void pullInSiblings(std::unordered_set &s) { + for (auto tv : s) { + auto tvd = tv->definition(); + if (tvd != nullptr) { + auto outs = tvd->outputs(); + auto out_tvs = ir_utils::filterByType(outs); + for (auto sibling_tv : out_tvs) { + if (sibling_tv == tv) { + continue; + } + std::cout << "pulling in " << sibling_tv << std::endl; + s.emplace(sibling_tv); + } + } + } +} + // I am just trying to get the same set of tensors being transformed matching // the previous behavior of ComputeAt. The algorithm to compute this set is @@ -370,6 +149,7 @@ std::unordered_set getPropagationSubgraph( TensorView *common_consumer = getCommonConsumer(producer, consumer); if (common_consumer != nullptr) { auto result = getAllTVsBetween(producer, common_consumer); + pullInSiblings(result); return result; } auto result_vals = DependencyCheck::getAllDependentVals({producer}); @@ -381,6 +161,7 @@ std::unordered_set getPropagationSubgraph( result_tvs.end(), std::inserter(result, result.begin()), [](TensorView* tv) { return !tv->uses().empty(); }); + pullInSiblings(result); return result; } @@ -434,560 +215,6 @@ void ComputeAt::runWith( mode); } -// namespace { - -// // Checks if producer and consumer are transformed consistently so that to -// // satisfy the provided compute at position. This means no replay is actually -// // necessary for the compute at requested. If consumer_pos then -// // consumer_or_producer_pos is relative to the consumer and skipReplay returns -// // the associated position in producer. -// // -// // If producer and consumer are not transformed consistently with provided -// // postition, returns -1. -// int skipReplay( -// const TensorView* producer, -// const TensorView* consumer, -// int consumer_or_producer_pos, -// bool consumer_pos = true) { -// FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); - -// const auto c2p_root_map = -// PairwiseRootDomainMap(producer, consumer) -// .mapConsumerToProducer(consumer->domain(), producer->domain()); - -// // IterDomains in consumer root also in producer root -// std::unordered_set mapped_consumer_roots; -// for (auto entry : c2p_root_map) { -// mapped_consumer_roots.emplace(entry.first); -// } - -// const auto consumer_domain = consumer->domain()->domain(); - -// auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( -// mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - -// std::unordered_set mapped_consumer_domain_ids( -// mapped_consumer_domain_ids_vec.begin(), -// mapped_consumer_domain_ids_vec.end()); - -// const auto producer_domain = producer->domain()->domain(); - -// auto it_consumer = consumer_domain.begin(); -// auto it_producer = producer_domain.begin(); - -// auto best_effort_PasC = BestEffortReplay::replayPasC( -// producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - -// auto c2p_map = best_effort_PasC.getReplay(); - -// int mismatched_consumer_pos = 0; -// int mismatched_producer_pos = 0; -// while (it_consumer != consumer_domain.end()) { -// auto consumer_id = *it_consumer; -// if (!mapped_consumer_domain_ids.count(consumer_id)) { -// ++it_consumer; -// mismatched_consumer_pos++; -// continue; -// } - -// auto c2p_it = c2p_map.find(consumer_id); -// if (c2p_it == c2p_map.end()) { -// break; -// } - -// if (it_producer == producer_domain.end()) { -// break; -// } - -// auto producer_id = *it_producer; - -// if (c2p_it->second == producer_id) { -// ++mismatched_consumer_pos; -// ++mismatched_producer_pos; -// ++it_consumer; -// ++it_producer; -// if (consumer_pos) { -// if (consumer_or_producer_pos == mismatched_consumer_pos) { -// return mismatched_producer_pos; -// } -// } else { -// if (consumer_or_producer_pos == mismatched_producer_pos) { -// return mismatched_consumer_pos; -// } -// } -// } else { -// break; -// } -// } -// return -1; -// } - -// } // namespace - -// // Actually applies transformation -// unsigned int ComputeAt::backwardComputeAt_impl( -// TensorView* producer, -// TensorView* consumer, -// unsigned int consumer_compute_at_pos) { -// FUSER_PERF_SCOPE("backwardComputeAt_impl"); - -// auto max_consumer_compute_at_pos = -// getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); - -// if (mode_ == ComputeAtMode::BestEffort) { -// consumer_compute_at_pos = -// std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); -// } else if (mode_ == ComputeAtMode::MostInlined) { -// consumer_compute_at_pos = max_consumer_compute_at_pos; -// } else { -// TORCH_INTERNAL_ASSERT( -// consumer_compute_at_pos <= max_consumer_compute_at_pos, -// "Invalid compute at position detected in compute at when trying to replay producer: ", -// producer, -// " as consumer: ", -// consumer, -// " tried to do this at position: ", -// consumer_compute_at_pos, -// " but max position that's allowed is ", -// max_consumer_compute_at_pos); -// } - -// // Short cut if no replay is necessary -// auto maybe_producer_pos = -// skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); -// if (maybe_producer_pos >= 0) { -// if (!producer->isFusionInput()) { -// producer->setComputeAt((unsigned int)maybe_producer_pos); -// } -// consumer->setMaxProducer(consumer_compute_at_pos); -// return (unsigned int)maybe_producer_pos; -// } - -// auto replay_producer_pair = TransformReplay::replayPasC( -// producer, -// consumer, -// (int)consumer_compute_at_pos, -// PairwiseRootDomainMap(producer, consumer)); - -// if (replay_producer_pair.second == 0) { -// return 0; -// } - -// if (replay_producer_pair.second >= producer->getComputeAtPosition()) { -// const TensorDomain* current_domain = producer->domain(); -// TensorDomain* new_domain = replay_producer_pair.first; - -// TORCH_INTERNAL_ASSERT( -// validateDomain(producer, new_domain), -// "Tried to set the domain of ", -// producer, -// " to ", -// new_domain, -// " but that would invalidate previously compute at position or max producer position."); - -// producer->setDomain(new_domain); -// if (!producer->isFusionInput()) { -// producer->setComputeAt(replay_producer_pair.second); -// } - -// consumer->setMaxProducer(consumer_compute_at_pos); -// root_map_.setAlias(current_domain, new_domain); -// } - -// return replay_producer_pair.second; -// } - -// // Actually applies transformation, replay consumer based on producer, set -// // compute at of producer, set pass position of consumer, return position -// // relative to consumer -// unsigned int ComputeAt::forwardComputeAt_impl( -// TensorView* producer, -// TensorView* consumer, -// unsigned int producer_compute_at_pos) { -// FUSER_PERF_SCOPE("forwardComputeAt_impl"); - -// auto max_producer_compute_at_pos = -// getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); - -// if (mode_ == ComputeAtMode::BestEffort) { -// producer_compute_at_pos = -// std::min(producer_compute_at_pos, max_producer_compute_at_pos); -// } else if (mode_ == ComputeAtMode::MostInlined) { -// producer_compute_at_pos = max_producer_compute_at_pos; -// } else { -// TORCH_INTERNAL_ASSERT( -// producer_compute_at_pos <= max_producer_compute_at_pos, -// "Invalid compute at position detected in compute at when trying to replay consumer: ", -// consumer, -// " as producer: ", -// producer, -// " tried to do this at position: ", -// producer_compute_at_pos, -// " but max position that's allowed is ", -// max_producer_compute_at_pos); -// } - -// // Short cut if no replay is necessary -// auto maybe_consumer_pos = -// skipReplay(producer, consumer, (int)producer_compute_at_pos, false); -// if (maybe_consumer_pos > -1) { -// if (!producer->isFusionInput()) { -// producer->setComputeAt(producer_compute_at_pos); -// } -// consumer->setMaxProducer((unsigned int)maybe_consumer_pos); -// return (unsigned int)maybe_consumer_pos; -// } - -// auto replay_consumer_pair = TransformReplay::replayCasP( -// consumer, -// producer, -// (int)producer_compute_at_pos, -// PairwiseRootDomainMap(producer, consumer)); - -// if (producer_compute_at_pos > producer->getComputeAtPosition()) { -// if (!producer->isFusionInput()) { -// producer->setComputeAt((int)producer_compute_at_pos); -// } -// } - -// if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) { -// const TensorDomain* current_domain = consumer->domain(); -// TensorDomain* new_domain = replay_consumer_pair.first; - -// TORCH_INTERNAL_ASSERT( -// validateDomain(consumer, new_domain), -// "Tried to set the domain of ", -// consumer, -// " to ", -// new_domain, -// " but that would invalidate previously compute at position or max producer position."); - -// consumer->setDomain(new_domain); -// consumer->setMaxProducer(replay_consumer_pair.second); -// root_map_.setAlias(current_domain, new_domain); -// } - -// return replay_consumer_pair.second; -// } - -// // Similar to backward traversal in traverseAllKnown but we should only apply -// // computeAt if it will increase computeAt positions. -// void ComputeAt::traverseBackward() { -// FUSER_PERF_SCOPE("ComputeAt::traverseBackward"); -// if (reference_ == producer_) { -// // Forward compute at don't need to run backward traversal -// producer_position_ = reference_position_; -// return; -// } - -// // propagate *backward* through all *producer* use_chains or from *producer* -// // to common_consumer if common_consumer exists. Only apply transform if -// // increases computeAt position. -// auto chains = -// tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); - -// for (auto tv_chain : chains) { -// TensorView* running_producer = tv_chain.back(); -// TensorView* running_consumer = nullptr; -// unsigned int running_consumer_pos = reference_position_; -// tv_chain.pop_back(); - -// TORCH_INTERNAL_ASSERT(running_producer == consumer_); - -// while (!tv_chain.empty()) { -// running_consumer = running_producer; -// running_producer = tv_chain.back(); -// tv_chain.pop_back(); -// running_consumer_pos = backwardComputeAt_impl( -// running_producer, running_consumer, running_consumer_pos); -// } - -// TORCH_INTERNAL_ASSERT( -// running_producer == producer_, -// "Compute at backward traversal ended up on something other than the producer."); -// producer_position_ = running_consumer_pos; -// } -// } - -// void ComputeAt::traverseForward() { -// FUSER_PERF_SCOPE("ComputeAt::traverseForward"); - -// // propagate forward through all *producer* use_chains or from *producer* to -// // common_consumer if common_consumer exists. -// auto chains = producer_use_chains_; -// if (common_consumer_ != nullptr) { -// chains = tvChains( -// DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); -// } - -// // propagate forward through all chains -// for (auto tv_dep_chain : chains) { -// TensorView* running_producer = nullptr; -// TensorView* running_consumer = tv_dep_chain.front(); -// tv_dep_chain.pop_front(); -// unsigned int running_producer_pos = producer_position_; - -// TORCH_INTERNAL_ASSERT(running_consumer == producer_); - -// while (!tv_dep_chain.empty()) { -// running_producer = running_consumer; -// running_consumer = tv_dep_chain.front(); -// tv_dep_chain.pop_front(); -// running_producer_pos = forwardComputeAt_impl( -// running_producer, running_consumer, running_producer_pos); -// } -// } -// } - -// void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) { -// if (consumer_tv->definition() == nullptr) { -// consumer_tv->setMaxProducer(0, true); -// } - -// unsigned int new_consummer_pa_pos = 0; - -// // Re-compute the max producer position as one or more -// // of the producers of this consumer have updated their -// // compute at position. -// for (auto inp : ir_utils::producerTvsOf(consumer_tv)) { -// if (!inp->isFusionInput()) { -// // Locate consumer's position that aligns with -// // the producer's new compute at axis. -// unsigned int inp_ca_pos_to_consumer = -// getConsumerPosAlignedToProducerCA(consumer_tv, inp); - -// // Populate the max consumer position required by -// // producer compute at. -// new_consummer_pa_pos = -// std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer); -// } -// } - -// consumer_tv->setMaxProducer(new_consummer_pa_pos, true); -// } - -// void ComputeAt::hoistInnermostBroadcast() { -// auto fusion = producer_->fusion(); - -// std::unordered_set consumers_to_update; - -// auto all_vals = fusion->usedMathVals(); -// auto all_tvs = ir_utils::filterByType(all_vals); - -// for (auto running_producer : all_tvs) { -// if (!running_producer->isFusionInput()) { -// auto producer_ca_pos = running_producer->getComputeAtPosition(); -// // Find the innermost iterdomain that is not a broadcast -// auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer); -// // Update the compute at pos of this producer if the original -// // compute at is within inner most broadcast axes -// if (new_ca_pos < producer_ca_pos) { -// running_producer->setComputeAt(new_ca_pos, true); -// } -// // Mark all consumers of this producer for later produce -// // position update. -// // This is safe with segmented fusion. TV uses will reset -// // when FusionSegmentGuard try to change the IO. -// auto tv_consumers = ir_utils::consumerTvsOf(running_producer); -// consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end()); -// } -// } -// } - -// void ComputeAt::updateSiblings() { -// // Track which consumers may have a wrong produce at position to update -// // later -// auto updateSiblingsOfTv = [&](TensorView* tv) { -// if (tv->definition() == nullptr) { -// return; -// } - -// std::unordered_set consumers_to_update; - -// if (tv->definition()->outputs().size() > 1) { -// auto outs = tv->definition()->outputs(); -// auto out_tvs = ir_utils::filterByType(outs); -// for (auto sibling_tv : out_tvs) { -// if (sibling_tv == tv) { -// continue; -// } - -// std::unordered_map tv_to_sibling_map; -// TORCH_INTERNAL_ASSERT( -// tv->getRootDomain().size() == sibling_tv->getRootDomain().size(), -// "Error replaying multiple output expressions in computeAt."); - -// // Propagate any root parallelization as fullSelfReplay expects it. -// for (const auto i : c10::irange(sibling_tv->getRootDomain().size())) { -// auto id = tv->getRootDomain()[i]; -// auto sibling_id = sibling_tv->getRootDomain()[i]; -// if (id->getParallelType() != ParallelType::Serial && -// sibling_id->getParallelType() == ParallelType::Serial) { -// sibling_id->parallelize(id->getParallelType()); -// } else if ( -// id->getParallelType() == ParallelType::Serial && -// sibling_id->getParallelType() != ParallelType::Serial) { -// id->parallelize(sibling_id->getParallelType()); -// } -// } -// auto sibling_domain = -// TransformReplay::fullSelfReplay(sibling_tv->domain(), tv->domain()); -// validateDomain(sibling_tv, sibling_domain); -// sibling_tv->setDomain(sibling_domain); -// sibling_tv->setComputeAt(tv->getComputeAtPosition()); -// sibling_tv->setMaxProducer(tv->getMaxProducerPosition()); -// auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv); -// consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end()); -// } -// } - -// // Update sibling consumer tv's max producer position -// for (auto consumer : consumers_to_update) { -// this->resetMaxProducerPos(consumer); -// } -// }; - -// // Find all tensor views that may have been modified -// auto chains = producer_use_chains_; -// if (common_consumer_ != nullptr) { -// chains = tvChains( -// DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); -// } - -// std::unordered_set participating_tvs; -// for (auto chain : chains) { -// participating_tvs.insert(chain.begin(), chain.end()); -// } - -// for (auto tv : participating_tvs) { -// updateSiblingsOfTv(tv); -// } -// } - -// void ComputeAt::runPass() { -// FUSER_PERF_SCOPE("ComputeAt::runPass"); - - // // Traverse backward through all dep chains from producer to consumer - // traverseBackward(); - - // // Start at producer and traverse forward through all chains - // traverseForward(); - - // // Back off on inlining the inner broadcast axes - // hoistInnermostBroadcast(); - - // // Update siblings of multi output expressions - // updateSiblings(); - - // // Update the compute at position of all consumers, this used to be done - // // during the compute at pass itself, but its cleaner to do this as a cleanup - // // pass similar to hoistInnermostBroadcast and updateSiblings. - // std::unordered_set all_consumers; - - // // Find all tensor views that may have been modified - // auto chains = producer_use_chains_; - // if (common_consumer_ != nullptr) { - // chains = tvChains( - // DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - // } - - // for (const auto& chain : chains) { - // for (auto tv : chain) { - // all_consumers.emplace(tv); - // } - // } - - // // Reset max producer position of all tensor views. - // for (auto tv : all_consumers) { - // resetMaxProducerPos(tv); - // } -// } - -// void ComputeAt::buildUnmappableDims() { -// auto all_tvs = ir_utils::allTvs(producer_->fusion()); -// for (auto tv : all_tvs) { -// auto consumers = ir_utils::consumerTvsOf(tv); -// for (auto consumer : consumers) { -// // Grab dimensions in producer and consumer that are mappable to eachother -// // based on the computeAtRootDomainMap. This will tell us which dimensions -// // can be inlined based on avoiding trying to inline non-trivial -// // reduction structures. -// auto mappable_roots = -// root_map_.getMappableDims(tv->domain(), consumer->domain()); -// for (auto tv_root_id : tv->getMaybeRFactorDomain()) { -// if (mappable_roots.find(tv_root_id) == mappable_roots.end() && -// !tv_root_id->isTrivialReduction()) { -// unmappable_dims_.emplace(tv_root_id); -// } -// } -// } -// } -// } - -// ComputeAt::ComputeAt( -// TensorView* _producer, -// TensorView* _consumer, -// TensorView* _reference, -// unsigned int _reference_position, -// ComputeAtMode _mode) -// : producer_(_producer), -// consumer_(_consumer), -// reference_(_reference), -// reference_position_(_reference_position), -// mode_(_mode) { -// TORCH_INTERNAL_ASSERT( -// reference_ == producer_ || reference_ == consumer_, -// "For compute at reference must be producer or consumer, it's neither.", -// " reference: ", -// reference_, -// " consumer: ", -// consumer_, -// " producer: ", -// producer_); -// TORCH_INTERNAL_ASSERT( -// reference_position_ >= 0 && reference_position_ <= reference_->nDims(), -// "Invalid computeAt axis, received ", -// reference_position_, -// " but should be > -", -// reference_->nDims(), -// " and <= ", -// reference_->nDims(), -// "."); - -// producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_)); - -// // Look through all the use chains of producer. Check if there's a single -// // consumer for all chains at or after the consumer specified in the computeAt -// // call. -// setCommonConsumer(); - -// root_map_.build(); - -// buildUnmappableDims(); -// } - -// ComputeAt::ComputeAt( -// std::unordered_set subgraph, -// TensorView* reference, -// unsigned int reference_position, -// ComputeAtMode mode) -// : TransformPropagator(reference, reference_position), -// subgraph_(std::move(subgraph)), -// mode_(mode) { -// TORCH_INTERNAL_ASSERT( -// subgraph_.count(reference), -// "Reference must be within subgraph."); -// TORCH_INTERNAL_ASSERT( -// reference_position >= 0 && reference_position <= reference->nDims(), -// "Invalid computeAt axis, received ", -// reference_position, -// " but should be > -", -// reference->nDims(), -// " and <= ", -// reference->nDims(), -// "."); -// } - } // namespace cuda } // namespace fuser } // namespace jit From 789532e849a88cec11e5b9ceb327b182bd9fefc6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 3 Jun 2022 01:41:40 -0700 Subject: [PATCH 019/100] cleanup --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index c85dbc99a5c9a..b0cabac7fb14b 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -124,7 +124,6 @@ void pullInSiblings(std::unordered_set &s) { if (sibling_tv == tv) { continue; } - std::cout << "pulling in " << sibling_tv << std::endl; s.emplace(sibling_tv); } } From 1d94372c35953d33bf61b0ae553f18e0abb9a37d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 3 Jun 2022 15:15:58 -0700 Subject: [PATCH 020/100] unmappable_dims --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 85 ++++++++++++++++++- torch/csrc/jit/codegen/cuda/consume_at.h | 7 ++ .../csrc/jit/codegen/cuda/transform_replay.h | 2 +- 3 files changed, 90 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index a16ff3417a511..e828334154735 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -28,6 +28,32 @@ ConsumeAt::ConsumeAt( " and <= ", reference->nDims(), "."); + + buildUnmappableDims(); +} + +void ConsumeAt::buildUnmappableDims() { + ComputeAtRootDomainMap root_map; + root_map.build(); + + auto all_tvs = ir_utils::allTvs(starting_tv->fusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline non-trivial + // reduction structures. + auto mappable_roots = + root_map.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end() && + !tv_root_id->isTrivialReduction()) { + unmappable_dims.emplace(tv_root_id); + } + } + } + } } void ConsumeAt::consumeAllAt( @@ -126,7 +152,37 @@ size_t ConsumeAt::getReplayablePosPasC( } } - return max_consumer_pos; + // Start at max position and work backwards, try to find a location where + // producer can be inlined. + for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; + consumer_pos--) { + // Grab all root dimensions of consumer as roots must be used to understand + // inlining potential. + auto consumer_root_dim_vals = + IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); + // convert to iter domains + auto consumer_root_dim_ids = + ir_utils::filterByType(consumer_root_dim_vals); + // If any root dimensions cannot be mapped to producer we can't inline. If + // any root dimension + if (std::any_of( + consumer_root_dim_ids.begin(), + consumer_root_dim_ids.end(), + [this, &c2p_root_map](IterDomain* c_root_id) { + auto p_root_id_it = c2p_root_map.find(c_root_id); + if (p_root_id_it == c2p_root_map.end()) { + return false; + } + auto p_id = p_root_id_it->second; + return unmappable_dims.find(p_id) != + unmappable_dims.end(); + })) { + continue; + } + return consumer_pos; + } + + return 0; } // Return the max position in producer that can be inlined to consumer @@ -181,7 +237,30 @@ size_t ConsumeAt::getReplayablePosCasP( } } - return max_producer_pos; + for (size_t producer_pos = max_producer_pos; producer_pos > 0; + producer_pos--) { + auto all_vals = DependencyCheck::getAllValsBetween( + {producer->getMaybeRFactorDomain().begin(), + producer->getMaybeRFactorDomain().end()}, + {p_dom.begin(), p_dom.begin() + producer_pos}); + + // If any root dims could have mapped to consumer, but don't, then we can't + // compute at this point + if (std::any_of( + producer->getMaybeRFactorDomain().begin(), + producer->getMaybeRFactorDomain().end(), + [this, &all_vals](IterDomain* p_root_id) { + return std::find(all_vals.begin(), all_vals.end(), p_root_id) != + all_vals.end() && + unmappable_dims.find(p_root_id) != + unmappable_dims.end(); + })) { + continue; + } + + return producer_pos; + } + return 0; } size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { @@ -197,7 +276,7 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { } size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { - size_t max_pos = getReplayablePosPasC(producer, consumer); + size_t max_pos = getReplayablePosPasC(producer, consumer); // TODO: I don't think this makes sense. It should check all consumers, not just the "from" consumer! What is the purpose of this, given that we already have a getMaxComputeAtPos? size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 75147ff533868..2b8ec3ce67a5a 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -22,6 +22,13 @@ class ConsumeAt : public TransformPropagatorBase { std::unordered_map replayed_pos; std::unordered_set unvisited; + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims; + + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + virtual bool shouldPropagate(TensorView* tv) override; virtual void recordReplayedPos(TensorView* tv, size_t pos) override; virtual size_t retrieveReplayedPos(TensorView* tv) override; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 6e94d4fb0ce15..b57cb5c22ff27 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -155,6 +155,7 @@ class TORCH_CUDA_CU_API TransformReplay { }; class TORCH_CUDA_CU_API TransformPropagatorBase { + protected: // TODO: keep? // This example comes from a BN kernel, the domain: // @@ -179,7 +180,6 @@ class TORCH_CUDA_CU_API TransformPropagatorBase { TensorView* starting_tv = nullptr; size_t starting_pos; - protected: TransformPropagatorBase(TensorView* from, size_t starting_pos); bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr); bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr); From 436f7a0d5dd4ec6939b36d07e51bb110f806b99b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 13 Jun 2022 15:50:24 -0700 Subject: [PATCH 021/100] tmp test --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 9c77d1b4d3978..f9e305853fc5a 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -16743,7 +16743,9 @@ TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { tv2->axis(-2)->parallelize(ParallelType::TIDx); tv3->axis(-2)->parallelize(ParallelType::TIDx); + fusion->print(); tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + fusion->print(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({17, 18, 128, 1}, options); @@ -23394,6 +23396,24 @@ TEST_F(NVFuserTest, FusionRepro1713_CUDA) { __FILE__); } +TEST_F(NVFuserTest, TMP) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + tv2->reorder({{1, 2}, {2, 1}}); + + fusion.print(); + tv1->computeAt(tv2, 2); + fusion.print(); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From a49313f56455eac87770b0d19f0bdce735eac948 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 10:49:10 -0700 Subject: [PATCH 022/100] pull new TransformPropagator --- .../jit/codegen/cuda/transform_replay.cpp | 194 +----------------- .../csrc/jit/codegen/cuda/transform_replay.h | 77 +------ 2 files changed, 12 insertions(+), 259 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d1bcd5064b394..f3bffdbea8a4a 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -672,197 +672,9 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { replayed_pos[to] = replay.second; } -std::deque consumersOf(TensorView* tv) { - std::deque consumer_tvs; - for (auto def : tv->uses()) { - auto outs = tvOutputs(def); - consumer_tvs.insert(consumer_tvs.end(), outs.begin(), outs.end()); - } - return deduplicate(consumer_tvs); -} - -std::deque producersFor(TensorView* tv) { - auto def = tv->definition(); - if (def == nullptr) { - return {}; - } - - return deduplicate(tvInputs(def)); -} - -}; // namespace - -bool TransformPropagatorBase::replayPasC( - TensorView* producer_tv, - TensorView* consumer_tv) { - if (producer_tv == starting_tv) { - return false; - } - - auto consumer_pos = getReplayPosPasC(producer_tv, consumer_tv); - if (consumer_pos == 0) { - return false; - } - - auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto replayed_producer = TransformReplay::replayPasC( - producer_tv, consumer_tv, consumer_pos, pairwiseMap); - - auto producer_root = producer_tv->getMaybeRFactorDomain(); - auto replayed_domain = replayed_producer.first->domain(); - - // Find the number of root IDs involved in the transformation - auto dep_vals = DependencyCheck::getAllValsBetween( - {producer_root.begin(), producer_root.end()}, - {replayed_domain.begin(), - replayed_domain.begin() + replayed_producer.second}); - - std::unordered_set dep_vals_set{dep_vals.begin(), dep_vals.end()}; - - auto n_transformed_root_dims = std::count_if( - producer_root.begin(), - producer_root.end(), - [&dep_vals_set](IterDomain* root_id) { - return dep_vals_set.find(root_id) != dep_vals_set.end(); - }); - - auto producer_pos = retrieveReplayedPos(producer_tv); - auto replayed_root_dims_it = n_replayed_root_dims.find(producer_tv); - if (producer_pos > 0 && replayed_root_dims_it != n_replayed_root_dims.end()) { - if (n_transformed_root_dims < replayed_root_dims_it->second || - (n_transformed_root_dims == replayed_root_dims_it->second && - replayed_producer.second <= producer_pos)) { - return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - } - } - - producer_tv->setDomain(replayed_producer.first); - recordReplayedPos(producer_tv, replayed_producer.second); - n_replayed_root_dims[producer_tv] = n_transformed_root_dims; - - return true; -} - -bool TransformPropagatorBase::replayCasP( - TensorView* consumer_tv, - TensorView* producer_tv) { - if (consumer_tv == starting_tv) { - return false; - } - - auto producer_pos = getReplayPosCasP(consumer_tv, producer_tv); - if (producer_pos == 0) { - return false; - } - - auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto replayed_consumer = TransformReplay::replayCasP( - consumer_tv, producer_tv, producer_pos, pairwiseMap); - - auto consumer_root = consumer_tv->getRootDomain(); - auto replayed_domain = replayed_consumer.first->domain(); - - // Find the number of root IDs involved in the transformation - auto dep_vals = DependencyCheck::getAllValsBetween( - {consumer_root.begin(), consumer_root.end()}, - {replayed_domain.begin(), - replayed_domain.begin() + replayed_consumer.second}); - - std::unordered_set dep_vals_set{dep_vals.begin(), dep_vals.end()}; - - auto n_transformed_root_dims = std::count_if( - consumer_root.begin(), - consumer_root.end(), - [&dep_vals_set](IterDomain* root_id) { - return dep_vals_set.find(root_id) != dep_vals_set.end(); - }); - - auto consumer_pos = retrieveReplayedPos(consumer_tv); - auto replayed_root_dims_it = n_replayed_root_dims.find(consumer_tv); - if (consumer_pos > 0 && replayed_root_dims_it != n_replayed_root_dims.end()) { - if (n_transformed_root_dims < replayed_root_dims_it->second || - (n_transformed_root_dims == replayed_root_dims_it->second && - replayed_consumer.second <= consumer_pos)) { - return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - } - } - - consumer_tv->setDomain(replayed_consumer.first); - recordReplayedPos(consumer_tv, replayed_consumer.second); - n_replayed_root_dims[consumer_tv] = n_transformed_root_dims; - - return true; -} - -TransformPropagatorBase::TransformPropagatorBase( - TensorView* from, - size_t starting_pos) - : starting_tv(from), starting_pos(starting_pos) {} - -void TransformPropagatorBase::run() { - VectorOfUniqueEntries propagation{starting_tv}; - - // Seed position with local tv - recordReplayedPos(starting_tv, starting_pos); - - // While tensor views are being replayed, if they're modified, make sure we - // propagate back to all producers as well as consumers. This is definitely - // not the most efficient implementation as what we do is any time a tv is - // changed we propagate both forward and backward. - while (!propagation.empty()) { - auto tv = propagation.popBack(); - - // Replay tv forward to its consumers. - for (auto consumer_tv : consumersOf(tv)) { - if (!shouldPropagate(consumer_tv)) { - continue; - } - auto replayed = replayCasP(consumer_tv, tv); - // If consumer has changed, mark we should propagate - if (replayed) { - propagation.pushBack(consumer_tv); - } - } - - for (auto producer_tv : producersFor(tv)) { - if (!shouldPropagate(producer_tv)) { - continue; - } - // If producer has changed, mark we should propagate - auto replayed = replayPasC(producer_tv, tv); - if (replayed) { - propagation.pushBack(producer_tv); - } - } - } -} - -TransformPropagator::TransformPropagator(TensorView* from, size_t starting_pos) - : TransformPropagatorBase(from, starting_pos) {} - -void TransformPropagator::from(TensorView* tv) { - TransformPropagator propagate(tv, tv->nDims()); - propagate.run(); -} - -void TransformPropagator::recordReplayedPos(TensorView *tv, size_t pos) { - replayed_pos[tv] = pos; -} - -size_t TransformPropagator::retrieveReplayedPos(TensorView *tv) { - auto it = replayed_pos.find(tv); - if (it != replayed_pos.end()) { - return it->second; - } - return 0; -} - -size_t TransformPropagator::getReplayPosPasC(TensorView *producer, TensorView *consumer) { - return retrieveReplayedPos(consumer); -} - -size_t TransformPropagator::getReplayPosCasP(TensorView *consumer, TensorView *producer) { - return retrieveReplayedPos(producer); +TransformPropagator::TransformPropagator(TensorView* from) + : MaxRootDomainInfoPropagator(from, getStartingRootIDInfo(from)) { + replayed_pos[from] = from->nDims(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 1208ed3ab85de..cb9daf8fff832 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -2,7 +2,8 @@ #include #include -#include +#include +#include #include #include @@ -155,75 +156,15 @@ class TORCH_CUDA_CU_API TransformReplay { const TensorDomain* self); }; -class TORCH_CUDA_CU_API TransformPropagatorBase { - protected: - // TODO: keep? - // This example comes from a BN kernel, the domain: - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i4, 128), 4), 1)}, iS{1}, iS{4}, iS{128}, - // iS{i0}, iS{i2}, iS{i3} ] - // - // and - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i5*i6*i7*i8, 128), 4), 1)}, iS252{1}, - // iS250{4}, iS248{128} ] - // - // Have the same number of replayed dimensions, however the second one - // involves more root domains. The second one is also likely the prefered - // replay. Therefore keep track of how many root domains were part of the - // replay and prefer transformations with more root domains. We could probably - // fix this instances of this occuring by changing the traversal pattern so - // that once propagating towards roots through broadcast axes, it can't come - // back through another broadcast, losing the transformation on those axes. - // However, this should work for existing cases. - std::unordered_map n_replayed_root_dims; - - TensorView* starting_tv = nullptr; - size_t starting_pos; - - TransformPropagatorBase(TensorView* from, size_t starting_pos); - bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr); - bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr); - void run(); - - virtual void recordReplayedPos(TensorView *tv, size_t pos) = 0; - virtual size_t retrieveReplayedPos(TensorView *tv) = 0; - virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) = 0; - virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) = 0; - virtual bool shouldPropagate(TensorView* tv) { return true; } -}; - -class TORCH_CUDA_CU_API TransformPropagator : public TransformPropagatorBase { - - std::unordered_map replayed_pos; - - // This example comes from a BN kernel, the domain: - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i4, 128), 4), 1)}, iS{1}, iS{4}, iS{128}, - // iS{i0}, iS{i2}, iS{i3} ] - // - // and - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i5*i6*i7*i8, 128), 4), 1)}, iS252{1}, - // iS250{4}, iS248{128} ] - // - // Have the same number of replayed dimensions, however the second one - // involves more root domains. The second one is also likely the prefered - // replay. Therefore keep track of how many root domains were part of the - // replay and prefer transformations with more root domains. We could probably - // fix this instances of this occuring by changing the traversal pattern so - // that once propagating towards roots through broadcast axes, it can't come - // back through another broadcast, losing the transformation on those axes. - // However, this should work for existing cases. - std::unordered_map n_replayed_root_dims; +class TORCH_CUDA_CU_API TransformPropagator + : public MaxRootDomainInfoPropagator { + std::unordered_map replayed_pos; + static std::shared_ptr + getStartingRootIDInfo(TensorView* tv); protected: - TransformPropagator(TensorView* from, size_t starting_pos); - - virtual void recordReplayedPos(TensorView *tv, size_t pos) override; - virtual size_t retrieveReplayedPos(TensorView *tv) override; - virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) override; - virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) override; + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; public: TransformPropagator(TensorView* from); From c5ad23953c3d3eea477444d0a9558fb069cce4de Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 11:53:36 -0700 Subject: [PATCH 023/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 126 +++++++++++------- torch/csrc/jit/codegen/cuda/consume_at.h | 22 ++- .../jit/codegen/cuda/maxinfo_propagator.cpp | 4 +- .../jit/codegen/cuda/maxinfo_propagator.h | 15 ++- .../csrc/jit/codegen/cuda/transform_replay.h | 3 +- 5 files changed, 108 insertions(+), 62 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index e828334154735..8961451ebf618 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -10,12 +10,39 @@ namespace jit { namespace fuser { namespace cuda { +std::shared_ptr ConsumeAt::getStartingRootIDInfo( + TensorView* tv, + size_t pos) { + // TODO: currently I am creating a group for each rfactor ID. Should I create + // groups base on leaf IDs instead? + RootDomainInfo result; + const auto& root_domain = tv->getMaybeRFactorDomain(); + const auto& leaf_domain = tv->domain()->domain(); + std::unordered_set selected( + leaf_domain.begin(), leaf_domain.begin() + pos); + for (auto id : root_domain) { + if (selected.count(id) > 0) { + result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); + continue; + } + for (auto selected_id : selected) { + if (DependencyCheck::isDependencyOf(id, selected_id)) { + result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); + break; + } + } + } + return std::make_shared(std::move(result)); +} + ConsumeAt::ConsumeAt( std::unordered_set consume, TensorView* reference, size_t reference_pos, ComputeAtMode mode) - : TransformPropagatorBase(reference, reference_pos), + : MaxRootDomainInfoPropagator( + reference, + getStartingRootIDInfo(reference, reference_pos)), consume(std::move(consume)), mode(mode), unvisited(consume) { @@ -36,7 +63,7 @@ void ConsumeAt::buildUnmappableDims() { ComputeAtRootDomainMap root_map; root_map.build(); - auto all_tvs = ir_utils::allTvs(starting_tv->fusion()); + auto all_tvs = ir_utils::allTvs(reference->fusion()); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { @@ -65,29 +92,10 @@ void ConsumeAt::consumeAllAt( ca.run(); } -void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { - if (consume.count(tv)) { - auto max_pos = getMaxComputeAtPos(tv); - if (pos > max_pos) { - replayed_pos[tv] = pos; - pos = max_pos; - } - if (!tv->isFusionInput()) { - tv->setComputeAt(pos); - } else { - replayed_pos[tv] = pos; - } - } else { - replayed_pos[tv] = pos; - } -} - size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { auto dom = tv->domain()->domain(); - auto first_reduction = - std::find_if(dom.begin(), dom.end(), [](IterDomain* id) { - return id->isReduction(); - }); + auto first_reduction = std::find_if( + dom.begin(), dom.end(), [](IterDomain* id) { return id->isReduction(); }); auto first_vectorized_axis = std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { @@ -174,8 +182,7 @@ size_t ConsumeAt::getReplayablePosPasC( return false; } auto p_id = p_root_id_it->second; - return unmappable_dims.find(p_id) != - unmappable_dims.end(); + return unmappable_dims.find(p_id) != unmappable_dims.end(); })) { continue; } @@ -252,8 +259,7 @@ size_t ConsumeAt::getReplayablePosCasP( [this, &all_vals](IterDomain* p_root_id) { return std::find(all_vals.begin(), all_vals.end(), p_root_id) != all_vals.end() && - unmappable_dims.find(p_root_id) != - unmappable_dims.end(); + unmappable_dims.find(p_root_id) != unmappable_dims.end(); })) { continue; } @@ -263,6 +269,37 @@ size_t ConsumeAt::getReplayablePosCasP( return 0; } +void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { + int pos = getReplayPosPasC(to, from); + auto replay = TransformReplay::replayPasC(to, from, pos); + to->setDomain(replay.first); + replayed_pos[to] = replay.second; +} + +void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { + int pos = getReplayPosCasP(to, from); + auto replay = TransformReplay::replayCasP(to, from, pos); + to->setDomain(replay.first); + replayed_pos[to] = replay.second; +} + +void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { + if (consume.count(tv)) { + auto max_pos = getMaxComputeAtPos(tv); + if (pos > max_pos) { + replayed_pos[tv] = pos; + pos = max_pos; + } + if (!tv->isFusionInput()) { + tv->setComputeAt(pos); + } else { + replayed_pos[tv] = pos; + } + } else { + replayed_pos[tv] = pos; + } +} + size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { size_t pos1 = 0, pos2 = 0; if (consume.count(tv)) { @@ -276,9 +313,12 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { } size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { - size_t max_pos = getReplayablePosPasC(producer, consumer); // TODO: I don't think this makes sense. It should check all consumers, not just the "from" consumer! What is the purpose of this, given that we already have a getMaxComputeAtPos? + size_t max_pos = getReplayablePosPasC( + producer, + consumer); size_t pos = retrieveReplayedPos(consumer); - // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; + // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << + // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -302,7 +342,8 @@ size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); - // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; + // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << + // producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -407,30 +448,19 @@ void ConsumeAt::computeMaxProducerPos() { } } -bool ConsumeAt::shouldPropagate(TensorView* tv) { - if (consume.count(tv)) { - unvisited.erase(tv); - return true; - } +bool ConsumeAt::allowPasC(TensorView* from, TensorView* to) { + return consume.count(to) > 0; +} - // If one of tv's producer is in the consume set, then tv must also be +bool ConsumeAt::allowCasP(TensorView* from, TensorView* to) { + // If the producer is in the consume set, then the consumer must also be // replayed to obtain a compatible loop structure so that this producer // can be consumed in this loop. - auto def = tv->definition(); - if (def != nullptr) { - auto tv_inputs = ir_utils::filterByType(def->inputs()); - for (auto input : tv_inputs) { - if (consume.count(input)) { - return true; - } - } - } - - return false; + return consume.count(from) > 0 || consume.count(to) > 0; } void ConsumeAt::run() { - TransformPropagatorBase::run(); + MaxRootDomainInfoPropagator::run(); TORCH_CHECK( unvisited.empty(), "Unable to propagate to the entire consume set"); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 2b8ec3ce67a5a..48e7fd7f1f8a2 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -10,13 +11,20 @@ namespace jit { namespace fuser { namespace cuda { -class ConsumeAt : public TransformPropagatorBase { +class ConsumeAt : public MaxRootDomainInfoPropagator { size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); + size_t getReplayPosPasC(TensorView *producer, TensorView *consumer); + size_t getReplayPosCasP(TensorView *consumer, TensorView *producer); + void recordReplayedPos(TensorView* tv, size_t pos); + size_t retrieveReplayedPos(TensorView* tv); + size_t getMaxComputeAtPos(TensorView* tv); void hoistInnermostBroadcast(); void computeMaxProducerPos(); + static std::shared_ptr getStartingRootIDInfo(TensorView* tv, size_t pos); + std::unordered_set consume; ComputeAtMode mode = ComputeAtMode::Standard; std::unordered_map replayed_pos; @@ -29,12 +37,6 @@ class ConsumeAt : public TransformPropagatorBase { // map to all its consumer TVs. void buildUnmappableDims(); - virtual bool shouldPropagate(TensorView* tv) override; - virtual void recordReplayedPos(TensorView* tv, size_t pos) override; - virtual size_t retrieveReplayedPos(TensorView* tv) override; - virtual size_t getReplayPosPasC(TensorView *producer, TensorView *consumer) override; - virtual size_t getReplayPosCasP(TensorView *consumer, TensorView *producer) override; - ConsumeAt( std::unordered_set consume, TensorView* reference, @@ -45,6 +47,12 @@ class ConsumeAt : public TransformPropagatorBase { void run(); + protected: + virtual bool allowPasC(TensorView* from, TensorView* to) override; + virtual bool allowCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + public: static void consumeAllAt( std::unordered_set consume, diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index 46ffef6b3bfb0..b23b0604d53b7 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -76,7 +76,7 @@ void MaxInfoPropagator::run() { replayed.emplace(next_hop.to); for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { - if (replayed.count(consumer_tv)) { + if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) { continue; } insertNextHopInfo( @@ -89,7 +89,7 @@ void MaxInfoPropagator::run() { } for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) { - if (replayed.count(producer_tv)) { + if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) { continue; } insertNextHopInfo( diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index aebca46a24bab..736b6341fd755 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -69,12 +69,13 @@ class TORCH_CUDA_CU_API MaxInfoPropagator { } }; - TensorView* reference; - std::shared_ptr reference_info; - protected: + TensorView* const reference; + std::shared_ptr const reference_info; + virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0; virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0; + virtual std::shared_ptr computeInfoPasC( TensorView* from, TensorView* to, @@ -84,6 +85,14 @@ class TORCH_CUDA_CU_API MaxInfoPropagator { TensorView* to, std::shared_ptr from_info) = 0; + // methods providing a mechanism to propagate to only a part of the DAG + virtual bool allowPasC(TensorView* from, TensorView* to) { + return true; + } + virtual bool allowCasP(TensorView* from, TensorView* to) { + return true; + } + public: MaxInfoPropagator( TensorView* reference, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index cb9daf8fff832..58938ede7fb92 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -159,8 +159,7 @@ class TORCH_CUDA_CU_API TransformReplay { class TORCH_CUDA_CU_API TransformPropagator : public MaxRootDomainInfoPropagator { std::unordered_map replayed_pos; - static std::shared_ptr - getStartingRootIDInfo(TensorView* tv); + static std::shared_ptr getStartingRootIDInfo(TensorView* tv); protected: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; From c9e12abc4569196c9269d1c75a4ece445387764d Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 23 Jun 2022 14:21:25 -0700 Subject: [PATCH 024/100] Update ir_interface_nodes.h --- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 29816f24eb158..4240b2f62a0a3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -456,7 +456,6 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . void applyMmaSwizzle(MmaOptions options); - friend TORCH_CUDA_CU_API TransformPropagatorBase; friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; From 5ad1f8fc10e2fb01de6ed5eaf7c6ac680ac6faeb Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 23 Jun 2022 14:26:02 -0700 Subject: [PATCH 025/100] Update ir_interface_nodes.h --- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 4240b2f62a0a3..5400c69907bed 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -155,7 +155,6 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { enum class ComputeAtMode { Standard, BestEffort, MostInlined }; class ConsumeAt; -class TransformPropagatorBase; class TransformPropagator; class TransformIter; class TransformReplay; From ffef5025ffb4b724c9a6591c82ef9a9b2181da22 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 15:41:33 -0700 Subject: [PATCH 026/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/consume_at.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 8961451ebf618..f7bc4ab1d37c4 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -273,14 +273,14 @@ void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { int pos = getReplayPosPasC(to, from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); - replayed_pos[to] = replay.second; + recordReplayedPos(to, replay.second); } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { int pos = getReplayPosCasP(to, from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); - replayed_pos[to] = replay.second; + recordReplayedPos(to, replay.second); } void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 48e7fd7f1f8a2..df68e60276e55 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -64,4 +64,4 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { } // namespace cuda } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch From 587c2c4ff8c842869fb2e0920a710a7b1aedb87d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:08:54 -0700 Subject: [PATCH 027/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 7 ++++--- torch/csrc/jit/codegen/cuda/consume_at.h | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index f7bc4ab1d37c4..d4981c13bd6d6 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -57,6 +57,7 @@ ConsumeAt::ConsumeAt( "."); buildUnmappableDims(); + replayed_pos[reference] = reference_pos; } void ConsumeAt::buildUnmappableDims() { @@ -270,6 +271,7 @@ size_t ConsumeAt::getReplayablePosCasP( } void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { + // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << std::endl; int pos = getReplayPosPasC(to, from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); @@ -277,6 +279,7 @@ void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { + // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << std::endl; int pos = getReplayPosCasP(to, from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); @@ -313,9 +316,7 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { } size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { - size_t max_pos = getReplayablePosPasC( - producer, - consumer); + size_t max_pos = getReplayablePosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index df68e60276e55..d244230e05eab 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -12,10 +12,11 @@ namespace fuser { namespace cuda { class ConsumeAt : public MaxRootDomainInfoPropagator { + // TODO: change arguments to `from`, `to`. size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); - size_t getReplayPosPasC(TensorView *producer, TensorView *consumer); - size_t getReplayPosCasP(TensorView *consumer, TensorView *producer); + size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); + size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); void recordReplayedPos(TensorView* tv, size_t pos); size_t retrieveReplayedPos(TensorView* tv); @@ -23,7 +24,9 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { void hoistInnermostBroadcast(); void computeMaxProducerPos(); - static std::shared_ptr getStartingRootIDInfo(TensorView* tv, size_t pos); + static std::shared_ptr getStartingRootIDInfo( + TensorView* tv, + size_t pos); std::unordered_set consume; ComputeAtMode mode = ComputeAtMode::Standard; From 8c4cb21659be2c7482c4ff52538c02d5230a3de6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:22:25 -0700 Subject: [PATCH 028/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 73 ++++++++++--------- torch/csrc/jit/codegen/cuda/consume_at.h | 20 ++--- .../jit/codegen/cuda/maxinfo_propagator.cpp | 4 +- .../jit/codegen/cuda/maxinfo_propagator.h | 6 +- .../jit/codegen/cuda/transform_replay.cpp | 10 +-- .../csrc/jit/codegen/cuda/transform_replay.h | 2 +- 6 files changed, 60 insertions(+), 55 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index d4981c13bd6d6..1894680c26da2 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -43,9 +43,10 @@ ConsumeAt::ConsumeAt( : MaxRootDomainInfoPropagator( reference, getStartingRootIDInfo(reference, reference_pos)), - consume(std::move(consume)), - mode(mode), - unvisited(consume) { + consume_(std::move(consume)), + reference_pos_(reference_pos), + mode_(mode), + unvisited_(consume_) { TORCH_INTERNAL_ASSERT( reference_pos >= 0 && reference_pos <= reference->nDims(), "Invalid computeAt axis, received ", @@ -57,7 +58,6 @@ ConsumeAt::ConsumeAt( "."); buildUnmappableDims(); - replayed_pos[reference] = reference_pos; } void ConsumeAt::buildUnmappableDims() { @@ -77,7 +77,7 @@ void ConsumeAt::buildUnmappableDims() { for (auto tv_root_id : tv->getMaybeRFactorDomain()) { if (mappable_roots.find(tv_root_id) == mappable_roots.end() && !tv_root_id->isTrivialReduction()) { - unmappable_dims.emplace(tv_root_id); + unmappable_dims_.emplace(tv_root_id); } } } @@ -101,8 +101,8 @@ size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { auto first_vectorized_axis = std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && id->getParallelType() == ParallelType::Unroll); }); @@ -126,8 +126,8 @@ size_t ConsumeAt::getReplayablePosPasC( auto vector_dim_it = std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && id->getParallelType() == ParallelType::Unroll); }); @@ -153,8 +153,8 @@ size_t ConsumeAt::getReplayablePosPasC( // If we find a consumer dim that maps to a producer dim that's // vectorized or unrolled limit max compute at by it. if (isParallelTypeVectorize(p_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && p_id->getParallelType() == ParallelType::Unroll)) { max_consumer_pos = consumer_pos - 1; } @@ -183,7 +183,7 @@ size_t ConsumeAt::getReplayablePosPasC( return false; } auto p_id = p_root_id_it->second; - return unmappable_dims.find(p_id) != unmappable_dims.end(); + return unmappable_dims_.find(p_id) != unmappable_dims.end(); })) { continue; } @@ -212,8 +212,8 @@ size_t ConsumeAt::getReplayablePosCasP( auto first_vectorized_axis = std::find_if(p_dom.begin(), first_reduction, [this](IterDomain* id) { return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && id->getParallelType() == ParallelType::Unroll); }); @@ -237,8 +237,8 @@ size_t ConsumeAt::getReplayablePosCasP( // If we find a producer dim that maps to a consumer vectorized or // unrolled dim, limit max compute at by it if (isParallelTypeVectorize(c_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && c_id->getParallelType() == ParallelType::Unroll)) { max_producer_pos = producer_pos - 1; } @@ -260,7 +260,7 @@ size_t ConsumeAt::getReplayablePosCasP( [this, &all_vals](IterDomain* p_root_id) { return std::find(all_vals.begin(), all_vals.end(), p_root_id) != all_vals.end() && - unmappable_dims.find(p_root_id) != unmappable_dims.end(); + unmappable_dims_.find(p_root_id) != unmappable_dims_.end(); })) { continue; } @@ -271,7 +271,8 @@ size_t ConsumeAt::getReplayablePosCasP( } void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << std::endl; + // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << + // std::endl; int pos = getReplayPosPasC(to, from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); @@ -279,7 +280,8 @@ void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << std::endl; + // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << + // std::endl; int pos = getReplayPosCasP(to, from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); @@ -287,29 +289,29 @@ void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { } void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { - if (consume.count(tv)) { + if (consume_.count(tv)) { auto max_pos = getMaxComputeAtPos(tv); if (pos > max_pos) { - replayed_pos[tv] = pos; + replayed_pos_[tv] = pos; pos = max_pos; } if (!tv->isFusionInput()) { tv->setComputeAt(pos); } else { - replayed_pos[tv] = pos; + replayed_pos_[tv] = pos; } } else { - replayed_pos[tv] = pos; + replayed_pos_[tv] = pos; } } size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { size_t pos1 = 0, pos2 = 0; - if (consume.count(tv)) { + if (consume_.count(tv)) { pos1 = tv->getComputeAtPosition(); } - auto it = replayed_pos.find(tv); - if (it != replayed_pos.end()) { + auto it = replayed_pos_.find(tv); + if (it != replayed_pos_.end()) { pos2 = it->second; } return std::max(pos1, pos2); @@ -321,9 +323,9 @@ size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; - if (mode == ComputeAtMode::BestEffort) { + if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); - } else if (mode == ComputeAtMode::MostInlined) { + } else if (mode_ == ComputeAtMode::MostInlined) { return max_pos; } @@ -346,9 +348,9 @@ size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << // producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; - if (mode == ComputeAtMode::BestEffort) { + if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); - } else if (mode == ComputeAtMode::MostInlined) { + } else if (mode_ == ComputeAtMode::MostInlined) { return max_pos; } @@ -366,7 +368,7 @@ size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { } void ConsumeAt::hoistInnermostBroadcast() { - for (auto tv : consume) { + for (auto tv : consume_) { if (!tv->isFusionInput()) { auto ca_pos = tv->getComputeAtPosition(); while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { @@ -433,7 +435,7 @@ size_t getConsumerPosAlignedToProducerCA( void ConsumeAt::computeMaxProducerPos() { std::unordered_set todo; - for (auto p : consume) { + for (auto p : consume_) { auto consumers = ir_utils::consumerTvsOf(p); std::copy( consumers.begin(), consumers.end(), std::inserter(todo, todo.end())); @@ -450,20 +452,21 @@ void ConsumeAt::computeMaxProducerPos() { } bool ConsumeAt::allowPasC(TensorView* from, TensorView* to) { - return consume.count(to) > 0; + return consume_.count(to) > 0; } bool ConsumeAt::allowCasP(TensorView* from, TensorView* to) { // If the producer is in the consume set, then the consumer must also be // replayed to obtain a compatible loop structure so that this producer // can be consumed in this loop. - return consume.count(from) > 0 || consume.count(to) > 0; + return consume_.count(from) > 0 || consume_.count(to) > 0; } void ConsumeAt::run() { + recordReplayedPos(reference_, reference_pos_); MaxRootDomainInfoPropagator::run(); TORCH_CHECK( - unvisited.empty(), "Unable to propagate to the entire consume set"); + unvisited_.empty(), "Unable to propagate to the entire consume set"); hoistInnermostBroadcast(); computeMaxProducerPos(); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index d244230e05eab..fa40a5f907b5d 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -21,24 +21,26 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { size_t retrieveReplayedPos(TensorView* tv); size_t getMaxComputeAtPos(TensorView* tv); + void hoistInnermostBroadcast(); void computeMaxProducerPos(); + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + static std::shared_ptr getStartingRootIDInfo( TensorView* tv, size_t pos); - std::unordered_set consume; - ComputeAtMode mode = ComputeAtMode::Standard; - std::unordered_map replayed_pos; - std::unordered_set unvisited; + std::unordered_set consume_; + size_t reference_pos_; + ComputeAtMode mode_ = ComputeAtMode::Standard; + std::unordered_map replayed_pos_; + std::unordered_set unvisited_; // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims; - - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); + std::unordered_set unmappable_dims_; ConsumeAt( std::unordered_set consume, diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index b23b0604d53b7..a085bad009871 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -30,8 +30,8 @@ void MaxInfoPropagator::run() { // deterministic either. std::list propagation(1); propagation.back().from = nullptr; - propagation.back().to = reference; - propagation.back().info_to = reference_info; + propagation.back().to = reference_; + propagation.back().info_to = reference_info_; // Insert the given next hop the correct position in `propagation`. If there // is an existing next hop that preserves more information, then we will just diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 736b6341fd755..22a5c925e2c82 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -70,8 +70,8 @@ class TORCH_CUDA_CU_API MaxInfoPropagator { }; protected: - TensorView* const reference; - std::shared_ptr const reference_info; + TensorView* const reference_; + std::shared_ptr const reference_info_; virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0; virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0; @@ -97,7 +97,7 @@ class TORCH_CUDA_CU_API MaxInfoPropagator { MaxInfoPropagator( TensorView* reference, std::shared_ptr reference_info) - : reference(reference), reference_info(reference_info){}; + : reference_(reference), reference_info_(reference_info){}; void run(); }; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index f3bffdbea8a4a..d4b5e0d17a2a2 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -659,22 +659,22 @@ std::shared_ptr TransformPropagator:: } void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { - int pos = replayed_pos.at(from); + int pos = replayed_pos_.at(from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); - replayed_pos[to] = replay.second; + replayed_pos_[to] = replay.second; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { - int pos = replayed_pos.at(from); + int pos = replayed_pos_.at(from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); - replayed_pos[to] = replay.second; + replayed_pos_[to] = replay.second; } TransformPropagator::TransformPropagator(TensorView* from) : MaxRootDomainInfoPropagator(from, getStartingRootIDInfo(from)) { - replayed_pos[from] = from->nDims(); + replayed_pos_[from] = from->nDims(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 58938ede7fb92..1c551e8e2acc6 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -158,7 +158,7 @@ class TORCH_CUDA_CU_API TransformReplay { class TORCH_CUDA_CU_API TransformPropagator : public MaxRootDomainInfoPropagator { - std::unordered_map replayed_pos; + std::unordered_map replayed_pos_; static std::shared_ptr getStartingRootIDInfo(TensorView* tv); protected: From ebd943da938c3efdf072ad16195e9f3a65a3c843 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:23:48 -0700 Subject: [PATCH 029/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 1894680c26da2..73f5fc6896415 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -64,7 +64,7 @@ void ConsumeAt::buildUnmappableDims() { ComputeAtRootDomainMap root_map; root_map.build(); - auto all_tvs = ir_utils::allTvs(reference->fusion()); + auto all_tvs = ir_utils::allTvs(reference_->fusion()); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { @@ -183,7 +183,7 @@ size_t ConsumeAt::getReplayablePosPasC( return false; } auto p_id = p_root_id_it->second; - return unmappable_dims_.find(p_id) != unmappable_dims.end(); + return unmappable_dims_.find(p_id) != unmappable_dims_.end(); })) { continue; } From 1ae4e7092a5bf26127f4bfbb80d2760ed1493de8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:34:49 -0700 Subject: [PATCH 030/100] fix --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 73f5fc6896415..1a6c4ff3ba54e 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -271,20 +271,20 @@ size_t ConsumeAt::getReplayablePosCasP( } void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << - // std::endl; + // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << std::endl; int pos = getReplayPosPasC(to, from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); + unvisited_.erase(to); recordReplayedPos(to, replay.second); } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << - // std::endl; + // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << std::endl; int pos = getReplayPosCasP(to, from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); + unvisited_.erase(to); recordReplayedPos(to, replay.second); } @@ -464,7 +464,10 @@ bool ConsumeAt::allowCasP(TensorView* from, TensorView* to) { void ConsumeAt::run() { recordReplayedPos(reference_, reference_pos_); + // std::cout << "before:" << ir_utils::toString(unvisited_) << std::endl; + unvisited_.erase(reference_); MaxRootDomainInfoPropagator::run(); + // std::cout << "after:" << ir_utils::toString(unvisited_) << std::endl; TORCH_CHECK( unvisited_.empty(), "Unable to propagate to the entire consume set"); From 58df2cb9c9c85eee3a6803a046cb9e2c6589b74b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:43:36 -0700 Subject: [PATCH 031/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 16 +++++----------- torch/csrc/jit/codegen/cuda/consume_at.h | 1 - 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 1a6c4ff3ba54e..94321e9ccd54d 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -45,8 +45,7 @@ ConsumeAt::ConsumeAt( getStartingRootIDInfo(reference, reference_pos)), consume_(std::move(consume)), reference_pos_(reference_pos), - mode_(mode), - unvisited_(consume_) { + mode_(mode) { TORCH_INTERNAL_ASSERT( reference_pos >= 0 && reference_pos <= reference->nDims(), "Invalid computeAt axis, received ", @@ -271,20 +270,20 @@ size_t ConsumeAt::getReplayablePosCasP( } void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << std::endl; + // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << + // std::endl; int pos = getReplayPosPasC(to, from); auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); - unvisited_.erase(to); recordReplayedPos(to, replay.second); } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << std::endl; + // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << + // std::endl; int pos = getReplayPosCasP(to, from); auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); - unvisited_.erase(to); recordReplayedPos(to, replay.second); } @@ -464,12 +463,7 @@ bool ConsumeAt::allowCasP(TensorView* from, TensorView* to) { void ConsumeAt::run() { recordReplayedPos(reference_, reference_pos_); - // std::cout << "before:" << ir_utils::toString(unvisited_) << std::endl; - unvisited_.erase(reference_); MaxRootDomainInfoPropagator::run(); - // std::cout << "after:" << ir_utils::toString(unvisited_) << std::endl; - TORCH_CHECK( - unvisited_.empty(), "Unable to propagate to the entire consume set"); hoistInnermostBroadcast(); computeMaxProducerPos(); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index fa40a5f907b5d..64b7f7a8bfafc 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -37,7 +37,6 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; std::unordered_map replayed_pos_; - std::unordered_set unvisited_; // Root domains in producer that's unmappable to any of its consumers std::unordered_set unmappable_dims_; From 680a520cae5cfe3ff9c1e0c02fee9512d72c49a8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 16:52:34 -0700 Subject: [PATCH 032/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 94321e9ccd54d..2ca219fe71853 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -305,15 +305,11 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { } size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { - size_t pos1 = 0, pos2 = 0; - if (consume_.count(tv)) { - pos1 = tv->getComputeAtPosition(); - } auto it = replayed_pos_.find(tv); if (it != replayed_pos_.end()) { - pos2 = it->second; + return it->second; } - return std::max(pos1, pos2); + return tv->getComputeAtPosition(); } size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { From 889bb27beac821f4d158b01a7c6ff03fca8028f2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 20:39:50 -0700 Subject: [PATCH 033/100] short cut --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 129 ++++++++++++++++-- .../jit/codegen/cuda/maxinfo_propagator.h | 2 + 2 files changed, 121 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 2ca219fe71853..420d843dadac5 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -13,6 +13,7 @@ namespace cuda { std::shared_ptr ConsumeAt::getStartingRootIDInfo( TensorView* tv, size_t pos) { + pos = tv->nDims(); // TODO: hacking! // TODO: currently I am creating a group for each rfactor ID. Should I create // groups base on leaf IDs instead? RootDomainInfo result; @@ -88,8 +89,16 @@ void ConsumeAt::consumeAllAt( TensorView* reference, size_t reference_pos, ComputeAtMode mode) { + // std::cout << "==========================" << std::endl; + // std::cout << "From: " << reference << " at pos " << reference_pos + // << std::endl; + // std::cout << "Before:" << std::endl; + // reference->fusion()->print(); ConsumeAt ca(std::move(consume), reference, reference_pos, mode); ca.run(); + // std::cout << "\n\nAfter:" << std::endl; + // reference->fusion()->print(); + // std::cout << "==========================" << std::endl; } size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { @@ -269,22 +278,122 @@ size_t ConsumeAt::getReplayablePosCasP( return 0; } +namespace { + +// Checks if producer and consumer are transformed consistently so that to +// satisfy the provided compute at position. This means no replay is actually +// necessary for the compute at requested. If consumer_pos then +// consumer_or_producer_pos is relative to the consumer and skipReplay returns +// the associated position in producer. +// +// If producer and consumer are not transformed consistently with provided +// postition, returns -1. +int skipReplay( + const TensorView* producer, + const TensorView* consumer, + int consumer_or_producer_pos, + bool consumer_pos = true) { + const auto c2p_root_map = + PairwiseRootDomainMap(producer, consumer) + .mapConsumerToProducer(consumer->domain(), producer->domain()); + + // IterDomains in consumer root also in producer root + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); + } + + const auto consumer_domain = consumer->domain()->domain(); + + auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set mapped_consumer_domain_ids( + mapped_consumer_domain_ids_vec.begin(), + mapped_consumer_domain_ids_vec.end()); + + const auto producer_domain = producer->domain()->domain(); + + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); + + auto best_effort_PasC = BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + + auto c2p_map = best_effort_PasC.getReplay(); + + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + auto consumer_id = *it_consumer; + if (!mapped_consumer_domain_ids.count(consumer_id)) { + ++it_consumer; + mismatched_consumer_pos++; + continue; + } + + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + break; + } + + if (it_producer == producer_domain.end()) { + break; + } + + auto producer_id = *it_producer; + + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + if (consumer_pos) { + if (consumer_or_producer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + } else { + if (consumer_or_producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + } + } else { + break; + } + } + return -1; +} + +} // namespace + void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC, from: " << from << ", to:" << to << - // std::endl; + // std::cout << "propagateTvPasC, from: " << from << ", to:" << to; int pos = getReplayPosPasC(to, from); - auto replay = TransformReplay::replayPasC(to, from, pos); - to->setDomain(replay.first); - recordReplayedPos(to, replay.second); + // std::cout << ", at: " << pos; + // Short cut if no replay is necessary + auto to_pos = skipReplay(to, from, (int)pos, true); + if (to_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + to->setDomain(replay.first); + to_pos = replay.second; + } + recordReplayedPos(to, to_pos); + // std::cout << ", result: " << to << std::endl; } void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP, from: " << from << ", to:" << to << - // std::endl; + // std::cout << "propagateTvCasP, from: " << from << ", to:" << to; int pos = getReplayPosCasP(to, from); - auto replay = TransformReplay::replayCasP(to, from, pos); - to->setDomain(replay.first); - recordReplayedPos(to, replay.second); + // std::cout << ", at: " << pos; + // Short cut if no replay is necessary + auto to_pos = skipReplay(from, to, (int)pos, false); + if (to_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + to->setDomain(replay.first); + to_pos = replay.second; + } + recordReplayedPos(to, to_pos); + // std::cout << ", result: " << to << std::endl; } void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 22a5c925e2c82..abd388e9f5505 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -46,6 +46,8 @@ class TORCH_CUDA_CU_API MaxInfoPropagator { // l == r means it is hard to tell which one of then contains more // information bool operator==(const Information& r) const; + // just to avoid compiler warning + virtual ~Information() {} }; private: From db54f57400e68f2c447b5966fc217017efdff97e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 23 Jun 2022 21:47:53 -0700 Subject: [PATCH 034/100] cleanup --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 4c38bcf44cd46..d4ffb22447caa 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -16709,9 +16709,7 @@ TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { tv2->axis(-2)->parallelize(ParallelType::TIDx); tv3->axis(-2)->parallelize(ParallelType::TIDx); - fusion->print(); tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - fusion->print(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({17, 18, 128, 1}, options); From ccd4a00c5d9f7c063a8ed442d7e216c29730b700 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 26 Jun 2022 01:38:46 -0700 Subject: [PATCH 035/100] resolve --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 4 +- torch/csrc/jit/codegen/cuda/consume_at.cpp | 87 +++++++------------ torch/csrc/jit/codegen/cuda/consume_at.h | 23 +++-- .../jit/codegen/cuda/ir_interface_nodes.h | 4 +- 4 files changed, 49 insertions(+), 69 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index b0cabac7fb14b..9bdd2ad704d9f 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -182,7 +182,7 @@ void ComputeAt::runAt( " are not in the same fusion."); FusionGuard fg(producer->fusion()); - ConsumeAt::consumeAllAt( + ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), consumer, consumer_position, @@ -207,7 +207,7 @@ void ComputeAt::runWith( // Make sure Fusion Guard is set appropriately FusionGuard fg(producer->fusion()); - ConsumeAt::consumeAllAt( + ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), producer, producer_position, diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 420d843dadac5..640e090f774e4 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -10,41 +10,13 @@ namespace jit { namespace fuser { namespace cuda { -std::shared_ptr ConsumeAt::getStartingRootIDInfo( - TensorView* tv, - size_t pos) { - pos = tv->nDims(); // TODO: hacking! - // TODO: currently I am creating a group for each rfactor ID. Should I create - // groups base on leaf IDs instead? - RootDomainInfo result; - const auto& root_domain = tv->getMaybeRFactorDomain(); - const auto& leaf_domain = tv->domain()->domain(); - std::unordered_set selected( - leaf_domain.begin(), leaf_domain.begin() + pos); - for (auto id : root_domain) { - if (selected.count(id) > 0) { - result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); - continue; - } - for (auto selected_id : selected) { - if (DependencyCheck::isDependencyOf(id, selected_id)) { - result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()}); - break; - } - } - } - return std::make_shared(std::move(result)); -} - -ConsumeAt::ConsumeAt( +ComputeAtPosPropagator::ComputeAtPosPropagator( std::unordered_set consume, TensorView* reference, size_t reference_pos, ComputeAtMode mode) - : MaxRootDomainInfoPropagator( - reference, - getStartingRootIDInfo(reference, reference_pos)), - consume_(std::move(consume)), + : consume_(std::move(consume)), + reference_(reference), reference_pos_(reference_pos), mode_(mode) { TORCH_INTERNAL_ASSERT( @@ -60,7 +32,7 @@ ConsumeAt::ConsumeAt( buildUnmappableDims(); } -void ConsumeAt::buildUnmappableDims() { +void ComputeAtPosPropagator::buildUnmappableDims() { ComputeAtRootDomainMap root_map; root_map.build(); @@ -84,7 +56,7 @@ void ConsumeAt::buildUnmappableDims() { } } -void ConsumeAt::consumeAllAt( +void ComputeAtPosPropagator::consumeAllAt( std::unordered_set consume, TensorView* reference, size_t reference_pos, @@ -94,14 +66,21 @@ void ConsumeAt::consumeAllAt( // << std::endl; // std::cout << "Before:" << std::endl; // reference->fusion()->print(); - ConsumeAt ca(std::move(consume), reference, reference_pos, mode); - ca.run(); + ComputeAtSubgraphSelector selector(consume); + ComputeAtPosPropagator propagator( + std::move(consume), reference, reference_pos, mode); + + propagator.recordReplayedPos(reference, reference_pos); + MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); + path.traverse(&propagator); + propagator.hoistInnermostBroadcast(); + propagator.computeMaxProducerPos(); // std::cout << "\n\nAfter:" << std::endl; // reference->fusion()->print(); // std::cout << "==========================" << std::endl; } -size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { +size_t ComputeAtPosPropagator::getMaxComputeAtPos(TensorView* tv) { auto dom = tv->domain()->domain(); auto first_reduction = std::find_if( dom.begin(), dom.end(), [](IterDomain* id) { return id->isReduction(); }); @@ -125,7 +104,7 @@ size_t ConsumeAt::getMaxComputeAtPos(TensorView* tv) { // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ConsumeAt::getReplayablePosPasC( +size_t ComputeAtPosPropagator::getReplayablePosPasC( TensorView* producer, TensorView* consumer) { // Check if any consumer dimensions are marked as vectorize as producer can @@ -208,7 +187,7 @@ size_t ConsumeAt::getReplayablePosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ConsumeAt::getReplayablePosCasP( +size_t ComputeAtPosPropagator::getReplayablePosCasP( TensorView* consumer, TensorView* producer) { auto p_dom = producer->domain()->domain(); @@ -366,7 +345,7 @@ int skipReplay( } // namespace -void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { +void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { // std::cout << "propagateTvPasC, from: " << from << ", to:" << to; int pos = getReplayPosPasC(to, from); // std::cout << ", at: " << pos; @@ -381,7 +360,7 @@ void ConsumeAt::propagateTvPasC(TensorView* from, TensorView* to) { // std::cout << ", result: " << to << std::endl; } -void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { +void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { // std::cout << "propagateTvCasP, from: " << from << ", to:" << to; int pos = getReplayPosCasP(to, from); // std::cout << ", at: " << pos; @@ -396,7 +375,7 @@ void ConsumeAt::propagateTvCasP(TensorView* from, TensorView* to) { // std::cout << ", result: " << to << std::endl; } -void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { +void ComputeAtPosPropagator::recordReplayedPos(TensorView* tv, size_t pos) { if (consume_.count(tv)) { auto max_pos = getMaxComputeAtPos(tv); if (pos > max_pos) { @@ -413,7 +392,7 @@ void ConsumeAt::recordReplayedPos(TensorView* tv, size_t pos) { } } -size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { +size_t ComputeAtPosPropagator::retrieveReplayedPos(TensorView* tv) { auto it = replayed_pos_.find(tv); if (it != replayed_pos_.end()) { return it->second; @@ -421,7 +400,9 @@ size_t ConsumeAt::retrieveReplayedPos(TensorView* tv) { return tv->getComputeAtPosition(); } -size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { +size_t ComputeAtPosPropagator::getReplayPosPasC( + TensorView* producer, + TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << @@ -446,7 +427,9 @@ size_t ConsumeAt::getReplayPosPasC(TensorView* producer, TensorView* consumer) { return pos; } -size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { +size_t ComputeAtPosPropagator::getReplayPosCasP( + TensorView* consumer, + TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << @@ -471,7 +454,7 @@ size_t ConsumeAt::getReplayPosCasP(TensorView* consumer, TensorView* producer) { return pos; } -void ConsumeAt::hoistInnermostBroadcast() { +void ComputeAtPosPropagator::hoistInnermostBroadcast() { for (auto tv : consume_) { if (!tv->isFusionInput()) { auto ca_pos = tv->getComputeAtPosition(); @@ -537,7 +520,7 @@ size_t getConsumerPosAlignedToProducerCA( return consumer_pos; } -void ConsumeAt::computeMaxProducerPos() { +void ComputeAtPosPropagator::computeMaxProducerPos() { std::unordered_set todo; for (auto p : consume_) { auto consumers = ir_utils::consumerTvsOf(p); @@ -555,25 +538,17 @@ void ConsumeAt::computeMaxProducerPos() { } } -bool ConsumeAt::allowPasC(TensorView* from, TensorView* to) { +bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { return consume_.count(to) > 0; } -bool ConsumeAt::allowCasP(TensorView* from, TensorView* to) { +bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { // If the producer is in the consume set, then the consumer must also be // replayed to obtain a compatible loop structure so that this producer // can be consumed in this loop. return consume_.count(from) > 0 || consume_.count(to) > 0; } -void ConsumeAt::run() { - recordReplayedPos(reference_, reference_pos_); - MaxRootDomainInfoPropagator::run(); - - hoistInnermostBroadcast(); - computeMaxProducerPos(); -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 64b7f7a8bfafc..b3d739b8aa14a 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -11,7 +11,17 @@ namespace jit { namespace fuser { namespace cuda { -class ConsumeAt : public MaxRootDomainInfoPropagator { +class ComputeAtSubgraphSelector : public MaxInfoSpanningTree::Selector { + std::unordered_set consume_; + + public: + virtual bool allowPasC(TensorView* from, TensorView* to) override; + virtual bool allowCasP(TensorView* from, TensorView* to) override; + ComputeAtSubgraphSelector(std::unordered_set consume) + : consume_(std::move(consume)) {} +}; + +class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // TODO: change arguments to `from`, `to`. size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); @@ -29,11 +39,8 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { // map to all its consumer TVs. void buildUnmappableDims(); - static std::shared_ptr getStartingRootIDInfo( - TensorView* tv, - size_t pos); - std::unordered_set consume_; + TensorView* reference_; size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; std::unordered_map replayed_pos_; @@ -41,19 +48,17 @@ class ConsumeAt : public MaxRootDomainInfoPropagator { // Root domains in producer that's unmappable to any of its consumers std::unordered_set unmappable_dims_; - ConsumeAt( + ComputeAtPosPropagator( std::unordered_set consume, TensorView* reference, size_t reference_pos, ComputeAtMode mode); - ~ConsumeAt() = default; + ~ComputeAtPosPropagator() = default; void run(); protected: - virtual bool allowPasC(TensorView* from, TensorView* to) override; - virtual bool allowCasP(TensorView* from, TensorView* to) override; virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 5400c69907bed..2bc36461986f5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -154,7 +154,7 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class ConsumeAt; +class ComputeAtPosPropagator; class TransformPropagator; class TransformIter; class TransformReplay; @@ -458,7 +458,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend TORCH_CUDA_CU_API ConsumeAt; + friend TORCH_CUDA_CU_API ComputeAtPosPropagator; friend class ir_utils::TVDomainGuard; friend TORCH_CUDA_CU_API void groupReductions( const std::vector&); From 83c3d0ae947e1ffb9824aaf659b32e8df30e7476 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 26 Jun 2022 22:34:21 -0700 Subject: [PATCH 036/100] Adding sibling path for MaxInfoSpanningTree --- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 32 ++++ torch/csrc/jit/codegen/cuda/ir_utils.h | 20 +++ .../jit/codegen/cuda/maxinfo_propagator.cpp | 24 +++ .../jit/codegen/cuda/maxinfo_propagator.h | 3 + torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 148 +++++++++++++----- .../jit/codegen/cuda/transform_replay.cpp | 66 +++++--- .../csrc/jit/codegen/cuda/transform_replay.h | 6 + 7 files changed, 245 insertions(+), 54 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 98912c425c5a0..50871f09da3ea 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -529,6 +529,22 @@ TORCH_CUDA_CU_API std::vector consumerValsOf(Val* val) { return uniqueEntries(consumer_vals); } +// Return immediate siblings of val +TORCH_CUDA_CU_API std::vector siblingValsOf(Val* val) { + std::vector sibling_vals; + auto def = val->definition(); + if (def != nullptr) { + auto outs = def->outputs(); + for (auto sibling_val : outs) { + if (sibling_val == val) { + continue; + } + sibling_vals.emplace_back(sibling_val); + } + } + return uniqueEntries(sibling_vals); +} + // Return immediate producers of val TORCH_CUDA_CU_API std::vector producerValsOf( const std::vector& vals) { @@ -574,6 +590,22 @@ std::vector consumerTvsOf(TensorView* tv) { return uniqueEntries(consumer_tvs); } +// Return immediate siblings of tv +TORCH_CUDA_CU_API std::vector siblingTvsOf(TensorView* tv) { + std::vector sibling_tvs; + auto def = tv->definition(); + if (def != nullptr) { + auto outs = ir_utils::filterByType(def->outputs()); + for (auto sibling_tv : outs) { + if (sibling_tv == tv) { + continue; + } + sibling_tvs.emplace_back(sibling_tv); + } + } + return uniqueEntries(sibling_tvs); +} + std::vector producerTvsOf(const std::vector& tvs) { std::vector all_producer_tvs; for (auto tv : tvs) { diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index b4c96ae147872..dd96eda69d608 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -181,6 +181,16 @@ TORCH_CUDA_CU_API std::vector producerValsOf(Val* val); // code. TORCH_CUDA_CU_API std::vector consumerValsOf(Val* val); +// Return immediate siblings of val, this function can be used on any Val and +// will return siblings through Exprs. +// +// Warning: returned val's are not guaranteed to be between fusion inputs and +// outputs. This function simply uses val->definition() or val->uses() which is +// limited to not go through fusion inputs/outputs, but if on a path that isn't +// strictly between fusion inputs/outputs, it could effectively return dead +// code. +TORCH_CUDA_CU_API std::vector siblingValsOf(Val* val); + // Return immediate producers of vals, this function can be used on any vals and // will return producers through Exprs. // @@ -223,6 +233,16 @@ TORCH_CUDA_CU_API std::vector producerTvsOf(TensorView* tv); // code. TORCH_CUDA_CU_API std::vector consumerTvsOf(TensorView* tv); +// Return immediate siblings of tv, this function will return all immediate +// siblings of tv through Exprs. +// +// Warning: returned tv's are not guaranteed to be between fusion inputs and +// outputs. This function simply uses tv->definition() or tv->uses() which is +// limited to not go through fusion inputs/outputs, but if on a path that isn't +// strictly between fusion inputs/outputs, it could effectively return dead +// code. +TORCH_CUDA_CU_API std::vector siblingTvsOf(TensorView* tv); + // Return immediate producers of tvs, this function will return all immediate // producers of tvs through Exprs. // diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index 4f2f212f8d8e0..d0bc3995dc172 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -80,6 +80,13 @@ void MaxInfoSpanningTree::compute_spanning_tree() { return selector_->allowCasP(from, to); }; + auto allowSibling = [this](TensorView* from, TensorView* to) { + if (selector_ == nullptr) { + return true; + } + return selector_->allowSibling(from, to); + }; + while (!candidates.empty()) { const auto next_hop_info = candidates.back(); const auto& next_hop = next_hop_info.next_hop; @@ -91,6 +98,20 @@ void MaxInfoSpanningTree::compute_spanning_tree() { } replayed.emplace(next_hop.to); + for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) { + if (replayed.count(sibling_tv) || + !allowSibling(next_hop.to, sibling_tv)) { + continue; + } + insertNextHop( + {.next_hop = + {.type = NextHopType::SIBLING, + .from = next_hop.to, + .to = sibling_tv}, + .info_from = next_hop_info.info_to, + .info_to = next_hop_info.info_to}); + } + for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) { continue; @@ -127,6 +148,9 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) { } for (const auto& next_hop : path_) { switch (next_hop.type) { + case NextHopType::SIBLING: + propagator->propagateTvSibling(next_hop.from, next_hop.to); + break; case NextHopType::C_AS_P: propagator->propagateTvCasP(next_hop.from, next_hop.to); break; diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 980de8eff31e6..b11738391d665 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -40,12 +40,14 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { struct Selector { virtual bool allowPasC(TensorView* from, TensorView* to) = 0; virtual bool allowCasP(TensorView* from, TensorView* to) = 0; + virtual bool allowSibling(TensorView* from, TensorView* to) = 0; }; // This is the interface to implement the actual propagation struct Propagator { virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0; virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0; + virtual void propagateTvSibling(TensorView* from, TensorView* to) = 0; }; // This is the interface that specifies the structure of information used to @@ -71,6 +73,7 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { private: enum class NextHopType { + SIBLING, C_AS_P, P_AS_C, }; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 4483f3026fc6f..61fedf2a2bcaf 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23615,6 +23615,46 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); } +namespace { + +// check that the resulting tensors in tvs2 are identical +void checkSiblingConsistency(TensorView* replay, TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + TORCH_CHECK(replay_root.size() == target_root.size()); + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + TORCH_CHECK(replay_it != r.end()); + TORCH_CHECK( + replay_it->second == replay_dom[i], + "IterDomain mismatch when checking ", + replay, + " and ", + target, + " at ", + i, + ", got ", + replay_it->second, + " and ", + replay_dom[i]); + } +}; + +} // namespace + TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { // https://github.com/csarofeen/pytorch/issues/1760 Fusion fusion; @@ -23641,41 +23681,6 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { TransformPropagator propagator(tvs2.var_sum); MaxRootDomainInfoSpanningTree(tvs2.var_sum).traverse(&propagator); - // check that the resulting tensors in tvs2 are identical - auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) { - auto replay_root = replay->getRootDomain(); - auto replay_dom = replay->domain()->domain(); - auto target_root = target->getRootDomain(); - auto target_dom = target->domain()->domain(); - std::unordered_map target2replay_map; - TORCH_CHECK(replay_root.size() == target_root.size()); - target2replay_map.reserve(replay_root.size()); - std::transform( - target_root.begin(), - target_root.end(), - replay_root.begin(), - std::inserter(target2replay_map, target2replay_map.begin()), - [](auto a, auto b) { return std::make_pair(a, b); }); - BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); - auto r = replay_.getReplay(); - for (int64_t i = 0; i < replay_dom.size(); i++) { - auto target_id = target_dom[i]; - auto replay_it = r.find(target_id); - TORCH_CHECK(replay_it != r.end()); - TORCH_CHECK( - replay_it->second == replay_dom[i], - "IterDomain mismatch when checking ", - replay, - " and ", - target, - " at ", - i, - ", got ", - replay_it->second, - " and ", - replay_dom[i]); - } - }; std::vector siblings[] = { {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}}; for (auto tensors : siblings) { @@ -23687,6 +23692,71 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { } } +TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion.addOutput(tvs.var_sum); + + tvs.avg->split(1, 1); + tvs.avg->split(1, 2); + tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + + auto tvs2 = tvs.rFactor({1, 4}); + + struct DisableTv0 : public MaxInfoSpanningTree::Selector { + TensorView* tv0; + virtual bool allowPasC(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowCasP(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return true; + } + DisableTv0(TensorView* tv0) : tv0(tv0) {} + } selector1(tv0); + + struct DisableTv0AndSibling : public DisableTv0 { + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } + using DisableTv0::DisableTv0; + } selector2(tv0); + + TransformPropagator propagator(tvs2.var_sum); + MaxRootDomainInfoSpanningTree good_path(tvs2.var_sum, &selector1); + MaxRootDomainInfoSpanningTree bad_path(tvs2.var_sum, &selector2); + + auto check = [&]() { + std::vector siblings[] = { + {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}}; + for (auto tensors : siblings) { + for (auto t1 : tensors) { + for (auto t2 : tensors) { + checkSiblingConsistency(t1, t2); + } + } + } + }; + + bad_path.traverse(&propagator); + ASSERT_ANY_THROW(check()); + good_path.traverse(&propagator); + check(); +} + TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -23812,6 +23882,9 @@ TEST_F(NVFuserTest, FusionTransormPropagatorSelector_CUDA) { virtual bool allowCasP(TensorView* from, TensorView* to) override { return to == tv3; } + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } Selector(TensorView* tv0, TensorView* tv3) : tv0(tv0), tv3(tv3) {} } selector(tv0, tv3); @@ -23874,6 +23947,11 @@ TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { ss << "from: " << from << std::endl; ss << "to: " << to << std::endl; } + virtual void propagateTvSibling(TensorView* from, TensorView* to) override { + ss << "propagateTvSibling" << std::endl; + ss << "from: " << from << std::endl; + ss << "to: " << to << std::endl; + } } printer1, printer2; printer1.ss << std::endl; printer2.ss << std::endl; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 2dc9212b3ffb6..aa0f8c5da278c 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -137,13 +137,12 @@ class ReplaySelf : public ReplayTransformations { : ReplayTransformations(_target_domain, std::move(_id_map), false) {} }; -} // namespace - // Self replay. -TensorDomain* TransformReplay::fullSelfReplay( +TensorDomain* fullSelfReplayImpl( const TensorDomain* new_self_root, - const TensorDomain* self) { - FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay"); + const TensorDomain* self, + bool check) { + FUSER_PERF_SCOPE("fullSelfReplayImpl"); TORCH_INTERNAL_ASSERT( new_self_root->getRootDomain().size() == self->getRootDomain().size(), @@ -154,20 +153,22 @@ TensorDomain* TransformReplay::fullSelfReplay( { size_t i = 0; for (auto id : self->getRootDomain()) { - TORCH_INTERNAL_ASSERT( - new_self_root->getRootDomain()[i]->getParallelType() == - id->getParallelType() && - new_self_root->getRootDomain()[i]->isReduction() == - id->isReduction() && - new_self_root->getRootDomain()[i]->isRFactorProduct() == - id->isRFactorProduct() && - new_self_root->getRootDomain()[i]->isBroadcast() == - id->isBroadcast(), - "Axes ", - id, - " and ", - new_self_root->getRootDomain()[i], - " do not match for self replay."); + if (check) { + TORCH_INTERNAL_ASSERT( + new_self_root->getRootDomain()[i]->getParallelType() == + id->getParallelType() && + new_self_root->getRootDomain()[i]->isReduction() == + id->isReduction() && + new_self_root->getRootDomain()[i]->isRFactorProduct() == + id->isRFactorProduct() && + new_self_root->getRootDomain()[i]->isBroadcast() == + id->isBroadcast(), + "Axes ", + id, + " and ", + new_self_root->getRootDomain()[i], + " do not match for self replay."); + } axis_map[id] = new_self_root->getRootDomain()[i]; i++; } @@ -214,6 +215,26 @@ TensorDomain* TransformReplay::fullSelfReplay( new_self_root->contiguity()); } +} // namespace + +// Self replay. +TensorDomain* TransformReplay::fullSelfReplay( + const TensorDomain* new_self_root, + const TensorDomain* self) { + return fullSelfReplayImpl(new_self_root, self, true); +} + +// Sibling replay. +TensorDomain* TransformReplay::siblingReplay( + const TensorView* replay, + const TensorView* target) { + auto new_domain = IrBuilder::create( + replay->container(), + IterDomain::clone(replay->getRootDomain()), + std::vector(replay->getRootDomain().size(), true)); + return fullSelfReplayImpl(new_domain, target->domain(), false); +} + // Producer could have rfactor axes which consumer may want replayed. We can // "replay" them as long as it doesn't modify the root rfactor axes. What we // really want to do is validate if we replayed these axes to the ones they @@ -660,6 +681,13 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { replayed_pos_[to] = replay.second; } +void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { + int pos = replayed_pos_.at(from); + auto replay = TransformReplay::siblingReplay(to, from); + to->setDomain(replay); + replayed_pos_[to] = pos; +} + TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { if (pos < 0) { pos += int64_t(from->nDims()) + 1; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index d45454e149b63..25fb1f7341eb2 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -150,6 +150,11 @@ class TORCH_CUDA_CU_API TransformReplay { int producer_compute_at_axis, const RootDomainMap& root_map); + // Sibling replay. + static TensorDomain* siblingReplay( + const TensorView* replay, + const TensorView* target); + // Self replay. static TensorDomain* fullSelfReplay( const TensorDomain* new_self_root, @@ -163,6 +168,7 @@ class TORCH_CUDA_CU_API TransformPropagator public: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; TransformPropagator(TensorView* from, int64_t pos = -1); }; From 8210fcef8e7b24e356690ede6d61b9161c92a6d3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 26 Jun 2022 22:40:46 -0700 Subject: [PATCH 037/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 61fedf2a2bcaf..22d37cf47f1ae 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23705,13 +23705,6 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { tvs.avg->split(1, 1); tvs.avg->split(1, 2); tvs.avg->split(1, 3); - tvs.var_sum->split(1, 1); - tvs.var_sum->split(1, 2); - tvs.var_sum->split(1, 3); - tvs.n->split(1, 1); - tvs.n->split(1, 2); - tvs.n->split(1, 3); - auto tvs2 = tvs.rFactor({1, 4}); struct DisableTv0 : public MaxInfoSpanningTree::Selector { From 04525db3bf55877509c372812702b8c60bcf8354 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 26 Jun 2022 23:34:40 -0700 Subject: [PATCH 038/100] remove check in fullSelfReplay --- .../jit/codegen/cuda/transform_replay.cpp | 45 +++---------------- .../csrc/jit/codegen/cuda/transform_replay.h | 5 --- 2 files changed, 5 insertions(+), 45 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index aa0f8c5da278c..406b447257cae 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -137,11 +137,12 @@ class ReplaySelf : public ReplayTransformations { : ReplayTransformations(_target_domain, std::move(_id_map), false) {} }; +} // namespace + // Self replay. -TensorDomain* fullSelfReplayImpl( +TensorDomain* TransformReplay::fullSelfReplay( const TensorDomain* new_self_root, - const TensorDomain* self, - bool check) { + const TensorDomain* self) { FUSER_PERF_SCOPE("fullSelfReplayImpl"); TORCH_INTERNAL_ASSERT( @@ -153,22 +154,6 @@ TensorDomain* fullSelfReplayImpl( { size_t i = 0; for (auto id : self->getRootDomain()) { - if (check) { - TORCH_INTERNAL_ASSERT( - new_self_root->getRootDomain()[i]->getParallelType() == - id->getParallelType() && - new_self_root->getRootDomain()[i]->isReduction() == - id->isReduction() && - new_self_root->getRootDomain()[i]->isRFactorProduct() == - id->isRFactorProduct() && - new_self_root->getRootDomain()[i]->isBroadcast() == - id->isBroadcast(), - "Axes ", - id, - " and ", - new_self_root->getRootDomain()[i], - " do not match for self replay."); - } axis_map[id] = new_self_root->getRootDomain()[i]; i++; } @@ -215,26 +200,6 @@ TensorDomain* fullSelfReplayImpl( new_self_root->contiguity()); } -} // namespace - -// Self replay. -TensorDomain* TransformReplay::fullSelfReplay( - const TensorDomain* new_self_root, - const TensorDomain* self) { - return fullSelfReplayImpl(new_self_root, self, true); -} - -// Sibling replay. -TensorDomain* TransformReplay::siblingReplay( - const TensorView* replay, - const TensorView* target) { - auto new_domain = IrBuilder::create( - replay->container(), - IterDomain::clone(replay->getRootDomain()), - std::vector(replay->getRootDomain().size(), true)); - return fullSelfReplayImpl(new_domain, target->domain(), false); -} - // Producer could have rfactor axes which consumer may want replayed. We can // "replay" them as long as it doesn't modify the root rfactor axes. What we // really want to do is validate if we replayed these axes to the ones they @@ -683,7 +648,7 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::siblingReplay(to, from); + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); to->setDomain(replay); replayed_pos_[to] = pos; } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 25fb1f7341eb2..c24ffa93f2954 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -150,11 +150,6 @@ class TORCH_CUDA_CU_API TransformReplay { int producer_compute_at_axis, const RootDomainMap& root_map); - // Sibling replay. - static TensorDomain* siblingReplay( - const TensorView* replay, - const TensorView* target); - // Self replay. static TensorDomain* fullSelfReplay( const TensorDomain* new_self_root, From b18fa5bab14f6d9e064982e42dd0159385055301 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 00:02:12 -0700 Subject: [PATCH 039/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 22d37cf47f1ae..58ad72c3a3ca3 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23617,7 +23617,7 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { namespace { -// check that the resulting tensors in tvs2 are identical +// check that the resulting sibling are identical void checkSiblingConsistency(TensorView* replay, TensorView* target) { auto replay_root = replay->getRootDomain(); auto replay_dom = replay->domain()->domain(); From 26c1d4e08d26c3edd0f4d46a4dc778ed27788543 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 00:07:54 -0700 Subject: [PATCH 040/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 7 +++++++ torch/csrc/jit/codegen/cuda/transform_replay.cpp | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 58ad72c3a3ca3..68a6be40bda6c 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23705,6 +23705,13 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { tvs.avg->split(1, 1); tvs.avg->split(1, 2); tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + auto tvs2 = tvs.rFactor({1, 4}); struct DisableTv0 : public MaxInfoSpanningTree::Selector { diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 406b447257cae..0dc2affd4b578 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -143,7 +143,7 @@ class ReplaySelf : public ReplayTransformations { TensorDomain* TransformReplay::fullSelfReplay( const TensorDomain* new_self_root, const TensorDomain* self) { - FUSER_PERF_SCOPE("fullSelfReplayImpl"); + FUSER_PERF_SCOPE("TransformReplay::fullSelfReplay"); TORCH_INTERNAL_ASSERT( new_self_root->getRootDomain().size() == self->getRootDomain().size(), From 32312cedbc19e56e32d3c4e5b25f6bb319984b77 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 01:00:25 -0700 Subject: [PATCH 041/100] save? --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 31 ++-- torch/csrc/jit/codegen/cuda/consume_at.cpp | 48 ++++++- .../jit/codegen/cuda/ir_interface_nodes.h | 2 + .../jit/codegen/cuda/transform_replay.cpp | 132 ++++++++++++++++++ .../csrc/jit/codegen/cuda/transform_replay.h | 10 ++ 5 files changed, 204 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 9bdd2ad704d9f..469b50ce84bf9 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -50,19 +50,19 @@ std::unordered_set getAllTVsBetween( " is a dependency of ", consumer->name(), ", however it is not."); - auto between_vals = DependencyCheck::getAllValsBetween({producer}, {consumer}); + auto between_vals = + DependencyCheck::getAllValsBetween({producer}, {consumer}); auto between_tvs = ir_utils::filterByType(between_vals); - std::unordered_set result(between_tvs.begin(), between_tvs.end()); + std::unordered_set result( + between_tvs.begin(), between_tvs.end()); result.erase(consumer); return result; } -TensorView * getCommonConsumer( - TensorView *producer, - TensorView *consumer -) { +TensorView* getCommonConsumer(TensorView* producer, TensorView* consumer) { FUSER_PERF_SCOPE("ComputeAt::setCommonConsumer"); - auto producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer)); + auto producer_use_chains_ = + tvChains(DependencyCheck::getAllUseChains(producer)); // Convert the first chain to a set. std::set common_consumers( @@ -99,7 +99,7 @@ TensorView * getCommonConsumer( } // If there is a common consumer, grab the first one at or after consumer - TensorView *common_consumer = nullptr; + TensorView* common_consumer = nullptr; if (!common_consumers.empty()) { for (auto tv : producer_use_chains_.front()) { if (common_consumers.find(tv) != common_consumers.end()) { @@ -114,7 +114,7 @@ TensorView * getCommonConsumer( return common_consumer; } -void pullInSiblings(std::unordered_set &s) { +void pullInSiblings(std::unordered_set& s) { for (auto tv : s) { auto tvd = tv->definition(); if (tvd != nullptr) { @@ -130,11 +130,10 @@ void pullInSiblings(std::unordered_set &s) { } } - // I am just trying to get the same set of tensors being transformed matching // the previous behavior of ComputeAt. The algorithm to compute this set is -// horrible, but I don't care because I will eventually completely remove ComputeAt, -// and this algorihtm is not worse than the pervious ComputeAt. :) +// horrible, but I don't care because I will eventually completely remove +// ComputeAt, and this algorihtm is not worse than the pervious ComputeAt. :) std::unordered_set getPropagationSubgraph( TensorView* producer, TensorView* consumer) { @@ -145,7 +144,7 @@ std::unordered_set getPropagationSubgraph( " is a dependency of ", consumer->name(), ", however it is not."); - TensorView *common_consumer = getCommonConsumer(producer, consumer); + TensorView* common_consumer = getCommonConsumer(producer, consumer); if (common_consumer != nullptr) { auto result = getAllTVsBetween(producer, common_consumer); pullInSiblings(result); @@ -182,6 +181,9 @@ void ComputeAt::runAt( " are not in the same fusion."); FusionGuard fg(producer->fusion()); + std::cout << "ComputeAt::runAt(producer=" << producer + << ", consumer=" << consumer + << ", consumer_position=" << consumer_position << ")" << std::endl; ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), consumer, @@ -206,6 +208,9 @@ void ComputeAt::runWith( // Make sure Fusion Guard is set appropriately FusionGuard fg(producer->fusion()); + std::cout << "ComputeAt::runWith(producer=" << producer + << ", consumer=" << consumer + << ", producer_position=" << producer_position << ")" << std::endl; ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 640e090f774e4..d6b48bcb385cb 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -65,16 +65,38 @@ void ComputeAtPosPropagator::consumeAllAt( // std::cout << "From: " << reference << " at pos " << reference_pos // << std::endl; // std::cout << "Before:" << std::endl; - // reference->fusion()->print(); + reference->fusion()->print(); ComputeAtSubgraphSelector selector(consume); - ComputeAtPosPropagator propagator( - std::move(consume), reference, reference_pos, mode); + WeakTransformPropagator propagator(reference, reference_pos); + ComputeAtPosPropagator ca_propagator( + consume, reference, reference_pos, mode); + + struct Printer : public MaxInfoSpanningTree::Propagator { + std::stringstream ss; + virtual void propagateTvPasC(TensorView* from, TensorView* to) override { + ss << "propagateTvPasC" << std::endl; + ss << "from: " << from << std::endl; + ss << "to: " << to << std::endl; + } + virtual void propagateTvCasP(TensorView* from, TensorView* to) override { + ss << "propagateTvCasP" << std::endl; + ss << "from: " << from << std::endl; + ss << "to: " << to << std::endl; + } + } printer; - propagator.recordReplayedPos(reference, reference_pos); MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); + path.traverse(&printer); + // std::cout << "consume: " << ir_utils::toString(consume) << std::endl; + // std::cout << "Path:\n" << printer.ss.str() << std::endl; path.traverse(&propagator); - propagator.hoistInnermostBroadcast(); - propagator.computeMaxProducerPos(); + // std::cout << "After TransformPropagator:" << std::endl; + reference->fusion()->print(); + + ca_propagator.recordReplayedPos(reference, reference_pos); + path.traverse(&ca_propagator); + ca_propagator.hoistInnermostBroadcast(); // TODO: this should be inlined to recordReplayedPos + ca_propagator.computeMaxProducerPos(); // std::cout << "\n\nAfter:" << std::endl; // reference->fusion()->print(); // std::cout << "==========================" << std::endl; @@ -351,6 +373,7 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { // std::cout << ", at: " << pos; // Short cut if no replay is necessary auto to_pos = skipReplay(to, from, (int)pos, true); + // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); @@ -366,6 +389,7 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { // std::cout << ", at: " << pos; // Short cut if no replay is necessary auto to_pos = skipReplay(from, to, (int)pos, false); + // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); @@ -404,6 +428,12 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( TensorView* producer, TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); + for (auto consumer_tv : ir_utils::consumerTvsOf(producer)) { + max_pos = std::min(max_pos, getReplayablePosPasC(producer, consumer_tv)); + } + // for (auto producer_tv : ir_utils::producerTvsOf(producer)) { + // max_pos = std::min(max_pos, getReplayablePosCasP(producer, producer_tv)); + // } size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; @@ -431,6 +461,12 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( TensorView* consumer, TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); + for (auto consumer_tv : ir_utils::consumerTvsOf(consumer)) { + max_pos = std::min(max_pos, getReplayablePosPasC(consumer, consumer_tv)); + } + // for (auto producer_tv : ir_utils::producerTvsOf(consumer)) { + // max_pos = std::min(max_pos, getReplayablePosCasP(consumer, producer_tv)); + // } size_t pos = retrieveReplayedPos(producer); // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << // producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 2bc36461986f5..cb47b28f3c040 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -156,6 +156,7 @@ enum class ComputeAtMode { Standard, BestEffort, MostInlined }; class ComputeAtPosPropagator; class TransformPropagator; +class WeakTransformPropagator; class TransformIter; class TransformReplay; class OptOutMutator; @@ -456,6 +457,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { void applyMmaSwizzle(MmaOptions options); friend TORCH_CUDA_CU_API TransformPropagator; + friend TORCH_CUDA_CU_API WeakTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; friend TORCH_CUDA_CU_API ComputeAtPosPropagator; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 2dc9212b3ffb6..210e11082eec9 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -670,6 +670,138 @@ TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { replayed_pos_[from] = pos; } +namespace { + +// Checks if producer and consumer are transformed consistently so that to +// satisfy the provided compute at position. This means no replay is actually +// necessary for the compute at requested. If consumer_pos then +// consumer_or_producer_pos is relative to the consumer and skipReplay returns +// the associated position in producer. +// +// If producer and consumer are not transformed consistently with provided +// postition, returns -1. +int skipReplay( + const TensorView* producer, + const TensorView* consumer, + int consumer_or_producer_pos, + bool consumer_pos = true) { + const auto c2p_root_map = + PairwiseRootDomainMap(producer, consumer) + .mapConsumerToProducer(consumer->domain(), producer->domain()); + + // IterDomains in consumer root also in producer root + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); + } + + const auto consumer_domain = consumer->domain()->domain(); + + auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set mapped_consumer_domain_ids( + mapped_consumer_domain_ids_vec.begin(), + mapped_consumer_domain_ids_vec.end()); + + const auto producer_domain = producer->domain()->domain(); + + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); + + auto best_effort_PasC = BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + + auto c2p_map = best_effort_PasC.getReplay(); + + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + auto consumer_id = *it_consumer; + if (!mapped_consumer_domain_ids.count(consumer_id)) { + ++it_consumer; + mismatched_consumer_pos++; + continue; + } + + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + break; + } + + if (it_producer == producer_domain.end()) { + break; + } + + auto producer_id = *it_producer; + + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + if (consumer_pos) { + if (consumer_or_producer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + } else { + if (consumer_or_producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + } + } else { + break; + } + } + return -1; +} + +} // namespace + +void WeakTransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { + std::cout << "propagateTvPasC" << std::endl; + std::cout << "from: " << from << std::endl; + std::cout << "to: " << to << std::endl; + int from_pos = replayed_pos_.at(from); + int to_pos = skipReplay(to, from, from_pos, true); + if (to_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, from_pos); + to->setDomain(replay.first); + to_pos = replay.second; + std::cout << "result: " << to << std::endl; + } else { + std::cout << "skipped" << std::endl; + } + replayed_pos_[to] = to_pos; +} + +void WeakTransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { + std::cout << "propagateTvCasP" << std::endl; + std::cout << "from: " << from << std::endl; + std::cout << "to: " << to << std::endl; + int from_pos = replayed_pos_.at(from); + int to_pos = skipReplay(from, to, from_pos, false); + if (to_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, from_pos); + to->setDomain(replay.first); + to_pos = replay.second; + std::cout << "result: " << to << std::endl; + } else { + std::cout << "skipped" << std::endl; + } + replayed_pos_[to] = to_pos; +} + +WeakTransformPropagator::WeakTransformPropagator(TensorView* from, int64_t pos) { + if (pos < 0) { + pos += int64_t(from->nDims()) + 1; + } + TORCH_CHECK( + pos >= 0 && pos <= from->nDims(), + "TransformPropagator called on an pos outside valid range."); + replayed_pos_[from] = pos; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index d45454e149b63..01416b7fad5fe 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -166,6 +166,16 @@ class TORCH_CUDA_CU_API TransformPropagator TransformPropagator(TensorView* from, int64_t pos = -1); }; +class TORCH_CUDA_CU_API WeakTransformPropagator + : public MaxRootDomainInfoSpanningTree::Propagator { + std::unordered_map replayed_pos_; + + public: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + WeakTransformPropagator(TensorView* from, int64_t pos = -1); +}; + } // namespace cuda } // namespace fuser } // namespace jit From 96ae406daab83819a2acc659daf8e960d10e098d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 01:35:46 -0700 Subject: [PATCH 042/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 38 +++++++++++++++---- torch/csrc/jit/codegen/cuda/consume_at.h | 2 + .../jit/codegen/cuda/transform_replay.cpp | 8 ++++ .../csrc/jit/codegen/cuda/transform_replay.h | 1 + 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index d6b48bcb385cb..a8d39123b0cd1 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -68,8 +68,7 @@ void ComputeAtPosPropagator::consumeAllAt( reference->fusion()->print(); ComputeAtSubgraphSelector selector(consume); WeakTransformPropagator propagator(reference, reference_pos); - ComputeAtPosPropagator ca_propagator( - consume, reference, reference_pos, mode); + ComputeAtPosPropagator ca_propagator(consume, reference, reference_pos, mode); struct Printer : public MaxInfoSpanningTree::Propagator { std::stringstream ss; @@ -83,6 +82,11 @@ void ComputeAtPosPropagator::consumeAllAt( ss << "from: " << from << std::endl; ss << "to: " << to << std::endl; } + virtual void propagateTvSibling(TensorView* from, TensorView* to) override { + ss << "propagateTvSibling" << std::endl; + ss << "from: " << from << std::endl; + ss << "to: " << to << std::endl; + } } printer; MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); @@ -95,7 +99,8 @@ void ComputeAtPosPropagator::consumeAllAt( ca_propagator.recordReplayedPos(reference, reference_pos); path.traverse(&ca_propagator); - ca_propagator.hoistInnermostBroadcast(); // TODO: this should be inlined to recordReplayedPos + ca_propagator.hoistInnermostBroadcast(); // TODO: this should be inlined to + // recordReplayedPos ca_propagator.computeMaxProducerPos(); // std::cout << "\n\nAfter:" << std::endl; // reference->fusion()->print(); @@ -399,6 +404,17 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { // std::cout << ", result: " << to << std::endl; } +void ComputeAtPosPropagator::propagateTvSibling( + TensorView* from, + TensorView* to) { + // TODO: copy-paste computeAts + auto from_pos = retrieveReplayedPos(from); + // TODO: check skip + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + to->setDomain(replay); + recordReplayedPos(to, from_pos); +} + void ComputeAtPosPropagator::recordReplayedPos(TensorView* tv, size_t pos) { if (consume_.count(tv)) { auto max_pos = getMaxComputeAtPos(tv); @@ -429,10 +445,12 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); for (auto consumer_tv : ir_utils::consumerTvsOf(producer)) { - max_pos = std::min(max_pos, getReplayablePosPasC(producer, consumer_tv)); + max_pos = + std::min(max_pos, getReplayablePosPasC(producer, consumer_tv)); } // for (auto producer_tv : ir_utils::producerTvsOf(producer)) { - // max_pos = std::min(max_pos, getReplayablePosCasP(producer, producer_tv)); + // max_pos = std::min(max_pos, getReplayablePosCasP(producer, + // producer_tv)); // } size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << @@ -462,10 +480,12 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); for (auto consumer_tv : ir_utils::consumerTvsOf(consumer)) { - max_pos = std::min(max_pos, getReplayablePosPasC(consumer, consumer_tv)); + max_pos = + std::min(max_pos, getReplayablePosPasC(consumer, consumer_tv)); } // for (auto producer_tv : ir_utils::producerTvsOf(consumer)) { - // max_pos = std::min(max_pos, getReplayablePosCasP(consumer, producer_tv)); + // max_pos = std::min(max_pos, getReplayablePosCasP(consumer, + // producer_tv)); // } size_t pos = retrieveReplayedPos(producer); // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << @@ -585,6 +605,10 @@ bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { return consume_.count(from) > 0 || consume_.count(to) > 0; } +bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { + return true; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index b3d739b8aa14a..868008626be81 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -17,6 +17,7 @@ class ComputeAtSubgraphSelector : public MaxInfoSpanningTree::Selector { public: virtual bool allowPasC(TensorView* from, TensorView* to) override; virtual bool allowCasP(TensorView* from, TensorView* to) override; + virtual bool allowSibling(TensorView* from, TensorView* to) override; ComputeAtSubgraphSelector(std::unordered_set consume) : consume_(std::move(consume)) {} }; @@ -61,6 +62,7 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { protected: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; public: static void consumeAllAt( diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 9246005d10060..d33e8e0d28e7b 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -785,6 +785,14 @@ void WeakTransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) replayed_pos_[to] = to_pos; } +void WeakTransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { + int pos = replayed_pos_.at(from); + // TODO: check skip + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + to->setDomain(replay); + replayed_pos_[to] = pos; +} + WeakTransformPropagator::WeakTransformPropagator(TensorView* from, int64_t pos) { if (pos < 0) { pos += int64_t(from->nDims()) + 1; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 589f20839f7e1..7f63438e6ba3d 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -174,6 +174,7 @@ class TORCH_CUDA_CU_API WeakTransformPropagator public: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; WeakTransformPropagator(TensorView* from, int64_t pos = -1); }; From a42daf3ef2e0075b95f383de5bbe60d5f1539908 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 15:37:36 -0700 Subject: [PATCH 043/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 42 ++++++------------- .../jit/codegen/cuda/transform_replay.cpp | 20 ++++----- 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index a8d39123b0cd1..354f0f9942c01 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -61,10 +61,10 @@ void ComputeAtPosPropagator::consumeAllAt( TensorView* reference, size_t reference_pos, ComputeAtMode mode) { - // std::cout << "==========================" << std::endl; - // std::cout << "From: " << reference << " at pos " << reference_pos - // << std::endl; - // std::cout << "Before:" << std::endl; + std::cout << "==========================" << std::endl; + std::cout << "From: " << reference << " at pos " << reference_pos + << std::endl; + std::cout << "Before:" << std::endl; reference->fusion()->print(); ComputeAtSubgraphSelector selector(consume); WeakTransformPropagator propagator(reference, reference_pos); @@ -102,9 +102,9 @@ void ComputeAtPosPropagator::consumeAllAt( ca_propagator.hoistInnermostBroadcast(); // TODO: this should be inlined to // recordReplayedPos ca_propagator.computeMaxProducerPos(); - // std::cout << "\n\nAfter:" << std::endl; - // reference->fusion()->print(); - // std::cout << "==========================" << std::endl; + std::cout << "\n\nAfter:" << std::endl; + reference->fusion()->print(); + std::cout << "==========================" << std::endl; } size_t ComputeAtPosPropagator::getMaxComputeAtPos(TensorView* tv) { @@ -373,9 +373,9 @@ int skipReplay( } // namespace void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC, from: " << from << ", to:" << to; + std::cout << "propagateTvPasC\nfrom: " << from << "\nto: " << to << std::endl; int pos = getReplayPosPasC(to, from); - // std::cout << ", at: " << pos; + std::cout << "at: " << pos << std::endl; // Short cut if no replay is necessary auto to_pos = skipReplay(to, from, (int)pos, true); // TORCH_CHECK(to_pos >= 0); @@ -385,13 +385,13 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - // std::cout << ", result: " << to << std::endl; + std::cout << "result: " << to << std::endl; } void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP, from: " << from << ", to:" << to; + std::cout << "propagateTvCasP\nfrom: " << from << "\nto: " << to << std::endl; int pos = getReplayPosCasP(to, from); - // std::cout << ", at: " << pos; + std::cout << "at: " << pos << std::endl; // Short cut if no replay is necessary auto to_pos = skipReplay(from, to, (int)pos, false); // TORCH_CHECK(to_pos >= 0); @@ -401,7 +401,7 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - // std::cout << ", result: " << to << std::endl; + std::cout << "result: " << to << std::endl; } void ComputeAtPosPropagator::propagateTvSibling( @@ -444,14 +444,6 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( TensorView* producer, TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); - for (auto consumer_tv : ir_utils::consumerTvsOf(producer)) { - max_pos = - std::min(max_pos, getReplayablePosPasC(producer, consumer_tv)); - } - // for (auto producer_tv : ir_utils::producerTvsOf(producer)) { - // max_pos = std::min(max_pos, getReplayablePosCasP(producer, - // producer_tv)); - // } size_t pos = retrieveReplayedPos(consumer); // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; @@ -479,14 +471,6 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( TensorView* consumer, TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); - for (auto consumer_tv : ir_utils::consumerTvsOf(consumer)) { - max_pos = - std::min(max_pos, getReplayablePosPasC(consumer, consumer_tv)); - } - // for (auto producer_tv : ir_utils::producerTvsOf(consumer)) { - // max_pos = std::min(max_pos, getReplayablePosCasP(consumer, - // producer_tv)); - // } size_t pos = retrieveReplayedPos(producer); // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << // producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d33e8e0d28e7b..ef4e0465020f1 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -752,35 +752,35 @@ int skipReplay( } // namespace void WeakTransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { - std::cout << "propagateTvPasC" << std::endl; - std::cout << "from: " << from << std::endl; - std::cout << "to: " << to << std::endl; + // std::cout << "propagateTvPasC" << std::endl; + // std::cout << "from: " << from << std::endl; + // std::cout << "to: " << to << std::endl; int from_pos = replayed_pos_.at(from); int to_pos = skipReplay(to, from, from_pos, true); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, from_pos); to->setDomain(replay.first); to_pos = replay.second; - std::cout << "result: " << to << std::endl; + // std::cout << "result: " << to << std::endl; } else { - std::cout << "skipped" << std::endl; + // std::cout << "skipped" << std::endl; } replayed_pos_[to] = to_pos; } void WeakTransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { - std::cout << "propagateTvCasP" << std::endl; - std::cout << "from: " << from << std::endl; - std::cout << "to: " << to << std::endl; + // std::cout << "propagateTvCasP" << std::endl; + // std::cout << "from: " << from << std::endl; + // std::cout << "to: " << to << std::endl; int from_pos = replayed_pos_.at(from); int to_pos = skipReplay(from, to, from_pos, false); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, from_pos); to->setDomain(replay.first); to_pos = replay.second; - std::cout << "result: " << to << std::endl; + // std::cout << "result: " << to << std::endl; } else { - std::cout << "skipped" << std::endl; + // std::cout << "skipped" << std::endl; } replayed_pos_[to] = to_pos; } From 3e277e2aadcc049d01a29b3dec40f832caaaebf0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 16:13:40 -0700 Subject: [PATCH 044/100] no hoistInnermostBroadcast --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 44 ++++++++++++---------- torch/csrc/jit/codegen/cuda/consume_at.h | 3 +- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 354f0f9942c01..21edc1395b0a2 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -99,18 +99,16 @@ void ComputeAtPosPropagator::consumeAllAt( ca_propagator.recordReplayedPos(reference, reference_pos); path.traverse(&ca_propagator); - ca_propagator.hoistInnermostBroadcast(); // TODO: this should be inlined to - // recordReplayedPos ca_propagator.computeMaxProducerPos(); std::cout << "\n\nAfter:" << std::endl; reference->fusion()->print(); std::cout << "==========================" << std::endl; } -size_t ComputeAtPosPropagator::getMaxComputeAtPos(TensorView* tv) { +size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto dom = tv->domain()->domain(); auto first_reduction = std::find_if( - dom.begin(), dom.end(), [](IterDomain* id) { return id->isReduction(); }); + dom.begin(), dom.begin() + pos, [](IterDomain* id) { return id->isReduction(); }); auto first_vectorized_axis = std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { @@ -120,7 +118,25 @@ size_t ComputeAtPosPropagator::getMaxComputeAtPos(TensorView* tv) { id->getParallelType() == ParallelType::Unroll); }); - return std::distance(dom.begin(), first_vectorized_axis); + // hoist inner most broadcast + auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); + std::cout << "tensor: " << tv << ", ca_pos: " << ca_pos << std::endl; + while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { + ca_pos--; + std::cout << "ca_pos: " << ca_pos << std::endl; + } + std::cout << "final ca_pos: " << ca_pos << std::endl; + + // TODO: check the following: + // Cannot inline: + // Reduction dimensions in producer + // Block broadcast dimensions in producer + // Vectorized dimensions in producer or consumer + // Unrolled dimensions in producer or consumer + // Dimensions derived from root dimensions that exist in both but are + // unmappable + + return ca_pos; } // Return the max position in consumer that producer can be inlined to @@ -417,10 +433,10 @@ void ComputeAtPosPropagator::propagateTvSibling( void ComputeAtPosPropagator::recordReplayedPos(TensorView* tv, size_t pos) { if (consume_.count(tv)) { - auto max_pos = getMaxComputeAtPos(tv); - if (pos > max_pos) { + auto new_pos = adjustComputeAtPos(tv, pos); + if (pos != new_pos) { replayed_pos_[tv] = pos; - pos = max_pos; + pos = new_pos; } if (!tv->isFusionInput()) { tv->setComputeAt(pos); @@ -494,18 +510,6 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( return pos; } -void ComputeAtPosPropagator::hoistInnermostBroadcast() { - for (auto tv : consume_) { - if (!tv->isFusionInput()) { - auto ca_pos = tv->getComputeAtPosition(); - while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { - ca_pos--; - } - tv->setComputeAt(ca_pos, true); - } - } -} - // TODO: most of this is copy-pasted code. I need to investigate this to // see which makes sense and which needs change. // Try to find the aligned position on consumer's domain corresponding to the diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 868008626be81..a97718c78ecf0 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -31,9 +31,8 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { void recordReplayedPos(TensorView* tv, size_t pos); size_t retrieveReplayedPos(TensorView* tv); - size_t getMaxComputeAtPos(TensorView* tv); + size_t adjustComputeAtPos(TensorView* tv, size_t pos); - void hoistInnermostBroadcast(); void computeMaxProducerPos(); // Iterate through all TVs and collect the dimensions of each TV that don't From 4c3342cbc44940b0c828e49a29161273cdb3e863 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 17:25:34 -0700 Subject: [PATCH 045/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 30 ++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 21edc1395b0a2..8e9b4e61def04 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -118,23 +118,27 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { id->getParallelType() == ParallelType::Unroll); }); - // hoist inner most broadcast auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); - std::cout << "tensor: " << tv << ", ca_pos: " << ca_pos << std::endl; + + std::cout << "adjustComputeAtPos: " << tv << std::endl; + for (auto producer_tv : ir_utils::producerTvsOf(tv)) { + std::cout << "producer_tv: " << producer_tv << std::endl; + auto max_pos = getReplayablePosPasC(producer_tv, tv); + std::cout << "max_pos: " << max_pos << std::endl; + ca_pos = std::min(ca_pos, max_pos); + } + for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { + std::cout << "consumer_tv: " << consumer_tv << std::endl; + auto max_pos = getReplayablePosCasP(consumer_tv, tv); + std::cout << "max_pos: " << max_pos << std::endl; + ca_pos = std::min(ca_pos, max_pos); + } + + // hoist inner most broadcast while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { ca_pos--; - std::cout << "ca_pos: " << ca_pos << std::endl; } - std::cout << "final ca_pos: " << ca_pos << std::endl; - - // TODO: check the following: - // Cannot inline: - // Reduction dimensions in producer - // Block broadcast dimensions in producer - // Vectorized dimensions in producer or consumer - // Unrolled dimensions in producer or consumer - // Dimensions derived from root dimensions that exist in both but are - // unmappable + return ca_pos; } From 702f2b0d6de658c1303d4ce236abda74902f67f1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 17:30:30 -0700 Subject: [PATCH 046/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 8e4e366fdcb13..d31d250ba5080 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23949,8 +23949,8 @@ TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { } virtual void propagateTvSibling(TensorView* from, TensorView* to) override { ss << "propagateTvSibling" << std::endl; - ss << "from: " << from << std::endl; - ss << "to: " << to << std::endl; + ss << "from: " << from->name() << std::endl; + ss << "to: " << to->name() << std::endl; } } printer1, printer2; printer1.ss << std::endl; From ba0e8af9e97954bd5acf4e14f6743b7fab559840 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 17:51:52 -0700 Subject: [PATCH 047/100] resolve review --- torch/csrc/jit/codegen/cuda/ir_utils.cpp | 35 +++++-------------- .../jit/codegen/cuda/maxinfo_propagator.cpp | 14 +++++++- .../jit/codegen/cuda/maxinfo_propagator.h | 13 +++++-- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 50871f09da3ea..4fbff177e7e90 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -572,38 +572,21 @@ TORCH_CUDA_CU_API std::vector consumerValsOf( } std::vector producerTvsOf(TensorView* tv) { - if (tv->definition() == nullptr) { - return {}; - } - auto producer_vals = - ir_utils::filterByType(tv->definition()->inputs()); - return uniqueEntries( - {producer_vals.begin(), producer_vals.end()}); + auto producer_vals = producerValsOf(tv); + auto producer_tvs = ir_utils::filterByType(producer_vals); + return {producer_tvs.begin(), producer_tvs.end()}; } std::vector consumerTvsOf(TensorView* tv) { - std::vector consumer_tvs; - for (auto use_expr : tv->uses()) { - auto outputs = ir_utils::filterByType(use_expr->outputs()); - consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end()); - } - return uniqueEntries(consumer_tvs); + auto consumer_vals = consumerValsOf(tv); + auto consumer_tvs = ir_utils::filterByType(consumer_vals); + return {consumer_tvs.begin(), consumer_tvs.end()}; } -// Return immediate siblings of tv TORCH_CUDA_CU_API std::vector siblingTvsOf(TensorView* tv) { - std::vector sibling_tvs; - auto def = tv->definition(); - if (def != nullptr) { - auto outs = ir_utils::filterByType(def->outputs()); - for (auto sibling_tv : outs) { - if (sibling_tv == tv) { - continue; - } - sibling_tvs.emplace_back(sibling_tv); - } - } - return uniqueEntries(sibling_tvs); + auto sibling_vals = siblingValsOf(tv); + auto sibling_tvs = ir_utils::filterByType(sibling_vals); + return {sibling_tvs.begin(), sibling_tvs.end()}; } std::vector producerTvsOf(const std::vector& tvs) { diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index d0bc3995dc172..06c2fcaf01547 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -109,7 +109,8 @@ void MaxInfoSpanningTree::compute_spanning_tree() { .from = next_hop.to, .to = sibling_tv}, .info_from = next_hop_info.info_to, - .info_to = next_hop_info.info_to}); + .info_to = computeInfoSibling( + next_hop.to, sibling_tv, next_hop_info.info_to)}); } for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { @@ -404,6 +405,17 @@ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo( return std::make_shared(std::move(result)); } +// Given the preserved reference root ID info of a tensor, compute +// the corresponding info in its sibling. Since info has nothing to do with +// replay state, so sibling info is always identical by definition. +std::shared_ptr MaxRootDomainInfoSpanningTree:: + computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const { + return from_info; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index b11738391d665..db32aaef6d23c 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -29,8 +29,9 @@ namespace cuda { * MaxInfoSpanningTree::Information and implement `operator<` which is used to * tell which path contains more information, and `operator bool` which is used * to tell if there is any information stored. You also need to implement - * computeInfoPasC and computeInfoCasP, which are the functions that compute - * information of the `to` tensor from the information of the `from` tensor. + * computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the + * functions that compute information of the `to` tensor from the information of + * the `from` tensor. */ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_CUDA_CU_API MaxInfoSpanningTree { @@ -112,6 +113,10 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { TensorView* from, TensorView* to, std::shared_ptr from_info) const = 0; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const = 0; public: MaxInfoSpanningTree( @@ -193,6 +198,10 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree TensorView* from, TensorView* to, std::shared_ptr from_info) const override; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const override; private: static std::shared_ptr getReferenceRootIDInfo(TensorView* tv); From bfe66c7478892e7a72241164ba317e848a019f1e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 18:20:39 -0700 Subject: [PATCH 048/100] save --- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 0dc2affd4b578..9015331a1417b 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -154,6 +154,18 @@ TensorDomain* TransformReplay::fullSelfReplay( { size_t i = 0; for (auto id : self->getRootDomain()) { + TORCH_INTERNAL_ASSERT( + new_self_root->getRootDomain()[i]->isReduction() == + id->isReduction() && + new_self_root->getRootDomain()[i]->isRFactorProduct() == + id->isRFactorProduct() && + new_self_root->getRootDomain()[i]->isBroadcast() == + id->isBroadcast(), + "Axes ", + id, + " and ", + new_self_root->getRootDomain()[i], + " do not match for self replay."); axis_map[id] = new_self_root->getRootDomain()[i]; i++; } From de2f3be3c66ec2e14248f9197b7520c9f61776ea Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 18:45:34 -0700 Subject: [PATCH 049/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 13 +++++++------ torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp | 14 +++++++++++++- torch/csrc/jit/codegen/cuda/maxinfo_propagator.h | 13 +++++++++++-- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 8e9b4e61def04..dab70b7287fa1 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -121,12 +121,13 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); std::cout << "adjustComputeAtPos: " << tv << std::endl; - for (auto producer_tv : ir_utils::producerTvsOf(tv)) { - std::cout << "producer_tv: " << producer_tv << std::endl; - auto max_pos = getReplayablePosPasC(producer_tv, tv); - std::cout << "max_pos: " << max_pos << std::endl; - ca_pos = std::min(ca_pos, max_pos); - } + // TODO: why I have to disable this?!! + // for (auto producer_tv : ir_utils::producerTvsOf(tv)) { + // std::cout << "producer_tv: " << producer_tv << std::endl; + // auto max_pos = getReplayablePosPasC(producer_tv, tv); + // std::cout << "max_pos: " << max_pos << std::endl; + // ca_pos = std::min(ca_pos, max_pos); + // } for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { std::cout << "consumer_tv: " << consumer_tv << std::endl; auto max_pos = getReplayablePosCasP(consumer_tv, tv); diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index d0bc3995dc172..06c2fcaf01547 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -109,7 +109,8 @@ void MaxInfoSpanningTree::compute_spanning_tree() { .from = next_hop.to, .to = sibling_tv}, .info_from = next_hop_info.info_to, - .info_to = next_hop_info.info_to}); + .info_to = computeInfoSibling( + next_hop.to, sibling_tv, next_hop_info.info_to)}); } for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { @@ -404,6 +405,17 @@ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo( return std::make_shared(std::move(result)); } +// Given the preserved reference root ID info of a tensor, compute +// the corresponding info in its sibling. Since info has nothing to do with +// replay state, so sibling info is always identical by definition. +std::shared_ptr MaxRootDomainInfoSpanningTree:: + computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const { + return from_info; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 41d655ae6daf1..f2e1591e018f1 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -29,8 +29,9 @@ namespace cuda { * MaxInfoSpanningTree::Information and implement `operator<` which is used to * tell which path contains more information, and `operator bool` which is used * to tell if there is any information stored. You also need to implement - * computeInfoPasC and computeInfoCasP, which are the functions that compute - * information of the `to` tensor from the information of the `from` tensor. + * computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the + * functions that compute information of the `to` tensor from the information of + * the `from` tensor. */ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_CUDA_CU_API MaxInfoSpanningTree { @@ -112,6 +113,10 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { TensorView* from, TensorView* to, std::shared_ptr from_info) const = 0; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const = 0; // methods providing a mechanism to propagate to only a part of the DAG virtual bool allowPasC(TensorView* from, TensorView* to) { @@ -201,6 +206,10 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree TensorView* from, TensorView* to, std::shared_ptr from_info) const override; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const override; private: static std::shared_ptr getReferenceRootIDInfo(TensorView* tv); From 195bc340a0200711a2ccc4dbedb5da4d5ecf80d2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 21:38:02 -0700 Subject: [PATCH 050/100] move skipReplay to TransformReplay --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 102 +------------- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 44 +----- .../jit/codegen/cuda/transform_replay.cpp | 127 ++++++++++++++++++ .../csrc/jit/codegen/cuda/transform_replay.h | 19 +++ 4 files changed, 155 insertions(+), 137 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 5603d5e54d577..b56a779c80364 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -342,96 +342,6 @@ void ComputeAt::runWith( ca.runPass(); } -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); - - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - // Actually applies transformation unsigned int ComputeAt::backwardComputeAt_impl( TensorView* producer, @@ -460,9 +370,11 @@ unsigned int ComputeAt::backwardComputeAt_impl( max_consumer_compute_at_pos); } - // Short cut if no replay is necessary - auto maybe_producer_pos = - skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); + // Checks if producer and consumer are transformed consistently so that to + // satisfy the provided compute at position. This means no replay is actually + // necessary for the compute at requested. + auto maybe_producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( + producer, consumer, consumer_compute_at_pos); if (maybe_producer_pos >= 0) { if (!producer->isFusionInput()) { producer->setComputeAt((unsigned int)maybe_producer_pos); @@ -536,8 +448,8 @@ unsigned int ComputeAt::forwardComputeAt_impl( } // Short cut if no replay is necessary - auto maybe_consumer_pos = - skipReplay(producer, consumer, (int)producer_compute_at_pos, false); + auto maybe_consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( + consumer, producer, producer_compute_at_pos); if (maybe_consumer_pos > -1) { if (!producer->isFusionInput()) { producer->setComputeAt(producer_compute_at_pos); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d31d250ba5080..932c8f06b4353 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23615,46 +23615,6 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); } -namespace { - -// check that the resulting sibling are identical -void checkSiblingConsistency(TensorView* replay, TensorView* target) { - auto replay_root = replay->getRootDomain(); - auto replay_dom = replay->domain()->domain(); - auto target_root = target->getRootDomain(); - auto target_dom = target->domain()->domain(); - std::unordered_map target2replay_map; - TORCH_CHECK(replay_root.size() == target_root.size()); - target2replay_map.reserve(replay_root.size()); - std::transform( - target_root.begin(), - target_root.end(), - replay_root.begin(), - std::inserter(target2replay_map, target2replay_map.begin()), - [](auto a, auto b) { return std::make_pair(a, b); }); - BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); - auto r = replay_.getReplay(); - for (int64_t i = 0; i < replay_dom.size(); i++) { - auto target_id = target_dom[i]; - auto replay_it = r.find(target_id); - TORCH_CHECK(replay_it != r.end()); - TORCH_CHECK( - replay_it->second == replay_dom[i], - "IterDomain mismatch when checking ", - replay, - " and ", - target, - " at ", - i, - ", got ", - replay_it->second, - " and ", - replay_dom[i]); - } -}; - -} // namespace - TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { // https://github.com/csarofeen/pytorch/issues/1760 Fusion fusion; @@ -23686,7 +23646,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - checkSiblingConsistency(t1, t2); + TORCH_CHECK(TransformReplay::fullyMatching(t1, t2)); } } } @@ -23745,7 +23705,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - checkSiblingConsistency(t1, t2); + TORCH_CHECK(TransformReplay::fullyMatching(t1, t2)); } } } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 9015331a1417b..24a198dc0c82f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -644,6 +644,133 @@ std::pair TransformReplay::replayCasP( return replayCasP(consumer, producer, compute_at_axis, root_map); } +namespace { + +int getMatchedLeafPosWithoutReplay( + const TensorView* producer, + const TensorView* consumer, + int consumer_or_producer_pos, + bool consumer_pos = true) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay"); + + const auto c2p_root_map = + PairwiseRootDomainMap(producer, consumer) + .mapConsumerToProducer(consumer->domain(), producer->domain()); + + // IterDomains in consumer root also in producer root + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); + } + + const auto consumer_domain = consumer->domain()->domain(); + + auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set mapped_consumer_domain_ids( + mapped_consumer_domain_ids_vec.begin(), + mapped_consumer_domain_ids_vec.end()); + + const auto producer_domain = producer->domain()->domain(); + + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); + + auto best_effort_PasC = BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + + auto c2p_map = best_effort_PasC.getReplay(); + + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + auto consumer_id = *it_consumer; + if (!mapped_consumer_domain_ids.count(consumer_id)) { + ++it_consumer; + mismatched_consumer_pos++; + continue; + } + + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + break; + } + + if (it_producer == producer_domain.end()) { + break; + } + + auto producer_id = *it_producer; + + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + if (consumer_pos) { + if (consumer_or_producer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + } else { + if (consumer_or_producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + } + } else { + break; + } + } + return -1; +} + +} // namespace + +int TransformReplay::getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos) { + return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true); +} + +int TransformReplay::getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos) { + return getMatchedLeafPosWithoutReplay( + producer, consumer, producer_pos, false); +} + +bool TransformReplay::fullyMatching( + const TensorView* replay, + const TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + if (replay_root.size() != target_root.size()) { + return false; + } + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + if (replay_it == r.end() || replay_it->second != replay_dom[i]) { + return false; + } + } + return true; +} + void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); auto replay = TransformReplay::replayPasC(to, from, pos); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index c24ffa93f2954..f1f3232a5c2f1 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -154,6 +154,25 @@ class TORCH_CUDA_CU_API TransformReplay { static TensorDomain* fullSelfReplay( const TensorDomain* new_self_root, const TensorDomain* self); + + // Returns the leaf position in producer that matches with `consumer_pos` in + // consumer. Returns -1 if matching is impossible. This function can be used + // to test if replay is needed for getting matching outer dims. + static int getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos); + + // Returns the leaf position in consumer that matches with `producer_pos` in + // producer. Returns -1 if matching is impossible. This function can be used + // to test if replay is needed for getting matching outer dims. + static int getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos); + + // tests if two tensors has fully matching transformations + static bool fullyMatching(const TensorView* replay, const TensorView* target); }; class TORCH_CUDA_CU_API TransformPropagator From 91b2f9bca0eb9f9ea563711a6e092e3c6ac3c425 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 21:54:09 -0700 Subject: [PATCH 051/100] TransformPropagator skip replay if possible --- .../jit/codegen/cuda/transform_replay.cpp | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 24a198dc0c82f..a043c3e03cd3d 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -773,22 +773,44 @@ bool TransformReplay::fullyMatching( void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayPasC(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // Note: [Using multiple TransformPropagators] + // There are cases that we use multiple TransformPropagators along different + // spanning trees with different references in the same fusion. Some of these + // spanning trees could overlap. In cases when there are overlapping nodes, + // TransformPropagator needs to respect the replay of others, because the + // current TransformPropagator might not contain the most amount of + // information on how to do the correct transformation. The logic below tells + // TransformPropagator to skip the replay when not necessary. + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + to->setDomain(replay.first); + new_pos = replay.second; + } + replayed_pos_[to] = new_pos; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayCasP(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + if (new_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + to->setDomain(replay.first); + new_pos = replay.second; + } + replayed_pos_[to] = new_pos; } void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - to->setDomain(replay); + // See note [Using multiple TransformPropagators] + if (!TransformReplay::fullyMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + to->setDomain(replay); + } replayed_pos_[to] = pos; } From 2a499a8dcfbc324e92cf5adea34ce66dee16352c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 22:17:09 -0700 Subject: [PATCH 052/100] test --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 932c8f06b4353..744889a84afeb 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23932,6 +23932,35 @@ to: 2 TORCH_CHECK(printer2.ss.str() == expect); } +TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, true}); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + tv0->split(0, 2); + tv2->split(1, 2); + tv2->split(0, 2); + + MaxRootDomainInfoSpanningTree path1(tv2); + TransformPropagator propagator1(tv2); + path1.traverse(&propagator1); + + MaxRootDomainInfoSpanningTree path2(tv0); + TransformPropagator propagator2(tv0); + path2.traverse(&propagator2); + + TORCH_CHECK(tv1->axis(0)->isBroadcast()); + TORCH_CHECK(tv1->axis(1)->isBroadcast()); + TORCH_CHECK(!tv1->axis(2)->isBroadcast()); + TORCH_CHECK(!tv1->axis(3)->isBroadcast()); + TORCH_CHECK(tv1->axis(4)->isBroadcast()); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From f4925b32178ddf4adcac1d57cf043929a184ae88 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 22:46:09 -0700 Subject: [PATCH 053/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 2 +- .../jit/codegen/cuda/ir_interface_nodes.h | 2 - .../jit/codegen/cuda/transform_replay.cpp | 140 ------------------ .../csrc/jit/codegen/cuda/transform_replay.h | 11 -- 4 files changed, 1 insertion(+), 154 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index dab70b7287fa1..d4fe5d251c1f0 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -67,7 +67,7 @@ void ComputeAtPosPropagator::consumeAllAt( std::cout << "Before:" << std::endl; reference->fusion()->print(); ComputeAtSubgraphSelector selector(consume); - WeakTransformPropagator propagator(reference, reference_pos); + TransformPropagator propagator(reference, reference_pos); ComputeAtPosPropagator ca_propagator(consume, reference, reference_pos, mode); struct Printer : public MaxInfoSpanningTree::Propagator { diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index cb47b28f3c040..2bc36461986f5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -156,7 +156,6 @@ enum class ComputeAtMode { Standard, BestEffort, MostInlined }; class ComputeAtPosPropagator; class TransformPropagator; -class WeakTransformPropagator; class TransformIter; class TransformReplay; class OptOutMutator; @@ -457,7 +456,6 @@ class TORCH_CUDA_CU_API TensorView : public Val { void applyMmaSwizzle(MmaOptions options); friend TORCH_CUDA_CU_API TransformPropagator; - friend TORCH_CUDA_CU_API WeakTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; friend TORCH_CUDA_CU_API ComputeAtPosPropagator; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 41ac0e93452e8..a043c3e03cd3d 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -824,146 +824,6 @@ TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { replayed_pos_[from] = pos; } -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - -void WeakTransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { - // std::cout << "propagateTvPasC" << std::endl; - // std::cout << "from: " << from << std::endl; - // std::cout << "to: " << to << std::endl; - int from_pos = replayed_pos_.at(from); - int to_pos = skipReplay(to, from, from_pos, true); - if (to_pos < 0) { - auto replay = TransformReplay::replayPasC(to, from, from_pos); - to->setDomain(replay.first); - to_pos = replay.second; - // std::cout << "result: " << to << std::endl; - } else { - // std::cout << "skipped" << std::endl; - } - replayed_pos_[to] = to_pos; -} - -void WeakTransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { - // std::cout << "propagateTvCasP" << std::endl; - // std::cout << "from: " << from << std::endl; - // std::cout << "to: " << to << std::endl; - int from_pos = replayed_pos_.at(from); - int to_pos = skipReplay(from, to, from_pos, false); - if (to_pos < 0) { - auto replay = TransformReplay::replayCasP(to, from, from_pos); - to->setDomain(replay.first); - to_pos = replay.second; - // std::cout << "result: " << to << std::endl; - } else { - // std::cout << "skipped" << std::endl; - } - replayed_pos_[to] = to_pos; -} - -void WeakTransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { - int pos = replayed_pos_.at(from); - // TODO: check skip - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - to->setDomain(replay); - replayed_pos_[to] = pos; -} - -WeakTransformPropagator::WeakTransformPropagator(TensorView* from, int64_t pos) { - if (pos < 0) { - pos += int64_t(from->nDims()) + 1; - } - TORCH_CHECK( - pos >= 0 && pos <= from->nDims(), - "TransformPropagator called on an pos outside valid range."); - replayed_pos_[from] = pos; -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 206ded35c3fd3..f1f3232a5c2f1 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -186,17 +186,6 @@ class TORCH_CUDA_CU_API TransformPropagator TransformPropagator(TensorView* from, int64_t pos = -1); }; -class TORCH_CUDA_CU_API WeakTransformPropagator - : public MaxRootDomainInfoSpanningTree::Propagator { - std::unordered_map replayed_pos_; - - public: - virtual void propagateTvPasC(TensorView* from, TensorView* to) override; - virtual void propagateTvCasP(TensorView* from, TensorView* to) override; - virtual void propagateTvSibling(TensorView* from, TensorView* to) override; - WeakTransformPropagator(TensorView* from, int64_t pos = -1); -}; - } // namespace cuda } // namespace fuser } // namespace jit From ca8ef16f6f82e259c8b1284a06fa0760ae36ec82 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 22:56:01 -0700 Subject: [PATCH 054/100] cleanup debugging print --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 6 --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 59 ++-------------------- 2 files changed, 4 insertions(+), 61 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 469b50ce84bf9..a6065bf95d316 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -181,9 +181,6 @@ void ComputeAt::runAt( " are not in the same fusion."); FusionGuard fg(producer->fusion()); - std::cout << "ComputeAt::runAt(producer=" << producer - << ", consumer=" << consumer - << ", consumer_position=" << consumer_position << ")" << std::endl; ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), consumer, @@ -208,9 +205,6 @@ void ComputeAt::runWith( // Make sure Fusion Guard is set appropriately FusionGuard fg(producer->fusion()); - std::cout << "ComputeAt::runWith(producer=" << producer - << ", consumer=" << consumer - << ", producer_position=" << producer_position << ")" << std::endl; ComputeAtPosPropagator::consumeAllAt( getPropagationSubgraph(producer, consumer), diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index d4fe5d251c1f0..466d138136169 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -61,54 +61,24 @@ void ComputeAtPosPropagator::consumeAllAt( TensorView* reference, size_t reference_pos, ComputeAtMode mode) { - std::cout << "==========================" << std::endl; - std::cout << "From: " << reference << " at pos " << reference_pos - << std::endl; - std::cout << "Before:" << std::endl; - reference->fusion()->print(); ComputeAtSubgraphSelector selector(consume); TransformPropagator propagator(reference, reference_pos); ComputeAtPosPropagator ca_propagator(consume, reference, reference_pos, mode); - struct Printer : public MaxInfoSpanningTree::Propagator { - std::stringstream ss; - virtual void propagateTvPasC(TensorView* from, TensorView* to) override { - ss << "propagateTvPasC" << std::endl; - ss << "from: " << from << std::endl; - ss << "to: " << to << std::endl; - } - virtual void propagateTvCasP(TensorView* from, TensorView* to) override { - ss << "propagateTvCasP" << std::endl; - ss << "from: " << from << std::endl; - ss << "to: " << to << std::endl; - } - virtual void propagateTvSibling(TensorView* from, TensorView* to) override { - ss << "propagateTvSibling" << std::endl; - ss << "from: " << from << std::endl; - ss << "to: " << to << std::endl; - } - } printer; - MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); - path.traverse(&printer); - // std::cout << "consume: " << ir_utils::toString(consume) << std::endl; - // std::cout << "Path:\n" << printer.ss.str() << std::endl; path.traverse(&propagator); - // std::cout << "After TransformPropagator:" << std::endl; - reference->fusion()->print(); ca_propagator.recordReplayedPos(reference, reference_pos); path.traverse(&ca_propagator); ca_propagator.computeMaxProducerPos(); - std::cout << "\n\nAfter:" << std::endl; - reference->fusion()->print(); - std::cout << "==========================" << std::endl; } size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto dom = tv->domain()->domain(); - auto first_reduction = std::find_if( - dom.begin(), dom.begin() + pos, [](IterDomain* id) { return id->isReduction(); }); + auto first_reduction = + std::find_if(dom.begin(), dom.begin() + pos, [](IterDomain* id) { + return id->isReduction(); + }); auto first_vectorized_axis = std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { @@ -120,18 +90,8 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); - std::cout << "adjustComputeAtPos: " << tv << std::endl; - // TODO: why I have to disable this?!! - // for (auto producer_tv : ir_utils::producerTvsOf(tv)) { - // std::cout << "producer_tv: " << producer_tv << std::endl; - // auto max_pos = getReplayablePosPasC(producer_tv, tv); - // std::cout << "max_pos: " << max_pos << std::endl; - // ca_pos = std::min(ca_pos, max_pos); - // } for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - std::cout << "consumer_tv: " << consumer_tv << std::endl; auto max_pos = getReplayablePosCasP(consumer_tv, tv); - std::cout << "max_pos: " << max_pos << std::endl; ca_pos = std::min(ca_pos, max_pos); } @@ -140,7 +100,6 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { ca_pos--; } - return ca_pos; } @@ -394,9 +353,7 @@ int skipReplay( } // namespace void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { - std::cout << "propagateTvPasC\nfrom: " << from << "\nto: " << to << std::endl; int pos = getReplayPosPasC(to, from); - std::cout << "at: " << pos << std::endl; // Short cut if no replay is necessary auto to_pos = skipReplay(to, from, (int)pos, true); // TORCH_CHECK(to_pos >= 0); @@ -406,13 +363,10 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - std::cout << "result: " << to << std::endl; } void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { - std::cout << "propagateTvCasP\nfrom: " << from << "\nto: " << to << std::endl; int pos = getReplayPosCasP(to, from); - std::cout << "at: " << pos << std::endl; // Short cut if no replay is necessary auto to_pos = skipReplay(from, to, (int)pos, false); // TORCH_CHECK(to_pos >= 0); @@ -422,7 +376,6 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - std::cout << "result: " << to << std::endl; } void ComputeAtPosPropagator::propagateTvSibling( @@ -466,8 +419,6 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( TensorView* consumer) { size_t max_pos = getReplayablePosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); - // std::cout << "[getReplayPosPasC] producer=" << producer << ", consumer=" << - // consumer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); @@ -493,8 +444,6 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( TensorView* producer) { size_t max_pos = getReplayablePosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); - // std::cout << "[getReplayPosCasP] consumer=" << consumer << ", producer=" << - // producer << ", max_pos=" << max_pos << ", pos=" << pos << std::endl; if (mode_ == ComputeAtMode::BestEffort) { return std::min(pos, max_pos); From c2dadf6dbb0f96aebca7c477234e2f4ac8308551 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 27 Jun 2022 23:24:18 -0700 Subject: [PATCH 055/100] minor cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 466d138136169..ffafce0ed5854 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -522,13 +522,11 @@ void ComputeAtPosPropagator::computeMaxProducerPos() { std::unordered_set todo; for (auto p : consume_) { auto consumers = ir_utils::consumerTvsOf(p); - std::copy( - consumers.begin(), consumers.end(), std::inserter(todo, todo.end())); + todo.insert(consumers.begin(), consumers.end()); } for (auto tv : todo) { - auto producers = ir_utils::producerTvsOf(tv); size_t max_pos = 0; - for (auto p : producers) { + for (auto p : ir_utils::producerTvsOf(tv)) { max_pos = std::max(max_pos, getConsumerPosAlignedToProducerCA(tv, p)); } From a86f51760a45f2e1af43727f83f7f22ac0a48513 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 01:12:13 -0700 Subject: [PATCH 056/100] more cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 160 ++++++------------ torch/csrc/jit/codegen/cuda/consume_at.h | 17 +- .../jit/codegen/cuda/ir_interface_nodes.h | 2 + 3 files changed, 63 insertions(+), 116 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index ffafce0ed5854..a96cd8cb4a622 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -62,15 +62,17 @@ void ComputeAtPosPropagator::consumeAllAt( size_t reference_pos, ComputeAtMode mode) { ComputeAtSubgraphSelector selector(consume); + TransformPropagator propagator(reference, reference_pos); ComputeAtPosPropagator ca_propagator(consume, reference, reference_pos, mode); + MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); path.traverse(&propagator); ca_propagator.recordReplayedPos(reference, reference_pos); path.traverse(&ca_propagator); - ca_propagator.computeMaxProducerPos(); + path.traverse(&updater); } size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { @@ -264,98 +266,11 @@ size_t ComputeAtPosPropagator::getReplayablePosCasP( return 0; } -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = getReplayPosPasC(to, from); // Short cut if no replay is necessary - auto to_pos = skipReplay(to, from, (int)pos, true); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); @@ -368,7 +283,8 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = getReplayPosCasP(to, from); // Short cut if no replay is necessary - auto to_pos = skipReplay(from, to, (int)pos, false); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); @@ -464,6 +380,23 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( return pos; } +bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { + return consume_.count(to) > 0; +} + +bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { + // If the producer is in the consume set, then the consumer must also be + // replayed to obtain a compatible loop structure so that this producer + // can be consumed in this loop. + return consume_.count(from) > 0 || consume_.count(to) > 0; +} + +bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { + return true; +} + +namespace { + // TODO: most of this is copy-pasted code. I need to investigate this to // see which makes sense and which needs change. // Try to find the aligned position on consumer's domain corresponding to the @@ -518,35 +451,38 @@ size_t getConsumerPosAlignedToProducerCA( return consumer_pos; } -void ComputeAtPosPropagator::computeMaxProducerPos() { - std::unordered_set todo; - for (auto p : consume_) { - auto consumers = ir_utils::consumerTvsOf(p); - todo.insert(consumers.begin(), consumers.end()); +} // namespace + +void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { + if (updated_.empty()) { + // This code path is only taken for the first hop of the traverse. This is + // needed because the starting tensor's CA pos is also set, so we need to + // update its consumers as well. + updated_.insert(nullptr); + propagateTvPasC(nullptr, from); } - for (auto tv : todo) { + for (auto consumer_tv : ir_utils::consumerTvsOf(to)) { + if (updated_.count(consumer_tv) > 0) { + continue; + } size_t max_pos = 0; - for (auto p : ir_utils::producerTvsOf(tv)) { - max_pos = - std::max(max_pos, getConsumerPosAlignedToProducerCA(tv, p)); + for (auto p : ir_utils::producerTvsOf(consumer_tv)) { + max_pos = std::max( + max_pos, getConsumerPosAlignedToProducerCA(consumer_tv, p)); } - tv->setMaxProducer(max_pos, true); + consumer_tv->setMaxProducer(max_pos, true); + updated_.insert(consumer_tv); } } -bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { - return consume_.count(to) > 0; +void MaxProducerPosUpdater::propagateTvCasP(TensorView* from, TensorView* to) { + propagateTvPasC(from, to); } -bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { - // If the producer is in the consume set, then the consumer must also be - // replayed to obtain a compatible loop structure so that this producer - // can be consumed in this loop. - return consume_.count(from) > 0 || consume_.count(to) > 0; -} - -bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { - return true; +void MaxProducerPosUpdater::propagateTvSibling( + TensorView* from, + TensorView* to) { + propagateTvPasC(from, to); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index a97718c78ecf0..b821bb5337d78 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -33,8 +33,6 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { size_t adjustComputeAtPos(TensorView* tv, size_t pos); - void computeMaxProducerPos(); - // Iterate through all TVs and collect the dimensions of each TV that don't // map to all its consumer TVs. void buildUnmappableDims(); @@ -58,12 +56,11 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { void run(); - protected: + public: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; virtual void propagateTvSibling(TensorView* from, TensorView* to) override; - public: static void consumeAllAt( std::unordered_set consume, TensorView* reference, @@ -71,6 +68,18 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { ComputeAtMode mode); }; +// This is actually not a propagation, and it is not needed to compute the max +// producer position in a specific order. But MaxInfoSpanningTree provides a +// very convenient API to visit the tensors, so I just use it for cleaner code. +class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator { + std::unordered_set updated_; + + public: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; +}; + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 2bc36461986f5..4c754ad128cee 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -155,6 +155,7 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { enum class ComputeAtMode { Standard, BestEffort, MostInlined }; class ComputeAtPosPropagator; +class MaxProducerPosUpdater; class TransformPropagator; class TransformIter; class TransformReplay; @@ -459,6 +460,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; friend TORCH_CUDA_CU_API ComputeAtPosPropagator; + friend TORCH_CUDA_CU_API MaxProducerPosUpdater; friend class ir_utils::TVDomainGuard; friend TORCH_CUDA_CU_API void groupReductions( const std::vector&); From d8ed31897c2c8a6004a51f50ca813f360012d2d1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 02:02:16 -0700 Subject: [PATCH 057/100] more cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 29 ++++++++++++++-------- torch/csrc/jit/codegen/cuda/consume_at.h | 1 + 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index a96cd8cb4a622..b4ef150b71546 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -69,8 +69,6 @@ void ComputeAtPosPropagator::consumeAllAt( MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); path.traverse(&propagator); - - ca_propagator.recordReplayedPos(reference, reference_pos); path.traverse(&ca_propagator); path.traverse(&updater); } @@ -92,6 +90,7 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); + // TODO: why not doing the same for producers? for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { auto max_pos = getReplayablePosCasP(consumer_tv, tv); ca_pos = std::min(ca_pos, max_pos); @@ -267,8 +266,11 @@ size_t ComputeAtPosPropagator::getReplayablePosCasP( } void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { + if (is_first_) { + is_first_ = false; + recordReplayedPos(reference_, reference_pos_); + } int pos = getReplayPosPasC(to, from); - // Short cut if no replay is necessary auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TORCH_CHECK(to_pos >= 0); @@ -281,8 +283,11 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { } void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { + if (is_first_) { + is_first_ = false; + recordReplayedPos(reference_, reference_pos_); + } int pos = getReplayPosCasP(to, from); - // Short cut if no replay is necessary auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TORCH_CHECK(to_pos >= 0); @@ -297,11 +302,15 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { void ComputeAtPosPropagator::propagateTvSibling( TensorView* from, TensorView* to) { - // TODO: copy-paste computeAts + if (is_first_) { + is_first_ = false; + recordReplayedPos(reference_, reference_pos_); + } auto from_pos = retrieveReplayedPos(from); - // TODO: check skip - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - to->setDomain(replay); + if (!TransformReplay::fullyMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + to->setDomain(replay); + } recordReplayedPos(to, from_pos); } @@ -455,9 +464,7 @@ size_t getConsumerPosAlignedToProducerCA( void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { if (updated_.empty()) { - // This code path is only taken for the first hop of the traverse. This is - // needed because the starting tensor's CA pos is also set, so we need to - // update its consumers as well. + // handle the reference tensor updated_.insert(nullptr); propagateTvPasC(nullptr, from); } diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index b821bb5337d78..11924f68b019a 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -42,6 +42,7 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; std::unordered_map replayed_pos_; + bool is_first_ = true; // Root domains in producer that's unmappable to any of its consumers std::unordered_set unmappable_dims_; From f8d6b8a968d84ff7b225fdf5610b30a5c786a063 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 10:23:40 -0700 Subject: [PATCH 058/100] save --- torch/csrc/jit/codegen/cuda/compute_at.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 9462571c7963d..683d0782482cd 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include From 220129b9a967555a91c5f915d65823b46569a90a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 10:25:12 -0700 Subject: [PATCH 059/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index b4ef150b71546..205175702171f 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -90,7 +90,8 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); - // TODO: why not doing the same for producers? + // We only check consumers here, because producers are not always replayed + // consistently for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { auto max_pos = getReplayablePosCasP(consumer_tv, tv); ca_pos = std::min(ca_pos, max_pos); From 77564853e718b75dba6b7ca5fdf5d9c5546456d5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 10:55:57 -0700 Subject: [PATCH 060/100] save --- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index dad649c76db4c..a043c3e03cd3d 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -814,13 +814,6 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { replayed_pos_[to] = pos; } -void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { - int pos = replayed_pos_.at(from); - auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); - to->setDomain(replay); - replayed_pos_[to] = pos; -} - TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { if (pos < 0) { pos += int64_t(from->nDims()) + 1; From ad552041f79189b12171d94bcf4ad73b5fbe8657 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 11:42:42 -0700 Subject: [PATCH 061/100] cleanup getConsumerPosAlignedToProducerCA --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 24 ++++++++-------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 205175702171f..57faa44ad3fa4 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -407,8 +407,6 @@ bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { namespace { -// TODO: most of this is copy-pasted code. I need to investigate this to -// see which makes sense and which needs change. // Try to find the aligned position on consumer's domain corresponding to the // compute at position of producer domain. Used in computeAt pass only. No // checking on actual producer-consumer relationship. @@ -437,22 +435,18 @@ size_t getConsumerPosAlignedToProducerCA( PairwiseRootDomainMap(producer, consumer)) .getReplay(); - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. + auto p_dom = producer->domain()->domain(); + std::unordered_set producer_ca_ids( + p_dom.begin(), p_dom.begin() + producer->getComputeAtPosition()); + + // Find the innermost position of consumer that has been mapped within the + // producer ca axis. unsigned int consumer_pos = consumer->nDims(); while (consumer_pos > 0) { auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto p_dom = producer->domain()->domain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer->getComputeAtPosition(), - [&consumer_id, &c2p_map](IterDomain* p_id) { - auto c_id_it = c2p_map.find(consumer_id); - if (c_id_it != c2p_map.end()) { - return c_id_it->second == p_id; - } - return false; - })) { + auto c_id_it = c2p_map.find(consumer_id); + if (c_id_it != c2p_map.end() && + producer_ca_ids.count(c_id_it->second) > 0) { break; } consumer_pos--; From 6e497f6fb6e6c0eca5c93b50da9fbd0ebfdb12f3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 12:02:07 -0700 Subject: [PATCH 062/100] more cleanup of getConsumerPosAlignedToProducerCA --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 42 +++++++--------------- torch/csrc/jit/codegen/cuda/consume_at.h | 1 + 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 57faa44ad3fa4..4a9305490752a 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -422,31 +422,11 @@ size_t getConsumerPosAlignedToProducerCA( // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to // NVFuserTest.FusionComplexBCast1_CUDA - auto c2p_map = - BestEffortReplay::replayPasC( - producer, - consumer, - -1, - // Compute at root domain may not be valid here, as all - // producers don't have to be able to map into consumer at - // max producer position. Since computeAt should be valid - // and this mechanism is only intended to lower produce - // position of consumer, we can simply use the pairwise map. - PairwiseRootDomainMap(producer, consumer)) - .getReplay(); - - auto p_dom = producer->domain()->domain(); - std::unordered_set producer_ca_ids( - p_dom.begin(), p_dom.begin() + producer->getComputeAtPosition()); - - // Find the innermost position of consumer that has been mapped within the - // producer ca axis. unsigned int consumer_pos = consumer->nDims(); while (consumer_pos > 0) { - auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto c_id_it = c2p_map.find(consumer_id); - if (c_id_it != c2p_map.end() && - producer_ca_ids.count(c_id_it->second) > 0) { + auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( + producer, consumer, consumer_pos); + if (producer_pos >= 0 && producer_pos <= producer->getComputeAtPosition()) { break; } consumer_pos--; @@ -457,6 +437,15 @@ size_t getConsumerPosAlignedToProducerCA( } // namespace +void MaxProducerPosUpdater::handle(TensorView* tv) { + size_t max_pos = 0; + for (auto p : ir_utils::producerTvsOf(tv)) { + max_pos = std::max( + max_pos, getConsumerPosAlignedToProducerCA(tv, p)); + } + tv->setMaxProducer(max_pos, true); +} + void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { if (updated_.empty()) { // handle the reference tensor @@ -467,12 +456,7 @@ void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { if (updated_.count(consumer_tv) > 0) { continue; } - size_t max_pos = 0; - for (auto p : ir_utils::producerTvsOf(consumer_tv)) { - max_pos = std::max( - max_pos, getConsumerPosAlignedToProducerCA(consumer_tv, p)); - } - consumer_tv->setMaxProducer(max_pos, true); + handle(consumer_tv); updated_.insert(consumer_tv); } } diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 11924f68b019a..940183742610e 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -74,6 +74,7 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // very convenient API to visit the tensors, so I just use it for cleaner code. class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator { std::unordered_set updated_; + void handle(TensorView* tv); public: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; From 194c7df0b889f131b4c681254f52ec69f7db725e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 12:14:49 -0700 Subject: [PATCH 063/100] more cleanup on getConsumerPosAlignedToProducerCA --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 44 ++++++---------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 4a9305490752a..77703241ddd23 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -405,45 +405,23 @@ bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { return true; } -namespace { - // Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. Used in computeAt pass only. No -// checking on actual producer-consumer relationship. -size_t getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - +// compute at position of producer domain. +void MaxProducerPosUpdater::handle(TensorView* consumer) { unsigned int consumer_pos = consumer->nDims(); while (consumer_pos > 0) { - auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( - producer, consumer, consumer_pos); - if (producer_pos >= 0 && producer_pos <= producer->getComputeAtPosition()) { - break; + for (auto producer : ir_utils::producerTvsOf(consumer)) { + auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( + producer, consumer, consumer_pos); + if (producer_pos >= 0 && + producer_pos <= producer->getComputeAtPosition()) { + goto finished; + } } consumer_pos--; } - - return consumer_pos; -} - -} // namespace - -void MaxProducerPosUpdater::handle(TensorView* tv) { - size_t max_pos = 0; - for (auto p : ir_utils::producerTvsOf(tv)) { - max_pos = std::max( - max_pos, getConsumerPosAlignedToProducerCA(tv, p)); - } - tv->setMaxProducer(max_pos, true); +finished: + consumer->setMaxProducer(consumer_pos, true); } void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { From b4fe1feea9ba75617d0e1c296e1de4867b8fa498 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 12:40:29 -0700 Subject: [PATCH 064/100] cleanup ComputeAtSubgraphSelector --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 54 +++++++++++++++++----- torch/csrc/jit/codegen/cuda/compute_at.h | 13 ++++++ torch/csrc/jit/codegen/cuda/consume_at.cpp | 32 ------------- torch/csrc/jit/codegen/cuda/consume_at.h | 21 +-------- 4 files changed, 57 insertions(+), 63 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index a6065bf95d316..9c8898b71190c 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -165,6 +165,26 @@ std::unordered_set getPropagationSubgraph( } // namespace +bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { + return selected_.count(to) > 0; +} + +bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { + // If the producer is in the consume set, then the consumer must also be + // replayed to obtain a compatible loop structure so that this producer + // can be consumed in this loop. + return selected_.count(from) > 0 || selected_.count(to) > 0; +} + +bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { + return true; +} + +ComputeAtSubgraphSelector::ComputeAtSubgraphSelector( + TensorView* producer, + TensorView* consumer) + : selected_(getPropagationSubgraph(producer, consumer)) {} + void ComputeAt::runAt( TensorView* producer, TensorView* consumer, @@ -181,11 +201,18 @@ void ComputeAt::runAt( " are not in the same fusion."); FusionGuard fg(producer->fusion()); - ComputeAtPosPropagator::consumeAllAt( - getPropagationSubgraph(producer, consumer), - consumer, - consumer_position, - mode); + + ComputeAtSubgraphSelector selector(producer, consumer); + + TransformPropagator propagator(consumer, consumer_position); + ComputeAtPosPropagator ca_propagator( + selector.selected(), consumer, consumer_position, mode); + MaxProducerPosUpdater updater; + + MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); + path.traverse(&propagator); + path.traverse(&ca_propagator); + path.traverse(&updater); } void ComputeAt::runWith( @@ -203,14 +230,19 @@ void ComputeAt::runWith( consumer, " are not in the same fusion."); - // Make sure Fusion Guard is set appropriately FusionGuard fg(producer->fusion()); - ComputeAtPosPropagator::consumeAllAt( - getPropagationSubgraph(producer, consumer), - producer, - producer_position, - mode); + ComputeAtSubgraphSelector selector(producer, consumer); + + TransformPropagator propagator(producer, producer_position); + ComputeAtPosPropagator ca_propagator( + selector.selected(), producer, producer_position, mode); + MaxProducerPosUpdater updater; + + MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); + path.traverse(&propagator); + path.traverse(&ca_propagator); + path.traverse(&updater); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 683d0782482cd..0e3e3c47eaccb 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -20,6 +20,19 @@ namespace cuda { class TensorDomain; class TensorView; +class ComputeAtSubgraphSelector : public MaxInfoSpanningTree::Selector { + std::unordered_set selected_; + + public: + virtual bool allowPasC(TensorView* from, TensorView* to) override; + virtual bool allowCasP(TensorView* from, TensorView* to) override; + virtual bool allowSibling(TensorView* from, TensorView* to) override; + ComputeAtSubgraphSelector(TensorView* producer, TensorView* consumer); + const std::unordered_set& selected() const { + return selected_; + } +}; + struct ComputeAt { public: // Runs the compute at pass making producer look like consumer, computing diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 77703241ddd23..437bd0190d884 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -56,23 +56,6 @@ void ComputeAtPosPropagator::buildUnmappableDims() { } } -void ComputeAtPosPropagator::consumeAllAt( - std::unordered_set consume, - TensorView* reference, - size_t reference_pos, - ComputeAtMode mode) { - ComputeAtSubgraphSelector selector(consume); - - TransformPropagator propagator(reference, reference_pos); - ComputeAtPosPropagator ca_propagator(consume, reference, reference_pos, mode); - MaxProducerPosUpdater updater; - - MaxRootDomainInfoSpanningTree path(reference, reference_pos, &selector); - path.traverse(&propagator); - path.traverse(&ca_propagator); - path.traverse(&updater); -} - size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { auto dom = tv->domain()->domain(); auto first_reduction = @@ -390,21 +373,6 @@ size_t ComputeAtPosPropagator::getReplayPosCasP( return pos; } -bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { - return consume_.count(to) > 0; -} - -bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { - // If the producer is in the consume set, then the consumer must also be - // replayed to obtain a compatible loop structure so that this producer - // can be consumed in this loop. - return consume_.count(from) > 0 || consume_.count(to) > 0; -} - -bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { - return true; -} - // Try to find the aligned position on consumer's domain corresponding to the // compute at position of producer domain. void MaxProducerPosUpdater::handle(TensorView* consumer) { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 940183742610e..d2ad077ab8916 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -11,17 +11,6 @@ namespace jit { namespace fuser { namespace cuda { -class ComputeAtSubgraphSelector : public MaxInfoSpanningTree::Selector { - std::unordered_set consume_; - - public: - virtual bool allowPasC(TensorView* from, TensorView* to) override; - virtual bool allowCasP(TensorView* from, TensorView* to) override; - virtual bool allowSibling(TensorView* from, TensorView* to) override; - ComputeAtSubgraphSelector(std::unordered_set consume) - : consume_(std::move(consume)) {} -}; - class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // TODO: change arguments to `from`, `to`. size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); @@ -47,6 +36,7 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // Root domains in producer that's unmappable to any of its consumers std::unordered_set unmappable_dims_; + public: ComputeAtPosPropagator( std::unordered_set consume, TensorView* reference, @@ -55,18 +45,9 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { ~ComputeAtPosPropagator() = default; - void run(); - - public: virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; virtual void propagateTvSibling(TensorView* from, TensorView* to) override; - - static void consumeAllAt( - std::unordered_set consume, - TensorView* reference, - size_t reference_pos, - ComputeAtMode mode); }; // This is actually not a propagation, and it is not needed to compute the max From 0f8c52e88e45810e87e7b98c03eea1bff55dfe2a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 14:04:40 -0700 Subject: [PATCH 065/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 2 +- torch/csrc/jit/codegen/cuda/transform_replay.h | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d38d37ff6cf3a..190f2eb447daf 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23710,7 +23710,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - TORCH_CHECK(TransformReplay::fullyMatching(t1, t2)); + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); } } } @@ -23769,7 +23769,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - TORCH_CHECK(TransformReplay::fullyMatching(t1, t2)); + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); } } } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index a043c3e03cd3d..c00b43c4660ba 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -741,7 +741,7 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP( producer, consumer, producer_pos, false); } -bool TransformReplay::fullyMatching( +bool TransformReplay::fullSelfMatching( const TensorView* replay, const TensorView* target) { auto replay_root = replay->getRootDomain(); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index f1f3232a5c2f1..6b307d348ad58 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -172,7 +172,7 @@ class TORCH_CUDA_CU_API TransformReplay { int producer_pos); // tests if two tensors has fully matching transformations - static bool fullyMatching(const TensorView* replay, const TensorView* target); + static bool fullSelfMatching(const TensorView* replay, const TensorView* target); }; class TORCH_CUDA_CU_API TransformPropagator From a873c177ec27e8de2cca4924ec0cf16ed226af0c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 14:32:54 -0700 Subject: [PATCH 066/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 2 +- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 2 +- torch/csrc/jit/codegen/cuda/transform_replay.h | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 437bd0190d884..9911f2cd100b4 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -291,7 +291,7 @@ void ComputeAtPosPropagator::propagateTvSibling( recordReplayedPos(reference_, reference_pos_); } auto from_pos = retrieveReplayedPos(from); - if (!TransformReplay::fullyMatching(to, from)) { + if (!TransformReplay::fullSelfMatching(to, from)) { auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); to->setDomain(replay); } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index c00b43c4660ba..e961867865181 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -807,7 +807,7 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); // See note [Using multiple TransformPropagators] - if (!TransformReplay::fullyMatching(to, from)) { + if (!TransformReplay::fullSelfMatching(to, from)) { auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); to->setDomain(replay); } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 6b307d348ad58..d026de618c88f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -172,7 +172,9 @@ class TORCH_CUDA_CU_API TransformReplay { int producer_pos); // tests if two tensors has fully matching transformations - static bool fullSelfMatching(const TensorView* replay, const TensorView* target); + static bool fullSelfMatching( + const TensorView* replay, + const TensorView* target); }; class TORCH_CUDA_CU_API TransformPropagator From 3f13eb28ffc9c8225b7a9fea85fbf073de5658ae Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 14:41:47 -0700 Subject: [PATCH 067/100] save --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 8 ++++++-- torch/csrc/jit/codegen/cuda/consume_at.h | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 9911f2cd100b4..72a346cb3416c 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -257,7 +257,9 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = getReplayPosPasC(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); - // TORCH_CHECK(to_pos >= 0); + // TODO: Can we make TransformPropagator do the transformation, and + // ComputeAtPosPropagator only set the CA positions? + // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); to->setDomain(replay.first); @@ -274,7 +276,9 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = getReplayPosCasP(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); - // TORCH_CHECK(to_pos >= 0); + // TODO: Can we make TransformPropagator do the transformation, and + // ComputeAtPosPropagator only set the CA positions? + // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); to->setDomain(replay.first); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index d2ad077ab8916..73da5a8dafaa2 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -17,11 +17,11 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); + size_t adjustComputeAtPos(TensorView* tv, size_t pos); + void recordReplayedPos(TensorView* tv, size_t pos); size_t retrieveReplayedPos(TensorView* tv); - size_t adjustComputeAtPos(TensorView* tv, size_t pos); - // Iterate through all TVs and collect the dimensions of each TV that don't // map to all its consumer TVs. void buildUnmappableDims(); From 59ba3274a099755210729a6b2f98a6d60eca8c47 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 14:57:21 -0700 Subject: [PATCH 068/100] cleanup --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 72a346cb3416c..d858d189dc03f 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -151,8 +151,7 @@ size_t ComputeAtPosPropagator::getReplayablePosPasC( // convert to iter domains auto consumer_root_dim_ids = ir_utils::filterByType(consumer_root_dim_vals); - // If any root dimensions cannot be mapped to producer we can't inline. If - // any root dimension + // If any root dimensions cannot be mapped to producer we can't inline. if (std::any_of( consumer_root_dim_ids.begin(), consumer_root_dim_ids.end(), From 679867cec2ba2678d754eaaa2761bf14e053efc6 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 17:45:55 -0700 Subject: [PATCH 069/100] cleanup max pos logic --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 57 ++++++++++++++++------ torch/csrc/jit/codegen/cuda/consume_at.h | 11 +++-- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index d858d189dc03f..8d76e0634c8b7 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -56,10 +56,10 @@ void ComputeAtPosPropagator::buildUnmappableDims() { } } -size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { +size_t ComputeAtPosPropagator::getMaxPosSelf(TensorView* tv) { auto dom = tv->domain()->domain(); auto first_reduction = - std::find_if(dom.begin(), dom.begin() + pos, [](IterDomain* id) { + std::find_if(dom.begin(), dom.begin() + tv->nDims(), [](IterDomain* id) { return id->isReduction(); }); @@ -71,21 +71,50 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { id->getParallelType() == ParallelType::Unroll); }); - auto ca_pos = std::distance(dom.begin(), first_vectorized_axis); + auto all_mappable = [&](size_t ca_pos) -> bool { + auto dom = tv->domain()->domain(); + // Grab all root dimensions as roots must be used to understand inlining + // potential. + auto root_dim_vals = + IterVisitor::getInputsTo({dom.begin(), dom.begin() + ca_pos}); + auto root_dim_ids = ir_utils::filterByType(root_dim_vals); + // If any root dimensions cannot be mapped we can't inline. + return std::all_of( + root_dim_ids.begin(), root_dim_ids.end(), [this](IterDomain* id) { + return unmappable_dims_.find(id) == unmappable_dims_.end(); + }); + }; + + auto pos = std::distance(dom.begin(), first_vectorized_axis); + while (!all_mappable(pos)) { + pos--; + } + return pos; +} - // We only check consumers here, because producers are not always replayed - // consistently +size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { + auto max_pos = getMaxPosSelf(tv); + for (auto producer_tv : ir_utils::producerTvsOf(tv)) { + if (consume_.count(producer_tv) > 0) { + max_pos = std::min(max_pos, getMaxPosPasC(producer_tv, tv)); + } + } for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - auto max_pos = getReplayablePosCasP(consumer_tv, tv); - ca_pos = std::min(ca_pos, max_pos); + // consume_.count(consumer_tv) > 0 is always true by definition + max_pos = std::min(max_pos, getMaxPosCasP(consumer_tv, tv)); } + return max_pos; +} + +size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { + pos = std::min(pos, getMaxPosAll(tv)); // hoist inner most broadcast - while (ca_pos > 0 && tv->axis(ca_pos - 1)->isBroadcast()) { - ca_pos--; + while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { + pos--; } - return ca_pos; + return pos; } // Return the max position in consumer that producer can be inlined to @@ -96,7 +125,7 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ComputeAtPosPropagator::getReplayablePosPasC( +size_t ComputeAtPosPropagator::getMaxPosPasC( TensorView* producer, TensorView* consumer) { // Check if any consumer dimensions are marked as vectorize as producer can @@ -178,7 +207,7 @@ size_t ComputeAtPosPropagator::getReplayablePosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ComputeAtPosPropagator::getReplayablePosCasP( +size_t ComputeAtPosPropagator::getMaxPosCasP( TensorView* consumer, TensorView* producer) { auto p_dom = producer->domain()->domain(); @@ -329,7 +358,7 @@ size_t ComputeAtPosPropagator::retrieveReplayedPos(TensorView* tv) { size_t ComputeAtPosPropagator::getReplayPosPasC( TensorView* producer, TensorView* consumer) { - size_t max_pos = getReplayablePosPasC(producer, consumer); + size_t max_pos = getMaxPosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); if (mode_ == ComputeAtMode::BestEffort) { @@ -354,7 +383,7 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( size_t ComputeAtPosPropagator::getReplayPosCasP( TensorView* consumer, TensorView* producer) { - size_t max_pos = getReplayablePosCasP(consumer, producer); + size_t max_pos = getMaxPosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); if (mode_ == ComputeAtMode::BestEffort) { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 73da5a8dafaa2..40023ceef0457 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -13,12 +13,15 @@ namespace cuda { class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // TODO: change arguments to `from`, `to`. - size_t getReplayablePosPasC(TensorView* producer, TensorView* consumer); - size_t getReplayablePosCasP(TensorView* consumer, TensorView* producer); - size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); - size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); + size_t getMaxPosSelf(TensorView* tv); + size_t getMaxPosPasC(TensorView* producer, TensorView* consumer); + size_t getMaxPosCasP(TensorView* consumer, TensorView* producer); + size_t getMaxPosAll(TensorView* tv); + size_t adjustComputeAtPos(TensorView* tv, size_t pos); + size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); + size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); void recordReplayedPos(TensorView* tv, size_t pos); size_t retrieveReplayedPos(TensorView* tv); From 18391587aa945a6e83c7f8b6700900811c074d52 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 22:26:51 -0700 Subject: [PATCH 070/100] getMaxPos* cleanup, step 1 --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 136 +++++++++++++-------- torch/csrc/jit/codegen/cuda/consume_at.h | 6 +- 2 files changed, 90 insertions(+), 52 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 8d76e0634c8b7..e6752093154bf 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -56,65 +56,45 @@ void ComputeAtPosPropagator::buildUnmappableDims() { } } -size_t ComputeAtPosPropagator::getMaxPosSelf(TensorView* tv) { +size_t ComputeAtPosPropagator::getMaxPosSelf( + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) { auto dom = tv->domain()->domain(); - auto first_reduction = - std::find_if(dom.begin(), dom.begin() + tv->nDims(), [](IterDomain* id) { - return id->isReduction(); - }); - auto first_vectorized_axis = - std::find_if(dom.begin(), first_reduction, [this](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); + auto iter = dom.end(); - auto all_mappable = [&](size_t ca_pos) -> bool { - auto dom = tv->domain()->domain(); - // Grab all root dimensions as roots must be used to understand inlining - // potential. - auto root_dim_vals = - IterVisitor::getInputsTo({dom.begin(), dom.begin() + ca_pos}); - auto root_dim_ids = ir_utils::filterByType(root_dim_vals); - // If any root dimensions cannot be mapped we can't inline. - return std::all_of( - root_dim_ids.begin(), root_dim_ids.end(), [this](IterDomain* id) { - return unmappable_dims_.find(id) == unmappable_dims_.end(); - }); - }; - - auto pos = std::distance(dom.begin(), first_vectorized_axis); - while (!all_mappable(pos)) { - pos--; + if (!allow_reduction) { + iter = std::find_if( + dom.begin(), iter, [](IterDomain* id) { return id->isReduction(); }); } - return pos; -} -size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { - auto max_pos = getMaxPosSelf(tv); - for (auto producer_tv : ir_utils::producerTvsOf(tv)) { - if (consume_.count(producer_tv) > 0) { - max_pos = std::min(max_pos, getMaxPosPasC(producer_tv, tv)); - } - } - for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - // consume_.count(consumer_tv) > 0 is always true by definition - max_pos = std::min(max_pos, getMaxPosCasP(consumer_tv, tv)); + if (!allow_vectorize) { + iter = std::find_if(dom.begin(), iter, [this](IterDomain* id) { + return isParallelTypeVectorize(id->getParallelType()) || + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + }); } - return max_pos; -} -size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { - pos = std::min(pos, getMaxPosAll(tv)); - - // hoist inner most broadcast - while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { - pos--; + if (!allow_unmappable) { + auto root_dom = tv->getMaybeRFactorDomain(); + std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); + auto& unmappable_dims = unmappable_dims_; + auto is_unmappable = [&](IterDomain* id) -> bool { + auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); + auto all_ids = ir_utils::filterByType(all_vals); + auto check_unmappable = [&](IterDomain* root_id) { + return unmappable_dims.count(root_id) > 0; + }; + return std::any_of(all_ids.begin(), all_ids.end(), check_unmappable); + }; + iter = std::find_if(dom.begin(), iter, is_unmappable); } - return pos; + return std::distance(dom.begin(), iter); } // Return the max position in consumer that producer can be inlined to @@ -125,6 +105,20 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable +// size_t ComputeAtPosPropagator::getMaxPosPasC( +// TensorView* producer, +// TensorView* consumer) { +// auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); +// auto max_producer_pos = getMaxPosSelf(producer, true, false, false); +// for (size_t pos = max_consumer_pos; pos > 0; pos--) { +// auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( +// producer, consumer, pos); +// if (producer_pos >= 0 && producer_pos <= max_producer_pos) { +// return pos; +// } +// } +// return 0; +// } size_t ComputeAtPosPropagator::getMaxPosPasC( TensorView* producer, TensorView* consumer) { @@ -150,7 +144,7 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( auto replay_PasC = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - // Look for id's that map to a consumer id that's vectorized + // Look for id's that map to a producer id that's vectorized auto c2p_replay_map = replay_PasC.getReplay(); for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; @@ -207,6 +201,20 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable +// size_t ComputeAtPosPropagator::getMaxPosCasP( +// TensorView* consumer, +// TensorView* producer) { +// auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); +// auto max_producer_pos = getMaxPosSelf(producer, false, false, false); +// for (size_t pos = max_producer_pos; pos > 0; pos--) { +// auto consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( +// consumer, producer, pos); +// if (consumer_pos >= 0 && consumer_pos <= max_consumer_pos) { +// return pos; +// } +// } +// return 0; +// } size_t ComputeAtPosPropagator::getMaxPosCasP( TensorView* consumer, TensorView* producer) { @@ -277,6 +285,32 @@ size_t ComputeAtPosPropagator::getMaxPosCasP( return 0; } +size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { + auto max_pos = getMaxPosSelf(tv, false, false, false); + for (auto producer_tv : ir_utils::producerTvsOf(tv)) { + // only check producers that are replayed consistently + if (consume_.count(producer_tv) > 0) { + max_pos = std::min(max_pos, getMaxPosPasC(producer_tv, tv)); + } + } + for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { + // consumers are always replayed consistently + max_pos = std::min(max_pos, getMaxPosCasP(consumer_tv, tv)); + } + return max_pos; +} + +size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { + pos = std::min(pos, getMaxPosAll(tv)); + + // hoist inner most broadcast + while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { + pos--; + } + + return pos; +} + void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 40023ceef0457..2764895fab3c5 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -13,7 +13,11 @@ namespace cuda { class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { // TODO: change arguments to `from`, `to`. - size_t getMaxPosSelf(TensorView* tv); + size_t getMaxPosSelf( + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable); size_t getMaxPosPasC(TensorView* producer, TensorView* consumer); size_t getMaxPosCasP(TensorView* consumer, TensorView* producer); size_t getMaxPosAll(TensorView* tv); From c1738d8886a1d8c88044a89c39f26083fe035b9c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 22:48:56 -0700 Subject: [PATCH 071/100] getMaxPos* cleanup, step 2 --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 54 ++-------------------- 1 file changed, 4 insertions(+), 50 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index e6752093154bf..ff7e888d3b215 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -122,19 +122,8 @@ size_t ComputeAtPosPropagator::getMaxPosSelf( size_t ComputeAtPosPropagator::getMaxPosPasC( TensorView* producer, TensorView* consumer) { - // Check if any consumer dimensions are marked as vectorize as producer can - // not be inlined to vectorized dimensions in consumer. - auto c_dom = consumer->domain()->domain(); - auto vector_dim_it = - std::find_if(c_dom.begin(), c_dom.end(), [this](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - // Limit max position based on vectorized dims in consumer. - auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); + auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto c2p_root_map = @@ -169,6 +158,7 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( consumer_pos--) { // Grab all root dimensions of consumer as roots must be used to understand // inlining potential. + auto c_dom = consumer->domain()->domain(); auto consumer_root_dim_vals = IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); // convert to iter domains @@ -218,21 +208,7 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( size_t ComputeAtPosPropagator::getMaxPosCasP( TensorView* consumer, TensorView* producer) { - auto p_dom = producer->domain()->domain(); - auto first_reduction = - std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { - return id->isReduction(); - }); - - auto first_vectorized_axis = - std::find_if(p_dom.begin(), first_reduction, [this](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - - auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); + auto max_producer_pos = getMaxPosSelf(producer, false, false, false); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( @@ -260,29 +236,7 @@ size_t ComputeAtPosPropagator::getMaxPosCasP( } } - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto all_vals = DependencyCheck::getAllValsBetween( - {producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end()}, - {p_dom.begin(), p_dom.begin() + producer_pos}); - - // If any root dims could have mapped to consumer, but don't, then we can't - // compute at this point - if (std::any_of( - producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end(), - [this, &all_vals](IterDomain* p_root_id) { - return std::find(all_vals.begin(), all_vals.end(), p_root_id) != - all_vals.end() && - unmappable_dims_.find(p_root_id) != unmappable_dims_.end(); - })) { - continue; - } - - return producer_pos; - } - return 0; + return max_producer_pos; } size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { From 9f765a4db6493f0eb7d81db63006b1344d93f33a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 23:48:46 -0700 Subject: [PATCH 072/100] getMaxPos* cleanup, step 3 --- torch/csrc/jit/codegen/cuda/consume_at.cpp | 134 +++++---------------- torch/csrc/jit/codegen/cuda/consume_at.h | 7 +- 2 files changed, 38 insertions(+), 103 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index ff7e888d3b215..09d0db4ca0ea4 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -56,44 +56,52 @@ void ComputeAtPosPropagator::buildUnmappableDims() { } } -size_t ComputeAtPosPropagator::getMaxPosSelf( +bool ComputeAtPosPropagator::isAllowedID( + IterDomain* id, TensorView* tv, bool allow_reduction, bool allow_vectorize, bool allow_unmappable) { - auto dom = tv->domain()->domain(); - - auto iter = dom.end(); + bool allowed = true; if (!allow_reduction) { - iter = std::find_if( - dom.begin(), iter, [](IterDomain* id) { return id->isReduction(); }); + allowed = allowed && !id->isReduction(); } if (!allow_vectorize) { - iter = std::find_if(dom.begin(), iter, [this](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); + bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + allowed = allowed && !is_vectorize; } if (!allow_unmappable) { auto root_dom = tv->getMaybeRFactorDomain(); std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); - auto& unmappable_dims = unmappable_dims_; - auto is_unmappable = [&](IterDomain* id) -> bool { - auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); - auto all_ids = ir_utils::filterByType(all_vals); - auto check_unmappable = [&](IterDomain* root_id) { - return unmappable_dims.count(root_id) > 0; - }; - return std::any_of(all_ids.begin(), all_ids.end(), check_unmappable); + auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); + auto all_ids = ir_utils::filterByType(all_vals); + auto check_unmappable = [&](IterDomain* root_id) { + return unmappable_dims_.count(root_id) > 0; }; - iter = std::find_if(dom.begin(), iter, is_unmappable); + bool is_unmappable = + std::any_of(all_ids.begin(), all_ids.end(), check_unmappable); + allowed = allowed && !is_unmappable; } + return allowed; +} + +size_t ComputeAtPosPropagator::getMaxPosSelf( + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) { + auto dom = tv->domain()->domain(); + auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { + return !isAllowedID( + id, tv, allow_reduction, allow_vectorize, allow_unmappable); + }); return std::distance(dom.begin(), iter); } @@ -105,20 +113,6 @@ size_t ComputeAtPosPropagator::getMaxPosSelf( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -// size_t ComputeAtPosPropagator::getMaxPosPasC( -// TensorView* producer, -// TensorView* consumer) { -// auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); -// auto max_producer_pos = getMaxPosSelf(producer, true, false, false); -// for (size_t pos = max_consumer_pos; pos > 0; pos--) { -// auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( -// producer, consumer, pos); -// if (producer_pos >= 0 && producer_pos <= max_producer_pos) { -// return pos; -// } -// } -// return 0; -// } size_t ComputeAtPosPropagator::getMaxPosPasC( TensorView* producer, TensorView* consumer) { @@ -126,14 +120,8 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - auto replay_PasC = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - - // Look for id's that map to a producer id that's vectorized auto c2p_replay_map = replay_PasC.getReplay(); for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; @@ -141,47 +129,13 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); if (map_it != c2p_replay_map.end()) { auto p_id = map_it->second; - // If we find a consumer dim that maps to a producer dim that's - // vectorized or unrolled limit max compute at by it. - if (isParallelTypeVectorize(p_id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - p_id->getParallelType() == ParallelType::Unroll)) { + if (!isAllowedID(p_id, producer, true, false, false)) { max_consumer_pos = consumer_pos - 1; } } } - // Start at max position and work backwards, try to find a location where - // producer can be inlined. - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - // Grab all root dimensions of consumer as roots must be used to understand - // inlining potential. - auto c_dom = consumer->domain()->domain(); - auto consumer_root_dim_vals = - IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); - // convert to iter domains - auto consumer_root_dim_ids = - ir_utils::filterByType(consumer_root_dim_vals); - // If any root dimensions cannot be mapped to producer we can't inline. - if (std::any_of( - consumer_root_dim_ids.begin(), - consumer_root_dim_ids.end(), - [this, &c2p_root_map](IterDomain* c_root_id) { - auto p_root_id_it = c2p_root_map.find(c_root_id); - if (p_root_id_it == c2p_root_map.end()) { - return false; - } - auto p_id = p_root_id_it->second; - return unmappable_dims_.find(p_id) != unmappable_dims_.end(); - })) { - continue; - } - return consumer_pos; - } - - return 0; + return max_consumer_pos; } // Return the max position in producer that can be inlined to consumer @@ -191,33 +145,14 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -// size_t ComputeAtPosPropagator::getMaxPosCasP( -// TensorView* consumer, -// TensorView* producer) { -// auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); -// auto max_producer_pos = getMaxPosSelf(producer, false, false, false); -// for (size_t pos = max_producer_pos; pos > 0; pos--) { -// auto consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( -// consumer, producer, pos); -// if (consumer_pos >= 0 && consumer_pos <= max_consumer_pos) { -// return pos; -// } -// } -// return 0; -// } size_t ComputeAtPosPropagator::getMaxPosCasP( TensorView* consumer, TensorView* producer) { auto max_producer_pos = getMaxPosSelf(producer, false, false, false); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( - producer->domain(), consumer->domain()); - auto replay_CasP = BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - - // Look for id's that map to a consumer id that's vectorized auto p2c_replay_map = replay_CasP.getReplay(); for (size_t producer_pos = max_producer_pos; producer_pos > 0; @@ -225,12 +160,7 @@ size_t ComputeAtPosPropagator::getMaxPosCasP( auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); if (map_it != p2c_replay_map.end()) { auto c_id = map_it->second; - // If we find a producer dim that maps to a consumer vectorized or - // unrolled dim, limit max compute at by it - if (isParallelTypeVectorize(c_id->getParallelType()) || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - c_id->getParallelType() == ParallelType::Unroll)) { + if (!isAllowedID(c_id, consumer, true, false, true)) { max_producer_pos = producer_pos - 1; } } diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index 2764895fab3c5..e8d02c586cb2e 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -12,7 +12,12 @@ namespace fuser { namespace cuda { class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { - // TODO: change arguments to `from`, `to`. + bool isAllowedID( + IterDomain* id, + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable); size_t getMaxPosSelf( TensorView* tv, bool allow_reduction, From b3ef17eee2c66a8571e68b9945ec85e4c1224fb8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 28 Jun 2022 23:56:29 -0700 Subject: [PATCH 073/100] renaming --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 25 ++------- torch/csrc/jit/codegen/cuda/compute_at.h | 10 +--- torch/csrc/jit/codegen/cuda/consume_at.cpp | 51 ++++++++++++------- torch/csrc/jit/codegen/cuda/consume_at.h | 21 ++++++-- .../jit/codegen/cuda/ir_interface_nodes.h | 4 +- 5 files changed, 58 insertions(+), 53 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 9c8898b71190c..73db1f8c17900 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -165,25 +165,10 @@ std::unordered_set getPropagationSubgraph( } // namespace -bool ComputeAtSubgraphSelector::allowPasC(TensorView* from, TensorView* to) { - return selected_.count(to) > 0; -} - -bool ComputeAtSubgraphSelector::allowCasP(TensorView* from, TensorView* to) { - // If the producer is in the consume set, then the consumer must also be - // replayed to obtain a compatible loop structure so that this producer - // can be consumed in this loop. - return selected_.count(from) > 0 || selected_.count(to) > 0; -} - -bool ComputeAtSubgraphSelector::allowSibling(TensorView* from, TensorView* to) { - return true; -} - ComputeAtSubgraphSelector::ComputeAtSubgraphSelector( TensorView* producer, TensorView* consumer) - : selected_(getPropagationSubgraph(producer, consumer)) {} + : InlinePropagatorSelector(getPropagationSubgraph(producer, consumer)) {} void ComputeAt::runAt( TensorView* producer, @@ -205,13 +190,13 @@ void ComputeAt::runAt( ComputeAtSubgraphSelector selector(producer, consumer); TransformPropagator propagator(consumer, consumer_position); - ComputeAtPosPropagator ca_propagator( + InlinePropagator inline_propagator( selector.selected(), consumer, consumer_position, mode); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); path.traverse(&propagator); - path.traverse(&ca_propagator); + path.traverse(&inline_propagator); path.traverse(&updater); } @@ -235,13 +220,13 @@ void ComputeAt::runWith( ComputeAtSubgraphSelector selector(producer, consumer); TransformPropagator propagator(producer, producer_position); - ComputeAtPosPropagator ca_propagator( + InlinePropagator inline_propagator( selector.selected(), producer, producer_position, mode); MaxProducerPosUpdater updater; MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); path.traverse(&propagator); - path.traverse(&ca_propagator); + path.traverse(&inline_propagator); path.traverse(&updater); } diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 0e3e3c47eaccb..ba34efa836265 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -20,17 +20,9 @@ namespace cuda { class TensorDomain; class TensorView; -class ComputeAtSubgraphSelector : public MaxInfoSpanningTree::Selector { - std::unordered_set selected_; - +class ComputeAtSubgraphSelector : public InlinePropagatorSelector { public: - virtual bool allowPasC(TensorView* from, TensorView* to) override; - virtual bool allowCasP(TensorView* from, TensorView* to) override; - virtual bool allowSibling(TensorView* from, TensorView* to) override; ComputeAtSubgraphSelector(TensorView* producer, TensorView* consumer); - const std::unordered_set& selected() const { - return selected_; - } }; struct ComputeAt { diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/consume_at.cpp index 09d0db4ca0ea4..4337a3ee9bcf5 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/consume_at.cpp @@ -10,7 +10,22 @@ namespace jit { namespace fuser { namespace cuda { -ComputeAtPosPropagator::ComputeAtPosPropagator( +bool InlinePropagatorSelector::allowPasC(TensorView* from, TensorView* to) { + return selected_.count(to) > 0; +} + +bool InlinePropagatorSelector::allowCasP(TensorView* from, TensorView* to) { + // If the producer is in the consume set, then the consumer must also be + // replayed to obtain a compatible loop structure so that this producer + // can be consumed in this loop. + return selected_.count(from) > 0 || selected_.count(to) > 0; +} + +bool InlinePropagatorSelector::allowSibling(TensorView* from, TensorView* to) { + return true; +} + +InlinePropagator::InlinePropagator( std::unordered_set consume, TensorView* reference, size_t reference_pos, @@ -32,7 +47,7 @@ ComputeAtPosPropagator::ComputeAtPosPropagator( buildUnmappableDims(); } -void ComputeAtPosPropagator::buildUnmappableDims() { +void InlinePropagator::buildUnmappableDims() { ComputeAtRootDomainMap root_map; root_map.build(); @@ -56,7 +71,7 @@ void ComputeAtPosPropagator::buildUnmappableDims() { } } -bool ComputeAtPosPropagator::isAllowedID( +bool InlinePropagator::isAllowedID( IterDomain* id, TensorView* tv, bool allow_reduction, @@ -92,7 +107,7 @@ bool ComputeAtPosPropagator::isAllowedID( return allowed; } -size_t ComputeAtPosPropagator::getMaxPosSelf( +size_t InlinePropagator::getMaxPosSelf( TensorView* tv, bool allow_reduction, bool allow_vectorize, @@ -113,7 +128,7 @@ size_t ComputeAtPosPropagator::getMaxPosSelf( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ComputeAtPosPropagator::getMaxPosPasC( +size_t InlinePropagator::getMaxPosPasC( TensorView* producer, TensorView* consumer) { // Limit max position based on vectorized dims in consumer. @@ -145,7 +160,7 @@ size_t ComputeAtPosPropagator::getMaxPosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t ComputeAtPosPropagator::getMaxPosCasP( +size_t InlinePropagator::getMaxPosCasP( TensorView* consumer, TensorView* producer) { auto max_producer_pos = getMaxPosSelf(producer, false, false, false); @@ -169,7 +184,7 @@ size_t ComputeAtPosPropagator::getMaxPosCasP( return max_producer_pos; } -size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { +size_t InlinePropagator::getMaxPosAll(TensorView* tv) { auto max_pos = getMaxPosSelf(tv, false, false, false); for (auto producer_tv : ir_utils::producerTvsOf(tv)) { // only check producers that are replayed consistently @@ -184,7 +199,7 @@ size_t ComputeAtPosPropagator::getMaxPosAll(TensorView* tv) { return max_pos; } -size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { +size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { pos = std::min(pos, getMaxPosAll(tv)); // hoist inner most broadcast @@ -195,7 +210,7 @@ size_t ComputeAtPosPropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { return pos; } -void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { +void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; recordReplayedPos(reference_, reference_pos_); @@ -204,7 +219,7 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and - // ComputeAtPosPropagator only set the CA positions? + // InlinePropagator only set the CA positions? // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); @@ -214,7 +229,7 @@ void ComputeAtPosPropagator::propagateTvPasC(TensorView* from, TensorView* to) { recordReplayedPos(to, to_pos); } -void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { +void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; recordReplayedPos(reference_, reference_pos_); @@ -223,7 +238,7 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and - // ComputeAtPosPropagator only set the CA positions? + // InlinePropagator only set the CA positions? // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); @@ -233,9 +248,7 @@ void ComputeAtPosPropagator::propagateTvCasP(TensorView* from, TensorView* to) { recordReplayedPos(to, to_pos); } -void ComputeAtPosPropagator::propagateTvSibling( - TensorView* from, - TensorView* to) { +void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; recordReplayedPos(reference_, reference_pos_); @@ -248,7 +261,7 @@ void ComputeAtPosPropagator::propagateTvSibling( recordReplayedPos(to, from_pos); } -void ComputeAtPosPropagator::recordReplayedPos(TensorView* tv, size_t pos) { +void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) { if (consume_.count(tv)) { auto new_pos = adjustComputeAtPos(tv, pos); if (pos != new_pos) { @@ -265,7 +278,7 @@ void ComputeAtPosPropagator::recordReplayedPos(TensorView* tv, size_t pos) { } } -size_t ComputeAtPosPropagator::retrieveReplayedPos(TensorView* tv) { +size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { auto it = replayed_pos_.find(tv); if (it != replayed_pos_.end()) { return it->second; @@ -273,7 +286,7 @@ size_t ComputeAtPosPropagator::retrieveReplayedPos(TensorView* tv) { return tv->getComputeAtPosition(); } -size_t ComputeAtPosPropagator::getReplayPosPasC( +size_t InlinePropagator::getReplayPosPasC( TensorView* producer, TensorView* consumer) { size_t max_pos = getMaxPosPasC(producer, consumer); @@ -298,7 +311,7 @@ size_t ComputeAtPosPropagator::getReplayPosPasC( return pos; } -size_t ComputeAtPosPropagator::getReplayPosCasP( +size_t InlinePropagator::getReplayPosCasP( TensorView* consumer, TensorView* producer) { size_t max_pos = getMaxPosCasP(consumer, producer); diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/consume_at.h index e8d02c586cb2e..28735410d2033 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.h +++ b/torch/csrc/jit/codegen/cuda/consume_at.h @@ -11,7 +11,22 @@ namespace jit { namespace fuser { namespace cuda { -class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { +class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { + std::unordered_set selected_; + + public: + virtual bool allowPasC(TensorView* from, TensorView* to) override; + virtual bool allowCasP(TensorView* from, TensorView* to) override; + virtual bool allowSibling(TensorView* from, TensorView* to) override; + + InlinePropagatorSelector(std::unordered_set selected) + : selected_(std::move(selected)){}; + const std::unordered_set& selected() const { + return selected_; + } +}; + +class InlinePropagator : public MaxInfoSpanningTree::Propagator { bool isAllowedID( IterDomain* id, TensorView* tv, @@ -49,13 +64,13 @@ class ComputeAtPosPropagator : public MaxInfoSpanningTree::Propagator { std::unordered_set unmappable_dims_; public: - ComputeAtPosPropagator( + InlinePropagator( std::unordered_set consume, TensorView* reference, size_t reference_pos, ComputeAtMode mode); - ~ComputeAtPosPropagator() = default; + ~InlinePropagator() = default; virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 4c754ad128cee..bb4484f7db6c2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -154,7 +154,7 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class ComputeAtPosPropagator; +class InlinePropagator; class MaxProducerPosUpdater; class TransformPropagator; class TransformIter; @@ -459,7 +459,7 @@ class TORCH_CUDA_CU_API TensorView : public Val { friend TORCH_CUDA_CU_API TransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend TORCH_CUDA_CU_API ComputeAtPosPropagator; + friend TORCH_CUDA_CU_API InlinePropagator; friend TORCH_CUDA_CU_API MaxProducerPosUpdater; friend class ir_utils::TVDomainGuard; friend TORCH_CUDA_CU_API void groupReductions( From a28f7aadfc583300f6cbda7e6a62e203945df9cf Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 00:00:28 -0700 Subject: [PATCH 074/100] file renaming --- build_variables.bzl | 2 +- torch/csrc/jit/codegen/cuda/compute_at.h | 2 +- .../jit/codegen/cuda/{consume_at.cpp => inline_propagator.cpp} | 2 +- .../csrc/jit/codegen/cuda/{consume_at.h => inline_propagator.h} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename torch/csrc/jit/codegen/cuda/{consume_at.cpp => inline_propagator.cpp} (99%) rename torch/csrc/jit/codegen/cuda/{consume_at.h => inline_propagator.h} (100%) diff --git a/build_variables.bzl b/build_variables.bzl index 84645b1f5d96d..adebcd9125b63 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -642,7 +642,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/jit/codegen/cuda/arith.cpp", "torch/csrc/jit/codegen/cuda/compute_at.cpp", - "torch/csrc/jit/codegen/cuda/consume_at.cpp", + "torch/csrc/jit/codegen/cuda/inline_propagator.cpp", "torch/csrc/jit/codegen/cuda/compute_at_map.cpp", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/contiguity.cpp", diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index ba34efa836265..6334f97930e60 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/consume_at.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp similarity index 99% rename from torch/csrc/jit/codegen/cuda/consume_at.cpp rename to torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 4337a3ee9bcf5..729351bfd67a7 100644 --- a/torch/csrc/jit/codegen/cuda/consume_at.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/torch/csrc/jit/codegen/cuda/consume_at.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h similarity index 100% rename from torch/csrc/jit/codegen/cuda/consume_at.h rename to torch/csrc/jit/codegen/cuda/inline_propagator.h From 6bf74e66c333f007fcb14a87a728bdca20939c41 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 00:58:04 -0700 Subject: [PATCH 075/100] validate domain --- .../jit/codegen/cuda/inline_propagator.cpp | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 729351bfd67a7..361bc6ec4646f 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -210,6 +210,17 @@ size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { return pos; } +namespace { + +bool validateDomain(TensorView* tv, TensorDomain* new_td) { + auto first_mismatch = + BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); + return first_mismatch >= (int)tv->getMaxProducerPosition() && + first_mismatch >= (int)tv->getComputeAtPosition(); +} + +} // namespace + void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { if (is_first_) { is_first_ = false; @@ -223,6 +234,13 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); to->setDomain(replay.first); to_pos = replay.second; } @@ -242,6 +260,13 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { // TORCH_CHECK(to_pos >= 0); if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); to->setDomain(replay.first); to_pos = replay.second; } @@ -256,6 +281,13 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { auto from_pos = retrieveReplayedPos(from); if (!TransformReplay::fullSelfMatching(to, from)) { auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay), + "Tried to set the domain of ", + to, + " to ", + replay, + " but that would invalidate previously compute at position or max producer position."); to->setDomain(replay); } recordReplayedPos(to, from_pos); From 2a84c76650a5f7edcd0b6d35142ef49294375ce1 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 01:18:13 -0700 Subject: [PATCH 076/100] split out MaxPosCalculator --- .../jit/codegen/cuda/inline_propagator.cpp | 73 ++++++++++--------- .../csrc/jit/codegen/cuda/inline_propagator.h | 35 +++++---- 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 361bc6ec4646f..2dc6e6775c9cf 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -25,33 +25,15 @@ bool InlinePropagatorSelector::allowSibling(TensorView* from, TensorView* to) { return true; } -InlinePropagator::InlinePropagator( - std::unordered_set consume, - TensorView* reference, - size_t reference_pos, - ComputeAtMode mode) - : consume_(std::move(consume)), - reference_(reference), - reference_pos_(reference_pos), - mode_(mode) { - TORCH_INTERNAL_ASSERT( - reference_pos >= 0 && reference_pos <= reference->nDims(), - "Invalid computeAt axis, received ", - reference_pos, - " but should be > -", - reference->nDims(), - " and <= ", - reference->nDims(), - "."); - +MaxPosCalculator::MaxPosCalculator(ComputeAtMode mode) : mode_(mode) { buildUnmappableDims(); } -void InlinePropagator::buildUnmappableDims() { +void MaxPosCalculator::buildUnmappableDims() { ComputeAtRootDomainMap root_map; root_map.build(); - auto all_tvs = ir_utils::allTvs(reference_->fusion()); + auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { @@ -71,12 +53,12 @@ void InlinePropagator::buildUnmappableDims() { } } -bool InlinePropagator::isAllowedID( +bool MaxPosCalculator::isAllowedID( IterDomain* id, TensorView* tv, bool allow_reduction, bool allow_vectorize, - bool allow_unmappable) { + bool allow_unmappable) const { bool allowed = true; if (!allow_reduction) { @@ -107,11 +89,11 @@ bool InlinePropagator::isAllowedID( return allowed; } -size_t InlinePropagator::getMaxPosSelf( +size_t MaxPosCalculator::getMaxPosSelf( TensorView* tv, bool allow_reduction, bool allow_vectorize, - bool allow_unmappable) { + bool allow_unmappable) const { auto dom = tv->domain()->domain(); auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { return !isAllowedID( @@ -128,9 +110,9 @@ size_t InlinePropagator::getMaxPosSelf( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t InlinePropagator::getMaxPosPasC( +size_t MaxPosCalculator::getMaxPosPasC( TensorView* producer, - TensorView* consumer) { + TensorView* consumer) const { // Limit max position based on vectorized dims in consumer. auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true); @@ -160,9 +142,9 @@ size_t InlinePropagator::getMaxPosPasC( // Unrolled dimensions in producer or consumer // Dimensions derived from root dimensions that exist in both but are // unmappable -size_t InlinePropagator::getMaxPosCasP( +size_t MaxPosCalculator::getMaxPosCasP( TensorView* consumer, - TensorView* producer) { + TensorView* producer) const { auto max_producer_pos = getMaxPosSelf(producer, false, false, false); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); @@ -184,17 +166,40 @@ size_t InlinePropagator::getMaxPosCasP( return max_producer_pos; } +InlinePropagator::InlinePropagator( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode) + : max_pos_calc(mode), + consume_(std::move(consume)), + reference_(reference), + reference_pos_(reference_pos), + mode_(mode) { + TORCH_INTERNAL_ASSERT( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid computeAt axis, received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); +} + size_t InlinePropagator::getMaxPosAll(TensorView* tv) { - auto max_pos = getMaxPosSelf(tv, false, false, false); + auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); for (auto producer_tv : ir_utils::producerTvsOf(tv)) { // only check producers that are replayed consistently if (consume_.count(producer_tv) > 0) { - max_pos = std::min(max_pos, getMaxPosPasC(producer_tv, tv)); + max_pos = std::min( + max_pos, max_pos_calc.getMaxPosPasC(producer_tv, tv)); } } for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { // consumers are always replayed consistently - max_pos = std::min(max_pos, getMaxPosCasP(consumer_tv, tv)); + max_pos = + std::min(max_pos, max_pos_calc.getMaxPosCasP(consumer_tv, tv)); } return max_pos; } @@ -321,7 +326,7 @@ size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { size_t InlinePropagator::getReplayPosPasC( TensorView* producer, TensorView* consumer) { - size_t max_pos = getMaxPosPasC(producer, consumer); + size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer); size_t pos = retrieveReplayedPos(consumer); if (mode_ == ComputeAtMode::BestEffort) { @@ -346,7 +351,7 @@ size_t InlinePropagator::getReplayPosPasC( size_t InlinePropagator::getReplayPosCasP( TensorView* consumer, TensorView* producer) { - size_t max_pos = getMaxPosCasP(consumer, producer); + size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer); size_t pos = retrieveReplayedPos(producer); if (mode_ == ComputeAtMode::BestEffort) { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 28735410d2033..ffc6ec5bf85ec 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -26,33 +26,43 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { } }; -class InlinePropagator : public MaxInfoSpanningTree::Propagator { +class MaxPosCalculator { + ComputeAtMode mode_ = ComputeAtMode::Standard; + + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + bool isAllowedID( IterDomain* id, TensorView* tv, bool allow_reduction, bool allow_vectorize, - bool allow_unmappable); + bool allow_unmappable) const; + + public: size_t getMaxPosSelf( TensorView* tv, bool allow_reduction, bool allow_vectorize, - bool allow_unmappable); - size_t getMaxPosPasC(TensorView* producer, TensorView* consumer); - size_t getMaxPosCasP(TensorView* consumer, TensorView* producer); - size_t getMaxPosAll(TensorView* tv); + bool allow_unmappable) const; + size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const; + size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const; + MaxPosCalculator(ComputeAtMode mode); +}; +class InlinePropagator : public MaxInfoSpanningTree::Propagator { + size_t getMaxPosAll(TensorView* tv); size_t adjustComputeAtPos(TensorView* tv, size_t pos); - size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); void recordReplayedPos(TensorView* tv, size_t pos); size_t retrieveReplayedPos(TensorView* tv); - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); - + const MaxPosCalculator max_pos_calc; std::unordered_set consume_; TensorView* reference_; size_t reference_pos_; @@ -60,9 +70,6 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { std::unordered_map replayed_pos_; bool is_first_ = true; - // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims_; - public: InlinePropagator( std::unordered_set consume, From 4953b6c6f4671d7bae8dbe8d6594d40dea5d9777 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 01:26:31 -0700 Subject: [PATCH 077/100] cleanup computeAt --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 11 ++++------- torch/csrc/jit/codegen/cuda/compute_at.h | 5 ----- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 73db1f8c17900..0a6fbd5f05dea 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -165,11 +165,6 @@ std::unordered_set getPropagationSubgraph( } // namespace -ComputeAtSubgraphSelector::ComputeAtSubgraphSelector( - TensorView* producer, - TensorView* consumer) - : InlinePropagatorSelector(getPropagationSubgraph(producer, consumer)) {} - void ComputeAt::runAt( TensorView* producer, TensorView* consumer, @@ -187,7 +182,8 @@ void ComputeAt::runAt( FusionGuard fg(producer->fusion()); - ComputeAtSubgraphSelector selector(producer, consumer); + auto selected = getPropagationSubgraph(producer, consumer); + InlinePropagatorSelector selector(selected); TransformPropagator propagator(consumer, consumer_position); InlinePropagator inline_propagator( @@ -217,7 +213,8 @@ void ComputeAt::runWith( FusionGuard fg(producer->fusion()); - ComputeAtSubgraphSelector selector(producer, consumer); + auto selected = getPropagationSubgraph(producer, consumer); + InlinePropagatorSelector selector(selected); TransformPropagator propagator(producer, producer_position); InlinePropagator inline_propagator( diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 6334f97930e60..f438048d79a0c 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -20,11 +20,6 @@ namespace cuda { class TensorDomain; class TensorView; -class ComputeAtSubgraphSelector : public InlinePropagatorSelector { - public: - ComputeAtSubgraphSelector(TensorView* producer, TensorView* consumer); -}; - struct ComputeAt { public: // Runs the compute at pass making producer look like consumer, computing From 9100c3d5db1d3e823f0e9b1958717e6e2d180c1b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 01:37:46 -0700 Subject: [PATCH 078/100] siblingTvsOf --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 0a6fbd5f05dea..35a35acfe0a97 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -116,16 +116,11 @@ TensorView* getCommonConsumer(TensorView* producer, TensorView* consumer) { void pullInSiblings(std::unordered_set& s) { for (auto tv : s) { - auto tvd = tv->definition(); - if (tvd != nullptr) { - auto outs = tvd->outputs(); - auto out_tvs = ir_utils::filterByType(outs); - for (auto sibling_tv : out_tvs) { - if (sibling_tv == tv) { - continue; - } - s.emplace(sibling_tv); + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + if (sibling_tv == tv) { + continue; } + s.emplace(sibling_tv); } } } From 0830c0a6dfadf7de2363f4039c34db629b587763 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jun 2022 11:18:41 -0400 Subject: [PATCH 079/100] Move functions around to be consistent with header order, add comments in header. --- .../jit/codegen/cuda/inline_propagator.cpp | 194 +++++++++--------- .../csrc/jit/codegen/cuda/inline_propagator.h | 43 +++- 2 files changed, 138 insertions(+), 99 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 2dc6e6775c9cf..26ba8a560df8b 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -166,27 +166,6 @@ size_t MaxPosCalculator::getMaxPosCasP( return max_producer_pos; } -InlinePropagator::InlinePropagator( - std::unordered_set consume, - TensorView* reference, - size_t reference_pos, - ComputeAtMode mode) - : max_pos_calc(mode), - consume_(std::move(consume)), - reference_(reference), - reference_pos_(reference_pos), - mode_(mode) { - TORCH_INTERNAL_ASSERT( - reference_pos >= 0 && reference_pos <= reference->nDims(), - "Invalid computeAt axis, received ", - reference_pos, - " but should be > -", - reference->nDims(), - " and <= ", - reference->nDims(), - "."); -} - size_t InlinePropagator::getMaxPosAll(TensorView* tv) { auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); for (auto producer_tv : ir_utils::producerTvsOf(tv)) { @@ -215,8 +194,106 @@ size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) { return pos; } +size_t InlinePropagator::getReplayPosPasC( + TensorView* producer, + TensorView* consumer) { + size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer); + size_t pos = retrieveReplayedPos(consumer); + + if (mode_ == ComputeAtMode::BestEffort) { + return std::min(pos, max_pos); + } else if (mode_ == ComputeAtMode::MostInlined) { + return max_pos; + } + + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in compute at when trying to replay producer: ", + producer, + " as consumer: ", + consumer, + " tried to do this at position: ", + pos, + " but max position that's allowed is ", + max_pos); + return pos; +} + +size_t InlinePropagator::getReplayPosCasP( + TensorView* consumer, + TensorView* producer) { + size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer); + size_t pos = retrieveReplayedPos(producer); + + if (mode_ == ComputeAtMode::BestEffort) { + return std::min(pos, max_pos); + } else if (mode_ == ComputeAtMode::MostInlined) { + return max_pos; + } + + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in compute at when trying to replay consumer: ", + consumer, + " as producer: ", + producer, + " tried to do this at position: ", + pos, + " but max position that's allowed is ", + max_pos); + return pos; +} + +void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) { + if (consume_.count(tv)) { + auto new_pos = adjustComputeAtPos(tv, pos); + if (pos != new_pos) { + replayed_pos_[tv] = pos; + pos = new_pos; + } + if (!tv->isFusionInput()) { + tv->setComputeAt(pos); + } else { + replayed_pos_[tv] = pos; + } + } else { + replayed_pos_[tv] = pos; + } +} + +size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { + auto it = replayed_pos_.find(tv); + if (it != replayed_pos_.end()) { + return it->second; + } + return tv->getComputeAtPosition(); +} + +InlinePropagator::InlinePropagator( + std::unordered_set consume, + TensorView* reference, + size_t reference_pos, + ComputeAtMode mode) + : max_pos_calc(mode), + consume_(std::move(consume)), + reference_(reference), + reference_pos_(reference_pos), + mode_(mode) { + TORCH_INTERNAL_ASSERT( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid computeAt axis, received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); +} + namespace { +// Make sure if tv is set to new_td it doesn't violate set compute at and max +// produce at positions. bool validateDomain(TensorView* tv, TensorDomain* new_td) { auto first_mismatch = BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); @@ -298,81 +375,6 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { recordReplayedPos(to, from_pos); } -void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) { - if (consume_.count(tv)) { - auto new_pos = adjustComputeAtPos(tv, pos); - if (pos != new_pos) { - replayed_pos_[tv] = pos; - pos = new_pos; - } - if (!tv->isFusionInput()) { - tv->setComputeAt(pos); - } else { - replayed_pos_[tv] = pos; - } - } else { - replayed_pos_[tv] = pos; - } -} - -size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { - auto it = replayed_pos_.find(tv); - if (it != replayed_pos_.end()) { - return it->second; - } - return tv->getComputeAtPosition(); -} - -size_t InlinePropagator::getReplayPosPasC( - TensorView* producer, - TensorView* consumer) { - size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer); - size_t pos = retrieveReplayedPos(consumer); - - if (mode_ == ComputeAtMode::BestEffort) { - return std::min(pos, max_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - return max_pos; - } - - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in compute at when trying to replay producer: ", - producer, - " as consumer: ", - consumer, - " tried to do this at position: ", - pos, - " but max position that's allowed is ", - max_pos); - return pos; -} - -size_t InlinePropagator::getReplayPosCasP( - TensorView* consumer, - TensorView* producer) { - size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer); - size_t pos = retrieveReplayedPos(producer); - - if (mode_ == ComputeAtMode::BestEffort) { - return std::min(pos, max_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - return max_pos; - } - - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in compute at when trying to replay consumer: ", - consumer, - " as producer: ", - producer, - " tried to do this at position: ", - pos, - " but max position that's allowed is ", - max_pos); - return pos; -} - // Try to find the aligned position on consumer's domain corresponding to the // compute at position of producer domain. void MaxProducerPosUpdater::handle(TensorView* consumer) { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index ffc6ec5bf85ec..a54276fd93512 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -11,6 +11,9 @@ namespace jit { namespace fuser { namespace cuda { +// Simple selector that only propagates across tensor views in the provided +// unordered_set. Will also propagate to all consumers of those tensors, and the +// siblings of those tensors. class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector { std::unordered_set selected_; @@ -36,6 +39,10 @@ class MaxPosCalculator { // map to all its consumer TVs. void buildUnmappableDims(); + // Utility function to return if an id of tv is a valid iter domain to inline + // within. This is used in getMaxPos{PasC,CasP}. Different variations of the + // bool values are used if checking max position of PasC, CasP, or checking + // for a max "self" position. bool isAllowedID( IterDomain* id, TensorView* tv, @@ -44,22 +51,49 @@ class MaxPosCalculator { bool allow_unmappable) const; public: + // Returns the position at which tv can be relayed within. size_t getMaxPosSelf( TensorView* tv, bool allow_reduction, bool allow_vectorize, bool allow_unmappable) const; + + // Returns the maximum position producer can be replayed based on consumer + // given the set ComputeAtMode size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const; + + // Returns the maximum position consumer can be replayed based on producer + // given the set ComputeAtMode size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const; + MaxPosCalculator(ComputeAtMode mode); }; class InlinePropagator : public MaxInfoSpanningTree::Propagator { + // Checks producers and consumers to see what the maximum position in tv is + // that can be shared across both directions. size_t getMaxPosAll(TensorView* tv); + + // Returns position of getMaxPosAll while also hoisting outside broadcast + // dimensions. size_t adjustComputeAtPos(TensorView* tv, size_t pos); + + // Returns the replay position that producer should be replayed as based on + // consumer, taking into consideration the max possible returned by + // getMaxPos{PasC, CasP}, the compute at mode type. size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); + + // Returns the replay position that consumer should be replayed as based on + // producer, taking into consideration the max possible returned by + // getMaxPos{PasC, CasP}, the compute at mode type. size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); + + // Sets the compute at position of tv and records the position in + // replayed_pos_ void recordReplayedPos(TensorView* tv, size_t pos); + + // Returns the entry for tv in replayed_pos_ if it exists, else returns the + // compute at position of tv. size_t retrieveReplayedPos(TensorView* tv); const MaxPosCalculator max_pos_calc; @@ -79,14 +113,17 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { ~InlinePropagator() = default; + // Actually propagate the transformations for the inlining pass. Uses the + // functions above to figure out what position to do the propagation at. virtual void propagateTvPasC(TensorView* from, TensorView* to) override; virtual void propagateTvCasP(TensorView* from, TensorView* to) override; virtual void propagateTvSibling(TensorView* from, TensorView* to) override; }; -// This is actually not a propagation, and it is not needed to compute the max -// producer position in a specific order. But MaxInfoSpanningTree provides a -// very convenient API to visit the tensors, so I just use it for cleaner code. +// This is actually not a propagation, it only sets the max producer position of +// the tensors, and it is not needed to compute the max producer position in a +// specific order. But MaxInfoSpanningTree provides a very convenient API to +// visit the tensors, so I just use it for cleaner code. class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator { std::unordered_set updated_; void handle(TensorView* tv); From d0b0fc1b2c3aa37d52574ecc979f7c97a0e6612c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 12:37:08 -0700 Subject: [PATCH 080/100] fix --- .../jit/codegen/cuda/inline_propagator.cpp | 19 ++++++++++++------- .../csrc/jit/codegen/cuda/inline_propagator.h | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 26ba8a560df8b..d5615b27bc8ed 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -77,12 +77,14 @@ bool MaxPosCalculator::isAllowedID( auto root_dom = tv->getMaybeRFactorDomain(); std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); - auto all_ids = ir_utils::filterByType(all_vals); - auto check_unmappable = [&](IterDomain* root_id) { - return unmappable_dims_.count(root_id) > 0; - }; - bool is_unmappable = - std::any_of(all_ids.begin(), all_ids.end(), check_unmappable); + bool is_unmappable = false; + for (auto val : all_vals) { + auto id = val->as(); + if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { + is_unmappable = true; + break; + } + } allowed = allowed && !is_unmappable; } @@ -272,13 +274,16 @@ size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { InlinePropagator::InlinePropagator( std::unordered_set consume, TensorView* reference, - size_t reference_pos, + int64_t reference_pos, ComputeAtMode mode) : max_pos_calc(mode), consume_(std::move(consume)), reference_(reference), reference_pos_(reference_pos), mode_(mode) { + if (reference_pos < 0) { + reference_pos += int64_t(reference->nDims()) + 1; + } TORCH_INTERNAL_ASSERT( reference_pos >= 0 && reference_pos <= reference->nDims(), "Invalid computeAt axis, received ", diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index a54276fd93512..8ba760065f9e9 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -108,7 +108,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { InlinePropagator( std::unordered_set consume, TensorView* reference, - size_t reference_pos, + int64_t reference_pos, ComputeAtMode mode); ~InlinePropagator() = default; From ea0d9cbf4cfdba4155cbdcfe5369f45ba18ded2b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 12:46:46 -0700 Subject: [PATCH 081/100] minor cleanup on variable names and comments --- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 10 +++++----- torch/csrc/jit/codegen/cuda/inline_propagator.h | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index d5615b27bc8ed..f915774d437b1 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -15,7 +15,7 @@ bool InlinePropagatorSelector::allowPasC(TensorView* from, TensorView* to) { } bool InlinePropagatorSelector::allowCasP(TensorView* from, TensorView* to) { - // If the producer is in the consume set, then the consumer must also be + // If the producer is in the selected set, then the consumer must also be // replayed to obtain a compatible loop structure so that this producer // can be consumed in this loop. return selected_.count(from) > 0 || selected_.count(to) > 0; @@ -172,7 +172,7 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) { auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); for (auto producer_tv : ir_utils::producerTvsOf(tv)) { // only check producers that are replayed consistently - if (consume_.count(producer_tv) > 0) { + if (selected_.count(producer_tv) > 0) { max_pos = std::min( max_pos, max_pos_calc.getMaxPosPasC(producer_tv, tv)); } @@ -247,7 +247,7 @@ size_t InlinePropagator::getReplayPosCasP( } void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) { - if (consume_.count(tv)) { + if (selected_.count(tv)) { auto new_pos = adjustComputeAtPos(tv, pos); if (pos != new_pos) { replayed_pos_[tv] = pos; @@ -272,12 +272,12 @@ size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) { } InlinePropagator::InlinePropagator( - std::unordered_set consume, + std::unordered_set selected, TensorView* reference, int64_t reference_pos, ComputeAtMode mode) : max_pos_calc(mode), - consume_(std::move(consume)), + selected_(std::move(selected)), reference_(reference), reference_pos_(reference_pos), mode_(mode) { diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h index 8ba760065f9e9..40df6548add0d 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -78,13 +78,13 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { // dimensions. size_t adjustComputeAtPos(TensorView* tv, size_t pos); - // Returns the replay position that producer should be replayed as based on - // consumer, taking into consideration the max possible returned by + // Returns the replay position in consumer that producer should be replayed as + // based on consumer, taking into consideration the max possible returned by // getMaxPos{PasC, CasP}, the compute at mode type. size_t getReplayPosPasC(TensorView* producer, TensorView* consumer); - // Returns the replay position that consumer should be replayed as based on - // producer, taking into consideration the max possible returned by + // Returns the replay position in producer that consumer should be replayed as + // based on producer, taking into consideration the max possible returned by // getMaxPos{PasC, CasP}, the compute at mode type. size_t getReplayPosCasP(TensorView* consumer, TensorView* producer); @@ -97,7 +97,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { size_t retrieveReplayedPos(TensorView* tv); const MaxPosCalculator max_pos_calc; - std::unordered_set consume_; + std::unordered_set selected_; TensorView* reference_; size_t reference_pos_; ComputeAtMode mode_ = ComputeAtMode::Standard; @@ -106,7 +106,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator { public: InlinePropagator( - std::unordered_set consume, + std::unordered_set selected, TensorView* reference, int64_t reference_pos, ComputeAtMode mode); From aa5f6020b9f2a12e4f3ab861b19b842a9e386a06 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 20:29:53 -0700 Subject: [PATCH 082/100] Add SpanningTreePrinter --- .../jit/codegen/cuda/maxinfo_propagator.cpp | 18 ++++++++++++++++++ .../csrc/jit/codegen/cuda/maxinfo_propagator.h | 11 +++++++++++ 2 files changed, 29 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index 06c2fcaf01547..44656b4df46f1 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -416,6 +416,24 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: return from_info; } +void SpanningTreePrinter::propagateTvPasC(TensorView* from, TensorView* to) { + stream_ << "propagateTvPasC" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateTvCasP(TensorView* from, TensorView* to) { + stream_ << "propagateTvCasP" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateTvSibling(TensorView* from, TensorView* to) { + stream_ << "propagateTvSibling" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index db32aaef6d23c..5a3ac6d46f479 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -232,6 +232,17 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree selector) {} }; +class TORCH_CUDA_CU_API SpanningTreePrinter + : public MaxInfoSpanningTree::Propagator { + std::ostream& stream_; + + public: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; + SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {} +}; + } // namespace cuda } // namespace fuser } // namespace jit From e4d0aac7368190d45e2436e22e247e665c7b58b2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 29 Jun 2022 20:49:08 -0700 Subject: [PATCH 083/100] no check producer --- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index f915774d437b1..2fcee1c989787 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -147,7 +147,7 @@ size_t MaxPosCalculator::getMaxPosPasC( size_t MaxPosCalculator::getMaxPosCasP( TensorView* consumer, TensorView* producer) const { - auto max_producer_pos = getMaxPosSelf(producer, false, false, false); + auto max_producer_pos = producer->nDims(); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto replay_CasP = @@ -170,13 +170,6 @@ size_t MaxPosCalculator::getMaxPosCasP( size_t InlinePropagator::getMaxPosAll(TensorView* tv) { auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); - for (auto producer_tv : ir_utils::producerTvsOf(tv)) { - // only check producers that are replayed consistently - if (selected_.count(producer_tv) > 0) { - max_pos = std::min( - max_pos, max_pos_calc.getMaxPosPasC(producer_tv, tv)); - } - } for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { // consumers are always replayed consistently max_pos = From 5d4093510446c5aef402914d21df3b05ec1ec656 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 10:09:11 -0700 Subject: [PATCH 084/100] revert getMaxPosCasP to restore previous behavior --- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 2fcee1c989787..74e2e6b407c1a 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -147,7 +147,7 @@ size_t MaxPosCalculator::getMaxPosPasC( size_t MaxPosCalculator::getMaxPosCasP( TensorView* consumer, TensorView* producer) const { - auto max_producer_pos = producer->nDims(); + auto max_producer_pos = getMaxPosSelf(producer, false, false, false); auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); auto replay_CasP = From 1281d343ba80d2ee0ae080b19d02b8ff240d2aba Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 12:15:22 -0700 Subject: [PATCH 085/100] revert #1786 --- .../jit/codegen/cuda/maxinfo_propagator.cpp | 18 ------------------ .../csrc/jit/codegen/cuda/maxinfo_propagator.h | 11 ----------- 2 files changed, 29 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index 44656b4df46f1..06c2fcaf01547 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -416,24 +416,6 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: return from_info; } -void SpanningTreePrinter::propagateTvPasC(TensorView* from, TensorView* to) { - stream_ << "propagateTvPasC" << std::endl; - stream_ << " from: " << from->toString() << std::endl; - stream_ << " to: " << to->toString() << std::endl; -} - -void SpanningTreePrinter::propagateTvCasP(TensorView* from, TensorView* to) { - stream_ << "propagateTvCasP" << std::endl; - stream_ << " from: " << from->toString() << std::endl; - stream_ << " to: " << to->toString() << std::endl; -} - -void SpanningTreePrinter::propagateTvSibling(TensorView* from, TensorView* to) { - stream_ << "propagateTvSibling" << std::endl; - stream_ << " from: " << from->toString() << std::endl; - stream_ << " to: " << to->toString() << std::endl; -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 5a3ac6d46f479..db32aaef6d23c 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -232,17 +232,6 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree selector) {} }; -class TORCH_CUDA_CU_API SpanningTreePrinter - : public MaxInfoSpanningTree::Propagator { - std::ostream& stream_; - - public: - virtual void propagateTvPasC(TensorView* from, TensorView* to) override; - virtual void propagateTvCasP(TensorView* from, TensorView* to) override; - virtual void propagateTvSibling(TensorView* from, TensorView* to) override; - SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {} -}; - } // namespace cuda } // namespace fuser } // namespace jit From befa9dc9509a5a26a9ee13296b0e142be4b35b9c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 14:11:37 -0700 Subject: [PATCH 086/100] fix TransformReplay::getMatchedLeafPosWithoutReplay --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 25 +++++++++++++++++++ .../jit/codegen/cuda/transform_replay.cpp | 5 ++++ 2 files changed, 30 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d543d8dc356ef..951245f4841c1 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24079,6 +24079,31 @@ TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { testValidate(&fusion, cg_outputs, {in1, in2}, {tv_ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->split(1, 2, false); + + TransformPropagator propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv2, 3) == 1); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv0, 1) >= 1 && + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv0, 1) <= 3); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index e961867865181..bc555c5aa90be 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -689,6 +689,11 @@ int getMatchedLeafPosWithoutReplay( if (!mapped_consumer_domain_ids.count(consumer_id)) { ++it_consumer; mismatched_consumer_pos++; + if (consumer_pos) { + if (consumer_or_producer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + } continue; } From 11d7ff9431db613c62d4df96d532ae204cc7cd5b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 15:11:52 -0700 Subject: [PATCH 087/100] save --- .../csrc/jit/codegen/cuda/inline_propagator.cpp | 17 +++++++++++++++-- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 1 + .../csrc/jit/codegen/cuda/transform_replay.cpp | 9 +++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 74e2e6b407c1a..cbe2e824b26cb 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -311,7 +311,9 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - // TORCH_CHECK(to_pos >= 0); + if (mode_ == ComputeAtMode::Standard) { + TORCH_CHECK(to_pos >= 0); + } if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -325,6 +327,9 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); + std::cout << "InlinePropagator::propagateTvPasC" << std::endl; + std::cout << " from: " << from->toString() << " at " << pos << std::endl; + std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -337,7 +342,9 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - // TORCH_CHECK(to_pos >= 0); + if (mode_ == ComputeAtMode::Standard) { + TORCH_CHECK(to_pos >= 0); + } if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -351,6 +358,9 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); + std::cout << "InlinePropagator::propagateTvCasP" << std::endl; + std::cout << " from: " << from->toString() << " at " << pos << std::endl; + std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -371,6 +381,9 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } recordReplayedPos(to, from_pos); + std::cout << "InlinePropagator::propagateTvSibling" << std::endl; + std::cout << " from: " << from->toString() << " at " << from_pos << std::endl; + std::cout << " to: " << to->toString() << " at " << from_pos << std::endl; } // Try to find the aligned position on consumer's domain corresponding to the diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d543d8dc356ef..2d8236a679f27 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -2137,6 +2137,7 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { // Reverse computeAt structure from previous test tv0->computeAt(tv4, -1); tv2->computeAt(tv4, -1); + fusion.print(); tv0->computeAt(tv7, -1); const int numel_x = 100; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index e961867865181..0d8ca9eb69cb7 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -789,6 +789,9 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; + std::cout << "TransformPropagator::propagateTvPasC" << std::endl; + std::cout << " from: " << from->toString() << " at " << pos << std::endl; + std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -802,6 +805,9 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; + std::cout << "TransformPropagator::propagateTvCasP" << std::endl; + std::cout << " from: " << from->toString() << " at " << pos << std::endl; + std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -812,6 +818,9 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } replayed_pos_[to] = pos; + std::cout << "TransformPropagator::propagateTvSibling" << std::endl; + std::cout << " from: " << from->toString() << " at " << pos << std::endl; + std::cout << " to: " << to->toString() << " at " << pos << std::endl; } TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { From ee4cd26393502c338bb0b151d97a52ab1bd8d2c4 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 16:17:46 -0700 Subject: [PATCH 088/100] save --- .../jit/codegen/cuda/inline_propagator.cpp | 18 +++++++++--------- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 1 - .../csrc/jit/codegen/cuda/transform_replay.cpp | 18 +++++++++--------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index cbe2e824b26cb..f76570c8e2c6d 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -327,9 +327,9 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - std::cout << "InlinePropagator::propagateTvPasC" << std::endl; - std::cout << " from: " << from->toString() << " at " << pos << std::endl; - std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; + // std::cout << "InlinePropagator::propagateTvPasC" << std::endl; + // std::cout << " from: " << from->toString() << " at " << pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -358,9 +358,9 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - std::cout << "InlinePropagator::propagateTvCasP" << std::endl; - std::cout << " from: " << from->toString() << " at " << pos << std::endl; - std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; + // std::cout << "InlinePropagator::propagateTvCasP" << std::endl; + // std::cout << " from: " << from->toString() << " at " << pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -381,9 +381,9 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } recordReplayedPos(to, from_pos); - std::cout << "InlinePropagator::propagateTvSibling" << std::endl; - std::cout << " from: " << from->toString() << " at " << from_pos << std::endl; - std::cout << " to: " << to->toString() << " at " << from_pos << std::endl; + // std::cout << "InlinePropagator::propagateTvSibling" << std::endl; + // std::cout << " from: " << from->toString() << " at " << from_pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << from_pos << std::endl; } // Try to find the aligned position on consumer's domain corresponding to the diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index cd4edd4f9b9b1..951245f4841c1 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -2137,7 +2137,6 @@ TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { // Reverse computeAt structure from previous test tv0->computeAt(tv4, -1); tv2->computeAt(tv4, -1); - fusion.print(); tv0->computeAt(tv7, -1); const int numel_x = 100; diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 19e50abaea54f..e3aec965b06f4 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -794,9 +794,9 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; - std::cout << "TransformPropagator::propagateTvPasC" << std::endl; - std::cout << " from: " << from->toString() << " at " << pos << std::endl; - std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; + // std::cout << "TransformPropagator::propagateTvPasC" << std::endl; + // std::cout << " from: " << from->toString() << " at " << pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -810,9 +810,9 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; - std::cout << "TransformPropagator::propagateTvCasP" << std::endl; - std::cout << " from: " << from->toString() << " at " << pos << std::endl; - std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; + // std::cout << "TransformPropagator::propagateTvCasP" << std::endl; + // std::cout << " from: " << from->toString() << " at " << pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -823,9 +823,9 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } replayed_pos_[to] = pos; - std::cout << "TransformPropagator::propagateTvSibling" << std::endl; - std::cout << " from: " << from->toString() << " at " << pos << std::endl; - std::cout << " to: " << to->toString() << " at " << pos << std::endl; + // std::cout << "TransformPropagator::propagateTvSibling" << std::endl; + // std::cout << " from: " << from->toString() << " at " << pos << std::endl; + // std::cout << " to: " << to->toString() << " at " << pos << std::endl; } TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { From b18a85a6d1e83f5d85607f4c2430922fc71766b7 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 17:12:39 -0700 Subject: [PATCH 089/100] fix --- .../jit/codegen/cuda/inline_propagator.cpp | 87 +++++++++++++++---- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index f76570c8e2c6d..eefb2ea451cfd 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -311,9 +311,9 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK(to_pos >= 0); - } + // if (mode_ != ComputeAtMode::MostInlined) { + // TORCH_CHECK(to_pos >= 0); + // } if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -342,9 +342,9 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK(to_pos >= 0); - } + // if (mode_ != ComputeAtMode::MostInlined) { + // TORCH_CHECK(to_pos >= 0); + // } if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -382,27 +382,76 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { } recordReplayedPos(to, from_pos); // std::cout << "InlinePropagator::propagateTvSibling" << std::endl; - // std::cout << " from: " << from->toString() << " at " << from_pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << from_pos << std::endl; + // std::cout << " from: " << from->toString() << " at " << from_pos << + // std::endl; std::cout << " to: " << to->toString() << " at " << from_pos << + // std::endl; } +namespace { + // Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. -void MaxProducerPosUpdater::handle(TensorView* consumer) { +// compute at position of producer domain. Used in computeAt pass only. No +// checking on actual producer-consumer relationship. +unsigned int getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto c2p_map = + BestEffortReplay::replayPasC( + producer, + consumer, + -1, + // Compute at root domain may not be valid here, as all + // producers don't have to be able to map into consumer at + // max producer position. Since computeAt should be valid + // and this mechanism is only intended to lower produce + // position of consumer, we can simply use the pairwise map. + PairwiseRootDomainMap(producer, consumer)) + .getReplay(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. unsigned int consumer_pos = consumer->nDims(); while (consumer_pos > 0) { - for (auto producer : ir_utils::producerTvsOf(consumer)) { - auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( - producer, consumer, consumer_pos); - if (producer_pos >= 0 && - producer_pos <= producer->getComputeAtPosition()) { - goto finished; - } + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &c2p_map](IterDomain* p_id) { + auto c_id_it = c2p_map.find(consumer_id); + if (c_id_it != c2p_map.end()) { + return c_id_it->second == p_id; + } + return false; + })) { + break; } consumer_pos--; } -finished: - consumer->setMaxProducer(consumer_pos, true); + + return consumer_pos; +} + +} // namespace + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. +void MaxProducerPosUpdater::handle(TensorView* consumer) { + unsigned int consumer_pos = 0; + for (auto producer : ir_utils::producerTvsOf(consumer)) { + consumer_pos = std::max( + consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer)); + } + consumer->setMaxProducer(consumer_pos); } void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) { From c07fd52177a4bc9d2fe0a1a61825514deec852f8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 22:21:01 -0700 Subject: [PATCH 090/100] symmetric --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 86 ++++++++++--- .../jit/codegen/cuda/transform_replay.cpp | 119 ++++++++++-------- .../csrc/jit/codegen/cuda/transform_replay.h | 1 + 3 files changed, 136 insertions(+), 70 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 951245f4841c1..179c7e1e0e8f2 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24080,28 +24080,80 @@ TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { } TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); + struct TransformPropagatorWithCheck : public TransformPropagator { + public: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override { + TransformPropagator::propagateTvPasC(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + std::cout << "propagateTvPasC" << std::endl; + std::cout << " from: " << from->toString() << std::endl; + std::cout << " to: " << to->toString() << std::endl; + std::cout << TransformReplay::getMatchedLeafPosWithoutReplayPasC( + to, from, from_pos) + << std::endl; + std::cout << to_pos << std::endl; + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC( + to, from, from_pos) == to_pos); + } + virtual void propagateTvCasP(TensorView* from, TensorView* to) override { + TransformPropagator::propagateTvCasP(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + std::cout << "propagateTvCasP" << std::endl; + std::cout << " from: " << from->toString() << std::endl; + std::cout << " to: " << to->toString() << std::endl; + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP( + to, from, from_pos) == to_pos); + } + virtual void propagateTvSibling(TensorView* from, TensorView* to) override { + TransformPropagator::propagateTvCasP(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK(from_pos == to_pos); + TORCH_CHECK(TransformReplay::fullSelfMatching(from, to)); + } + using TransformPropagator::TransformPropagator; + }; - TensorView* tv0 = makeContigTensor(1); - TensorView* tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); + { + Fusion fusion; + FusionGuard fg(&fusion); - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); - tv3->split(1, 2, false); + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); - TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + tv3->split(1, 2, false); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv2, 3) == 1); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv0, 1) >= 1 && - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv0, 1) <= 3); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + } + + { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 2}); + fusion.addOutput(tv1); + + tv0->split(1, 2, false); + + fusion.print(); + + TransformPropagatorWithCheck propagator(tv0); + MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + } } } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index bc555c5aa90be..b56cd173922cf 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -647,80 +647,93 @@ std::pair TransformReplay::replayCasP( namespace { int getMatchedLeafPosWithoutReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { + const TensorView* from, + const TensorView* to, + int from_pos, + bool from_consumer) { FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay"); - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); + const TensorView *producer, *consumer; + if (from_consumer) { + consumer = from; + producer = to; + } else { + consumer = to; + producer = from; + } - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); + const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + id_map root_map; + if (from_consumer) { + root_map = pairwise_map.mapConsumerToProducer( + consumer->domain(), producer->domain()); + } else { + root_map = pairwise_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); } - const auto consumer_domain = consumer->domain()->domain(); + // IterDomains in consumer root also in producer root + std::unordered_set mapped_from_roots; + for (auto entry : root_map) { + mapped_from_roots.emplace(entry.first); + } - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + const auto from_domain = from->domain()->domain(); - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); + auto mapped_from_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_from_roots, {from_domain.begin(), from_domain.end()}); - const auto producer_domain = producer->domain()->domain(); + std::unordered_set mapped_from_domain_ids( + mapped_from_domain_ids_vec.begin(), mapped_from_domain_ids_vec.end()); - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); + const auto to_domain = to->domain()->domain(); - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); + auto it_from = from_domain.begin(); + auto it_to = to_domain.begin(); - auto c2p_map = best_effort_PasC.getReplay(); + id_map replay_map; + if (from_consumer) { + auto best_effort_PasC = + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map); + replay_map = best_effort_PasC.getReplay(); + } else { + auto best_effort_CasP = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map); + replay_map = best_effort_CasP.getReplay(); + } - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } + int mismatched_from_pos = 0; + int mismatched_to_pos = 0; + while (it_from != from_domain.end()) { + auto from_id = *it_from; + if (!mapped_from_domain_ids.count(from_id)) { + ++it_from; + mismatched_from_pos++; + if (from_pos == mismatched_from_pos) { + return mismatched_to_pos; } continue; } + std::cout << from_id << std::endl; - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { + auto replay_it = replay_map.find(from_id); + if (replay_it == replay_map.end()) { break; } - if (it_producer == producer_domain.end()) { + if (it_to == to_domain.end()) { break; } - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } + auto to_id = *it_to; + + if (replay_it->second == to_id) { + ++mismatched_from_pos; + ++mismatched_to_pos; + ++it_from; + ++it_to; + if (from_pos == mismatched_from_pos) { + return mismatched_to_pos; } } else { break; @@ -735,7 +748,7 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC( const TensorView* producer, const TensorView* consumer, int consumer_pos) { - return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true); + return getMatchedLeafPosWithoutReplay(consumer, producer, consumer_pos, true); } int TransformReplay::getMatchedLeafPosWithoutReplayCasP( diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index d026de618c88f..9fafcf5c734c5 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -179,6 +179,7 @@ class TORCH_CUDA_CU_API TransformReplay { class TORCH_CUDA_CU_API TransformPropagator : public MaxRootDomainInfoSpanningTree::Propagator { + protected: std::unordered_map replayed_pos_; public: From 1bb4e029801dbd72205440d92cf30d6d8817018b Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 22:27:42 -0700 Subject: [PATCH 091/100] save --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 179c7e1e0e8f2..d39980116307d 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24086,13 +24086,6 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { TransformPropagator::propagateTvPasC(from, to); auto from_pos = replayed_pos_.at(from); auto to_pos = replayed_pos_.at(to); - std::cout << "propagateTvPasC" << std::endl; - std::cout << " from: " << from->toString() << std::endl; - std::cout << " to: " << to->toString() << std::endl; - std::cout << TransformReplay::getMatchedLeafPosWithoutReplayPasC( - to, from, from_pos) - << std::endl; - std::cout << to_pos << std::endl; TORCH_CHECK( TransformReplay::getMatchedLeafPosWithoutReplayPasC( to, from, from_pos) == to_pos); @@ -24101,9 +24094,6 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { TransformPropagator::propagateTvCasP(from, to); auto from_pos = replayed_pos_.at(from); auto to_pos = replayed_pos_.at(to); - std::cout << "propagateTvCasP" << std::endl; - std::cout << " from: " << from->toString() << std::endl; - std::cout << " to: " << to->toString() << std::endl; TORCH_CHECK( TransformReplay::getMatchedLeafPosWithoutReplayCasP( to, from, from_pos) == to_pos); @@ -24145,7 +24135,8 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { fusion.addInput(tv0); auto tv1 = sum(tv0, {0, 2}); - fusion.addOutput(tv1); + auto tv2 = sin(tv1); + fusion.addOutput(tv2); tv0->split(1, 2, false); From f2d17f3ef7f05636c42bb94e1f3707806c8b5cac Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 22:29:05 -0700 Subject: [PATCH 092/100] cleanup --- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 2 -- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 1 - 2 files changed, 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index d39980116307d..229a6feafb515 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24140,8 +24140,6 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { tv0->split(1, 2, false); - fusion.print(); - TransformPropagatorWithCheck propagator(tv0); MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index b56cd173922cf..914159f629f9f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -714,7 +714,6 @@ int getMatchedLeafPosWithoutReplay( } continue; } - std::cout << from_id << std::endl; auto replay_it = replay_map.find(from_id); if (replay_it == replay_map.end()) { From 4521302b50d3a8b4be4fc77de1ba56053681f290 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 22:50:08 -0700 Subject: [PATCH 093/100] skip both from and to ids --- .../jit/codegen/cuda/transform_replay.cpp | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 914159f629f9f..05e0fac78fe9e 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -672,7 +672,7 @@ int getMatchedLeafPosWithoutReplay( producer->domain(), consumer->domain()); } - // IterDomains in consumer root also in producer root + // IterDomains in `from` root also in `to` root std::unordered_set mapped_from_roots; for (auto entry : root_map) { mapped_from_roots.emplace(entry.first); @@ -686,8 +686,20 @@ int getMatchedLeafPosWithoutReplay( std::unordered_set mapped_from_domain_ids( mapped_from_domain_ids_vec.begin(), mapped_from_domain_ids_vec.end()); + // IterDomains in `to` root also in `from` root + std::unordered_set mapped_to_roots; + for (auto entry : root_map) { + mapped_to_roots.emplace(entry.second); + } + const auto to_domain = to->domain()->domain(); + auto mapped_to_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_to_roots, {to_domain.begin(), to_domain.end()}); + + std::unordered_set mapped_to_domain_ids( + mapped_to_domain_ids_vec.begin(), mapped_to_domain_ids_vec.end()); + auto it_from = from_domain.begin(); auto it_to = to_domain.begin(); @@ -704,17 +716,24 @@ int getMatchedLeafPosWithoutReplay( int mismatched_from_pos = 0; int mismatched_to_pos = 0; - while (it_from != from_domain.end()) { + while (it_from != from_domain.end() && it_to != to_domain.end()) { auto from_id = *it_from; if (!mapped_from_domain_ids.count(from_id)) { ++it_from; - mismatched_from_pos++; + ++mismatched_from_pos; if (from_pos == mismatched_from_pos) { return mismatched_to_pos; } continue; } + auto to_id = *it_to; + if (!mapped_to_domain_ids.count(to_id)) { + ++it_to; + ++mismatched_to_pos; + continue; + } + auto replay_it = replay_map.find(from_id); if (replay_it == replay_map.end()) { break; @@ -724,8 +743,6 @@ int getMatchedLeafPosWithoutReplay( break; } - auto to_id = *it_to; - if (replay_it->second == to_id) { ++mismatched_from_pos; ++mismatched_to_pos; From 6793f276f03ced34f65ca3c5fd49886ea3b4832a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 30 Jun 2022 23:51:19 -0700 Subject: [PATCH 094/100] save --- .../jit/codegen/cuda/transform_replay.cpp | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 05e0fac78fe9e..229d6bd102901 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -673,13 +673,13 @@ int getMatchedLeafPosWithoutReplay( } // IterDomains in `from` root also in `to` root + const auto from_domain = from->domain()->domain(); + std::unordered_set mapped_from_roots; for (auto entry : root_map) { mapped_from_roots.emplace(entry.first); } - const auto from_domain = from->domain()->domain(); - auto mapped_from_domain_ids_vec = DependencyCheck::getAllValsBetween( mapped_from_roots, {from_domain.begin(), from_domain.end()}); @@ -687,18 +687,27 @@ int getMatchedLeafPosWithoutReplay( mapped_from_domain_ids_vec.begin(), mapped_from_domain_ids_vec.end()); // IterDomains in `to` root also in `from` root - std::unordered_set mapped_to_roots; - for (auto entry : root_map) { - mapped_to_roots.emplace(entry.second); - } - const auto to_domain = to->domain()->domain(); - auto mapped_to_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_to_roots, {to_domain.begin(), to_domain.end()}); + std::unordered_set mapped_to_domain_ids; + + // Unmappable dims in the `to` tensor is allowed to be skipped only when `to` + // is a consumer. In a C->P replay, unmappable dims will be reductions, and we + // want reductions to be pushed to the end. + if (to == consumer) { + std::unordered_set mapped_to_roots; + for (auto entry : root_map) { + mapped_to_roots.emplace(entry.second); + } + + auto mapped_to_domain_ids_vec = DependencyCheck::getAllValsBetween( + mapped_to_roots, {to_domain.begin(), to_domain.end()}); - std::unordered_set mapped_to_domain_ids( - mapped_to_domain_ids_vec.begin(), mapped_to_domain_ids_vec.end()); + mapped_to_domain_ids.insert( + mapped_to_domain_ids_vec.begin(), mapped_to_domain_ids_vec.end()); + } else { + mapped_to_domain_ids.insert(to_domain.begin(), to_domain.end()); + } auto it_from = from_domain.begin(); auto it_to = to_domain.begin(); @@ -716,19 +725,24 @@ int getMatchedLeafPosWithoutReplay( int mismatched_from_pos = 0; int mismatched_to_pos = 0; - while (it_from != from_domain.end() && it_to != to_domain.end()) { + while (it_from != from_domain.end()) { + if (from_pos == mismatched_from_pos) { + return mismatched_to_pos; + } + auto from_id = *it_from; - if (!mapped_from_domain_ids.count(from_id)) { + if (mapped_from_domain_ids.count(from_id) == 0) { ++it_from; ++mismatched_from_pos; - if (from_pos == mismatched_from_pos) { - return mismatched_to_pos; - } continue; } + if (it_to == to_domain.end()) { + return -1; + } + auto to_id = *it_to; - if (!mapped_to_domain_ids.count(to_id)) { + if (mapped_to_domain_ids.count(to_id) == 0) { ++it_to; ++mismatched_to_pos; continue; @@ -736,11 +750,7 @@ int getMatchedLeafPosWithoutReplay( auto replay_it = replay_map.find(from_id); if (replay_it == replay_map.end()) { - break; - } - - if (it_to == to_domain.end()) { - break; + return -1; } if (replay_it->second == to_id) { @@ -748,13 +758,13 @@ int getMatchedLeafPosWithoutReplay( ++mismatched_to_pos; ++it_from; ++it_to; - if (from_pos == mismatched_from_pos) { - return mismatched_to_pos; - } } else { - break; + return -1; } } + if (from_pos == mismatched_from_pos) { + return mismatched_to_pos; + } return -1; } From 8de84bbd10c7ab446ee63d8e316d7faf8d2d1216 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 01:23:01 -0700 Subject: [PATCH 095/100] fix --- .../jit/codegen/cuda/transform_replay.cpp | 213 ++++++++++-------- 1 file changed, 114 insertions(+), 99 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 229d6bd102901..17171b325e0aa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -644,147 +644,162 @@ std::pair TransformReplay::replayCasP( return replayCasP(consumer, producer, compute_at_axis, root_map); } -namespace { +int TransformReplay::getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC"); + + const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + id_map c2p_root_map = pairwise_map.mapConsumerToProducer( + consumer->domain(), producer->domain()); -int getMatchedLeafPosWithoutReplay( - const TensorView* from, - const TensorView* to, - int from_pos, - bool from_consumer) { - FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay"); + // IterDomains in `consumer` root also in `producer` root + const auto consumer_domain = consumer->domain()->domain(); - const TensorView *producer, *consumer; - if (from_consumer) { - consumer = from; - producer = to; - } else { - consumer = to; - producer = from; + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); } - const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); - id_map root_map; - if (from_consumer) { - root_map = pairwise_map.mapConsumerToProducer( - consumer->domain(), producer->domain()); - } else { - root_map = pairwise_map.mapProducerToConsumer( - producer->domain(), consumer->domain()); - } + auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - // IterDomains in `from` root also in `to` root - const auto from_domain = from->domain()->domain(); + std::unordered_set unskippable_consumer_ids( + unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); - std::unordered_set mapped_from_roots; - for (auto entry : root_map) { - mapped_from_roots.emplace(entry.first); - } + // IterDomains in `producer` root also in `consumer` root + const auto producer_domain = producer->domain()->domain(); - auto mapped_from_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_from_roots, {from_domain.begin(), from_domain.end()}); + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); - std::unordered_set mapped_from_domain_ids( - mapped_from_domain_ids_vec.begin(), mapped_from_domain_ids_vec.end()); + id_map c2p_map = + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + .getReplay(); - // IterDomains in `to` root also in `from` root - const auto to_domain = to->domain()->domain(); + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + if (consumer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } - std::unordered_set mapped_to_domain_ids; + auto consumer_id = *it_consumer; + if (unskippable_consumer_ids.count(consumer_id) == 0) { + ++it_consumer; + ++mismatched_consumer_pos; + continue; + } - // Unmappable dims in the `to` tensor is allowed to be skipped only when `to` - // is a consumer. In a C->P replay, unmappable dims will be reductions, and we - // want reductions to be pushed to the end. - if (to == consumer) { - std::unordered_set mapped_to_roots; - for (auto entry : root_map) { - mapped_to_roots.emplace(entry.second); + if (it_producer == producer_domain.end()) { + return -1; } - auto mapped_to_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_to_roots, {to_domain.begin(), to_domain.end()}); + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + return -1; + } - mapped_to_domain_ids.insert( - mapped_to_domain_ids_vec.begin(), mapped_to_domain_ids_vec.end()); - } else { - mapped_to_domain_ids.insert(to_domain.begin(), to_domain.end()); + auto producer_id = *it_producer; + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + } else { + return -1; + } + } + if (consumer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; } + return -1; +} + +int TransformReplay::getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP"); + + const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + id_map p2c_root_map = pairwise_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); + + // IterDomains in `producer` root that are not reduction + const auto producer_domain = producer->domain()->domain(); + auto unskippable_producer_ids_vec = + TensorDomain::noReductions(producer_domain); + std::unordered_set unskippable_producer_ids( + unskippable_producer_ids_vec.begin(), unskippable_producer_ids_vec.end()); - auto it_from = from_domain.begin(); - auto it_to = to_domain.begin(); + // IterDomains in `consumer` root also in `producer` root + const auto consumer_domain = consumer->domain()->domain(); - id_map replay_map; - if (from_consumer) { - auto best_effort_PasC = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map); - replay_map = best_effort_PasC.getReplay(); - } else { - auto best_effort_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map); - replay_map = best_effort_CasP.getReplay(); + std::unordered_set mapped_consumer_roots; + for (auto entry : p2c_root_map) { + mapped_consumer_roots.emplace(entry.second); } - int mismatched_from_pos = 0; - int mismatched_to_pos = 0; - while (it_from != from_domain.end()) { - if (from_pos == mismatched_from_pos) { - return mismatched_to_pos; + auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set unskippable_consumer_ids( + unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); + + auto it_producer = producer_domain.begin(); + auto it_consumer = consumer_domain.begin(); + + id_map replay_map = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map) + .getReplay(); + + int mismatched_producer_pos = 0; + int mismatched_consumer_pos = 0; + while (it_producer != producer_domain.end()) { + if (producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; } - auto from_id = *it_from; - if (mapped_from_domain_ids.count(from_id) == 0) { - ++it_from; - ++mismatched_from_pos; + auto producer_id = *it_producer; + if (unskippable_producer_ids.count(producer_id) == 0) { + ++it_producer; + ++mismatched_producer_pos; continue; } - if (it_to == to_domain.end()) { + if (it_consumer == consumer_domain.end()) { return -1; } - auto to_id = *it_to; - if (mapped_to_domain_ids.count(to_id) == 0) { - ++it_to; - ++mismatched_to_pos; + auto consumer_id = *it_consumer; + if (unskippable_consumer_ids.count(consumer_id) == 0) { + ++it_consumer; + ++mismatched_consumer_pos; continue; } - auto replay_it = replay_map.find(from_id); + auto replay_it = replay_map.find(producer_id); if (replay_it == replay_map.end()) { return -1; } - if (replay_it->second == to_id) { - ++mismatched_from_pos; - ++mismatched_to_pos; - ++it_from; - ++it_to; + if (replay_it->second == consumer_id) { + ++mismatched_producer_pos; + ++mismatched_consumer_pos; + ++it_producer; + ++it_consumer; } else { return -1; } } - if (from_pos == mismatched_from_pos) { - return mismatched_to_pos; + if (producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; } return -1; } -} // namespace - -int TransformReplay::getMatchedLeafPosWithoutReplayPasC( - const TensorView* producer, - const TensorView* consumer, - int consumer_pos) { - return getMatchedLeafPosWithoutReplay(consumer, producer, consumer_pos, true); -} - -int TransformReplay::getMatchedLeafPosWithoutReplayCasP( - const TensorView* consumer, - const TensorView* producer, - int producer_pos) { - return getMatchedLeafPosWithoutReplay( - producer, consumer, producer_pos, false); -} - bool TransformReplay::fullSelfMatching( const TensorView* replay, const TensorView* target) { From 18c89233b72d48761ae47b8a789357e70ab287b9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 01:52:37 -0700 Subject: [PATCH 096/100] save --- .../jit/codegen/cuda/inline_propagator.cpp | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index eefb2ea451cfd..732121a21cdb9 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -311,9 +311,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - // if (mode_ != ComputeAtMode::MostInlined) { - // TORCH_CHECK(to_pos >= 0); - // } + if (true || mode_ != ComputeAtMode::MostInlined) { + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from consumer ", + from, + " to producer ", + to, + " because this would require replay."); + } if (to_pos < 0) { auto replay = TransformReplay::replayPasC(to, from, pos); TORCH_INTERNAL_ASSERT( @@ -342,9 +348,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); // TODO: Can we make TransformPropagator do the transformation, and // InlinePropagator only set the CA positions? - // if (mode_ != ComputeAtMode::MostInlined) { - // TORCH_CHECK(to_pos >= 0); - // } + if (true || mode_ != ComputeAtMode::MostInlined) { + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from producer ", + from, + " to consumer ", + to, + " because this would require replay."); + } if (to_pos < 0) { auto replay = TransformReplay::replayCasP(to, from, pos); TORCH_INTERNAL_ASSERT( From 2987b48c657439f404741607cf94f4860efd09e9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 01:58:50 -0700 Subject: [PATCH 097/100] fix most inlined compute at --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 35a35acfe0a97..e0271ddd81a6a 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -165,7 +165,7 @@ void ComputeAt::runAt( TensorView* consumer, unsigned int consumer_position, ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::run"); + FUSER_PERF_SCOPE("ComputeAt::runAt"); // Make sure the correct fusion is setup between this and consumer. TORCH_CHECK( @@ -175,6 +175,10 @@ void ComputeAt::runAt( consumer, " are not in the same fusion."); + if (mode == ComputeAtMode::MostInlined) { + consumer_position = -1; + } + FusionGuard fg(producer->fusion()); auto selected = getPropagationSubgraph(producer, consumer); @@ -206,6 +210,10 @@ void ComputeAt::runWith( consumer, " are not in the same fusion."); + if (mode == ComputeAtMode::MostInlined) { + producer_position = -1; + } + FusionGuard fg(producer->fusion()); auto selected = getPropagationSubgraph(producer, consumer); From 30a621b5b585ca62901b78b04027d685aac7f985 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 02:01:14 -0700 Subject: [PATCH 098/100] fix --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 10 ---------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index e0271ddd81a6a..ed0f9fb271d57 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -176,7 +176,7 @@ void ComputeAt::runAt( " are not in the same fusion."); if (mode == ComputeAtMode::MostInlined) { - consumer_position = -1; + consumer_position = consumer->nDims(); } FusionGuard fg(producer->fusion()); @@ -211,7 +211,7 @@ void ComputeAt::runWith( " are not in the same fusion."); if (mode == ComputeAtMode::MostInlined) { - producer_position = -1; + producer_position = producer->nDims(); } FusionGuard fg(producer->fusion()); diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 732121a21cdb9..09da9a4260b00 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -333,9 +333,6 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - // std::cout << "InlinePropagator::propagateTvPasC" << std::endl; - // std::cout << " from: " << from->toString() << " at " << pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -370,9 +367,6 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { to_pos = replay.second; } recordReplayedPos(to, to_pos); - // std::cout << "InlinePropagator::propagateTvCasP" << std::endl; - // std::cout << " from: " << from->toString() << " at " << pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << to_pos << std::endl; } void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -393,10 +387,6 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } recordReplayedPos(to, from_pos); - // std::cout << "InlinePropagator::propagateTvSibling" << std::endl; - // std::cout << " from: " << from->toString() << " at " << from_pos << - // std::endl; std::cout << " to: " << to->toString() << " at " << from_pos << - // std::endl; } namespace { From 0f2249d9027c943d3b3ca6632466dc2da7aad77e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 02:01:51 -0700 Subject: [PATCH 099/100] cleanup --- torch/csrc/jit/codegen/cuda/transform_replay.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index ea8219fb6cf33..17171b325e0aa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -848,9 +848,6 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; - // std::cout << "TransformPropagator::propagateTvPasC" << std::endl; - // std::cout << " from: " << from->toString() << " at " << pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { @@ -864,9 +861,6 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { new_pos = replay.second; } replayed_pos_[to] = new_pos; - // std::cout << "TransformPropagator::propagateTvCasP" << std::endl; - // std::cout << " from: " << from->toString() << " at " << pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << new_pos << std::endl; } void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { @@ -877,9 +871,6 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) { to->setDomain(replay); } replayed_pos_[to] = pos; - // std::cout << "TransformPropagator::propagateTvSibling" << std::endl; - // std::cout << " from: " << from->toString() << " at " << pos << std::endl; - // std::cout << " to: " << to->toString() << " at " << pos << std::endl; } TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { From 4c3d9a70b7e8e1df6185155e389f67ddb18f2b58 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 1 Jul 2022 02:07:03 -0700 Subject: [PATCH 100/100] save --- torch/csrc/jit/codegen/cuda/inline_propagator.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp index 09da9a4260b00..195ef3e67a188 100644 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -309,9 +309,7 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) { int pos = getReplayPosPasC(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); - // TODO: Can we make TransformPropagator do the transformation, and - // InlinePropagator only set the CA positions? - if (true || mode_ != ComputeAtMode::MostInlined) { + if (mode_ != ComputeAtMode::MostInlined) { TORCH_CHECK( to_pos >= 0, "Unable to propagate CA position from consumer ", @@ -343,9 +341,7 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) { int pos = getReplayPosCasP(to, from); auto to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); - // TODO: Can we make TransformPropagator do the transformation, and - // InlinePropagator only set the CA positions? - if (true || mode_ != ComputeAtMode::MostInlined) { + if (mode_ != ComputeAtMode::MostInlined) { TORCH_CHECK( to_pos >= 0, "Unable to propagate CA position from producer ",