diff --git a/build_variables.bzl b/build_variables.bzl index 44d76dd9da723a..2e8ff13bed1afa 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -699,6 +699,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", + "torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", "torch/csrc/jit/codegen/cuda/non_divisible_split.cpp", "torch/csrc/jit/codegen/cuda/ops/alias.cpp", diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp new file mode 100644 index 00000000000000..46ffef6b3bfb05 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -0,0 +1,311 @@ +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +bool MaxInfoPropagator::Information::operator>(const Information& r) const { + return r < *this; +} + +bool MaxInfoPropagator::Information::operator==(const Information& r) const { + return !(r < *this) && !(*this < r); +} + +// Dijkstra +void MaxInfoPropagator::run() { + // A set that allows us to quickly tell if a tensor has been replayed. If yes, + // then we will not bother computing if a new path to this tensor is worth + // taking (because the answer is always not worth) + std::unordered_set replayed; + + // A sorted list of possible next steps. The list is sorted in the order of + // ascending amount of preserved information about the reference tensor. The + // back of the list preserves the most amount of information about the + // reference tensor, and should always be the next step to take. We use + // std::list instead of std::priority_queue because C++'s + // std::priority_queue does not support increase-key, and might not be + // deterministic either. + std::list propagation(1); + propagation.back().from = nullptr; + propagation.back().to = reference; + propagation.back().info_to = reference_info; + + // Insert the given next hop the correct position in `propagation`. If there + // is an existing next hop that preserves more information, then we will just + // discard `info`. + auto insertNextHopInfo = [&](const NextHopInfo& info) { + if (!*(info.info_from)) { + // When there is no more information about the starting tensor, + // we are not interested in continuing the propagation. + return; + } + // Find if there is already a path to the dest tensor + auto existing = std::find_if( + propagation.begin(), propagation.end(), [&](const NextHopInfo& i) { + return i.to == info.to; + }); + // Only insert if there is no existing path to the dest tensor, or the new + // path preserves more information about the starting tensor. + if (existing == propagation.end() || *existing < info) { + if (existing != propagation.end()) { + propagation.erase(existing); + } + auto pos = std::upper_bound(propagation.begin(), propagation.end(), info); + propagation.insert(pos, info); + } + }; + + while (!propagation.empty()) { + auto next_hop = propagation.back(); + propagation.pop_back(); + + if (next_hop.from != nullptr) { + // nullptr used to start from reference + switch (next_hop.type) { + case NextHopType::C_AS_P: + propagateTvCasP(next_hop.from, next_hop.to); + break; + case NextHopType::P_AS_C: + propagateTvPasC(next_hop.from, next_hop.to); + break; + } + } + replayed.emplace(next_hop.to); + + for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { + if (replayed.count(consumer_tv)) { + continue; + } + insertNextHopInfo( + {.type = NextHopType::C_AS_P, + .from = next_hop.to, + .to = consumer_tv, + .info_from = next_hop.info_to, + .info_to = + computeInfoCasP(next_hop.to, consumer_tv, next_hop.info_to)}); + } + + for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) { + if (replayed.count(producer_tv)) { + continue; + } + insertNextHopInfo( + {.type = NextHopType::P_AS_C, + .from = next_hop.to, + .to = producer_tv, + .info_from = next_hop.info_to, + .info_to = + computeInfoPasC(next_hop.to, producer_tv, next_hop.info_to)}); + } + } +} + +MaxRootDomainInfoPropagator::RootDomainInfo::operator bool() const { + return !info.empty(); +} + +bool MaxRootDomainInfoPropagator::RootDomainInfo::operator<( + const MaxInfoPropagator::Information& r) const { + auto rr = dynamic_cast(r); + if (info.size() != rr.info.size()) { + return info.size() < rr.info.size(); + } + size_t l_complete = + std::count_if(info.begin(), info.end(), [](const RootIDInfo& i) { + return i.is_complete; + }); + size_t r_complete = + std::count_if(rr.info.begin(), rr.info.end(), [](const RootIDInfo& i) { + return i.is_complete; + }); + return l_complete < r_complete; +} + +namespace { + +// Given `root_ids`, a list of IDs in the root domain of `tv`, find their +// corresponding IDs in the rfactor domain of `tv`. +std::unordered_set mapRootToRFactor( + TensorView* tv, + const std::unordered_set& root_ids) { + std::unordered_set mapped_rfactor_ids; + const auto& rfactor_dom = tv->getMaybeRFactorDomain(); + for (auto id : rfactor_dom) { + if (root_ids.count(id) > 0) { + mapped_rfactor_ids.emplace(id); + continue; + } + for (auto root_id : root_ids) { + if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) { + mapped_rfactor_ids.emplace(id); + break; + } + } + } + return mapped_rfactor_ids; +} + +// Given `rfactor_ids`, a list of IDs in the rfactor domain of `tv`, find their +// corresponding IDs in the root domain of `tv`. +std::unordered_set mapRFactorToRoot( + TensorView* tv, + const std::unordered_set& rfactor_ids) { + std::unordered_set mapped_root_ids; + for (auto id : tv->getRootDomain()) { + if (rfactor_ids.count(id) > 0) { + mapped_root_ids.emplace(id); + continue; + } + for (auto rfactor_id : rfactor_ids) { + if (DependencyCheck::isDependencyOf(id, rfactor_id)) { + mapped_root_ids.emplace(id); + break; + } + } + } + return mapped_root_ids; +} + +} // namespace + +// Given the preserved reference root ID info of a producer, compute +// the corresponding info in consumer. The given info may be represented by +// producer's root domain, or rfactor domain, depending on how we reached the +// producer during propagation. If the given info is already represented with +// producer's rfactor domain, then we directly map it to the consumer's root +// domain. If the given info is represented with producer's root domain, we need +// to first map it to the rfactor domain of the producer, then we can map it to +// the consumer's root domain. The computed info will be represented by root +// domain as root domain contains the raw information. +std::shared_ptr MaxRootDomainInfoPropagator:: + computeInfoCasP( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) { + RootDomainInfo result; + + TensorView* producer = from; + TensorView* consumer = to; + const auto& producer_root_id_info = + std::dynamic_pointer_cast(from_info)->info; + + auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto p2c_map = pairwise_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); + + for (auto& info : producer_root_id_info) { + RootIDInfo consumer_info; + consumer_info.is_complete = info.is_complete; + consumer_info.is_rfactor = false; + + // mapped root ids in producer -> mapped rfactor ids in producer + std::unordered_set producer_mapped_rfactor_ids; + if (producer->hasRFactor() && !info.is_rfactor) { + producer_mapped_rfactor_ids = mapRootToRFactor(producer, info.mapped_ids); + } else { + producer_mapped_rfactor_ids = info.mapped_ids; + } + + // mapped rfactor ids in producer -> mapped root ids in consumer + for (auto producer_id : producer_mapped_rfactor_ids) { + auto it = p2c_map.find(producer_id); + if (it != p2c_map.end()) { + consumer_info.mapped_ids.insert(it->second); + } else { + consumer_info.is_complete = false; + } + } + + // If at least one root id in the consumer contains information + // of this starting root id, then keep this record + if (!consumer_info.mapped_ids.empty()) { + result.info.push_back(consumer_info); + } + } + return std::make_shared(std::move(result)); +} + +// Given the preserved reference root ID info of a consumer, compute +// the corresponding info in producer. The given info may be represented by +// consumer's root domain, or rfactor domain, depending on how we reached the +// consumer during propagation. If the given info is already represented with +// consumer's root domain, then we directly map it to the producer's rfactor +// domain. If the given info is represented with consumer's rfactor domain, we +// need to first map it to the root domain of the consumer, then we can map it +// to the producer's rfactor domain. The computed info will be represented by +// rfactor domain as rfactor domain contains the raw information. +std::shared_ptr MaxRootDomainInfoPropagator:: + computeInfoPasC( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) { + RootDomainInfo result; + + TensorView* producer = to; + TensorView* consumer = from; + const auto& consumer_root_id_info = + std::dynamic_pointer_cast(from_info)->info; + + auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto c2p_map = pairwise_map.mapConsumerToProducer( + consumer->domain(), producer->domain()); + + for (auto& info : consumer_root_id_info) { + RootIDInfo producer_info; + producer_info.is_complete = info.is_complete; + producer_info.is_rfactor = true; + + // mapped rfactor ids in consumer -> mapped root ids in consumer + std::unordered_set consumer_mapped_root_ids; + if (info.is_rfactor && consumer->hasRFactor()) { + consumer_mapped_root_ids = mapRFactorToRoot(consumer, info.mapped_ids); + } else { + consumer_mapped_root_ids = info.mapped_ids; + } + + // mapped root ids in consumer -> mapped rfactor ids in producer + for (auto consumer_id : consumer_mapped_root_ids) { + auto it = c2p_map.find(consumer_id); + if (it != c2p_map.end()) { + producer_info.mapped_ids.insert(it->second); + } else { + producer_info.is_complete = false; + } + } + + // We will stop at the rfactor ids in producer, and will not further map + // them into root ids in producer. This means, we only keep the unprocessed + // raw information of a tensor. This behavior is important to make sure that + // info is as accurate as possible throughout the propagation. + // + // For example, if we do a C->P->C' propagation, we want to do + // C(root) -> P(rfactor) -> C'(root) + // instead of + // C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root) + // + // and the above two paths do lead to different results: + // + // For example if you have a producer tensor + // root domain: [I1, I2] + // rfactor domain: [I3, I5] + // where I3, I4 = split(I1), I5 = merge(I4, I2) + // Then the P(rfactor) -> P(root) -> P(rfactor) could lead to + // P(rfactor: {I5}) -> P(root: {I1, I2}) -> P(rfactor: {I3, I5}) + // which is not correct + + // If at least one root id in the producer contains information + // of this starting root id, then keep this record + if (!producer_info.mapped_ids.empty()) { + result.info.push_back(producer_info); + } + } + return std::make_shared(std::move(result)); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h new file mode 100644 index 00000000000000..aebca46a24bab8 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -0,0 +1,173 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +/* + * MaxInfoPropagator is a visitor for TensorViews. It starts from a reference + * tensor, and the information about the reference tensor that we want to + * preserve. It walks the DAG using the Dijkstra algorithm from the reference + * tensor to other tensors in the graph. Each step in the propagation will be + * called with `propagateTvPasC` or `propagateTvCasP` in the order that the + * maximum amount of the given information is being preserved. Every tensor in + * the graph is visited only once. + * + * MaxInfoPropagator is an abstract class that has no idea about what + * propagation we want to do and what "information" means. In order to use this + * class, the user needs to specify the following thing: + * - a subclass of `Information`: a class that stores the information about the + * reference tensor. The subclass has to define `operator<` which is used to + * tell which path contains more information, and `operator bool` which is + * used to tell if there is any information stored. + * - propagateTvPasC, propagateTvCasP: the function that modifies the `to` + * tensor according to the `from` tensor and its stored information. + * - computeInfoPasC, computeInfoCasP: the function that computes the + * information of the `to` tensor from the information of the `from` tensor. + */ +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +class TORCH_CUDA_CU_API MaxInfoPropagator { + protected: + struct Information { + // returns true if there is any info about the root domain of the reference + // tensor, returns false if there is no info about the root domain of the + // reference tensor. + virtual operator bool() const = 0; + // l < r means l contains a smaller amount of information about the starting + // tensor than r. + virtual bool operator<(const Information& r) const = 0; + // l > r means l contains a bigger amount of information about the starting + // tensor than r. + bool operator>(const Information& r) const; + // l == r means it is hard to tell which one of then contains more + // information + bool operator==(const Information& r) const; + }; + + private: + enum class NextHopType { + C_AS_P, + P_AS_C, + }; + + // This is a helper struct that contains all the information about the next + // step in the Dijkstra algorithm + struct NextHopInfo { + NextHopType type; + TensorView* from = nullptr; + TensorView* to; + + std::shared_ptr info_from; + std::shared_ptr info_to; + + bool operator<(const NextHopInfo& r) const { + return *info_to < *(r.info_to); + } + }; + + TensorView* reference; + std::shared_ptr reference_info; + + protected: + virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0; + virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0; + virtual std::shared_ptr computeInfoPasC( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) = 0; + virtual std::shared_ptr computeInfoCasP( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) = 0; + + public: + MaxInfoPropagator( + TensorView* reference, + std::shared_ptr reference_info) + : reference(reference), reference_info(reference_info){}; + void run(); +}; + +// MaxRootDomainInfoPropagator is a MaxInfoPropagator the does propagation along +// the path that perserves the most amount of root domain information about the +// reference tensor. +// +// During the propagation, we explicitly keep track of the information about +// which reference tensor's root ID's information is preserved, and to which +// level. This information is stored as a vector of `RootIDInfo`, where each +// item in the vector correspond to one ID in the reference tensor's root +// domain. +class TORCH_CUDA_CU_API MaxRootDomainInfoPropagator : public MaxInfoPropagator { + protected: + // This is a struct storing how the information about a root ID in the + // starting tensor is preserved during propagation. If during propagation, we + // reached a tensor called the "current" tensor, we are interested in the + // following information: + // - Which reference tensor's root ID's information does the current tensor + // contains? Each RootIDInfo object should correspond to one reference + // tensor's root ID, but we don't need to store this ID explicitly. + // - For this reference tensor's root ID, what are its corresponding IDs in + // the current tensor's root/rfactor domain? + // - Is the current tensor's information about this reference tensor's root ID + // complete? + struct RootIDInfo { + // Each object of this class correspond to one root ID in the reference + // tensor, but we do not need to explicitly store this ID. + + // The IDs in the current tensor's root or rfactor domain that contains + // information of the corresponding reference tensor's root ID. Whether we + // are using root domain or rfactor domain depends on how we reached the + // current tensor during propagation. `is_rfactor` tells us whether the IDs + // contained in `mapped_ids` are from the root domain or the rfactor domain. + std::unordered_set mapped_ids; + + // Does `mapped_ids` contain all the IDs required to recompute the + // corresponding reference tensor's root ID? For example, if we have + // t1 = input tensor of shape (20,) + // t2 = view(t1, {4, 5}) + // t3 = sum(t2, {1}) + // t4 = set(t3) + // and we start the propagation from t1, then t2 and t3's information about + // t1 is complete, but t4 is not because one axis is missing. + bool is_complete; + + // Is `mapped_ids` from the root domain or rfactor domain of the current + // tensor? We only store IDs from one of them, depending on how we reach the + // current tensor during propagation. If we reached the current tensor from + // a consumer, then `mapped_ids` containes IDs in the current tensor's + // rfactor domain because the rfactor domain contains raw information. If we + // reached the current tensor from a producer, then `mapped_ids` containes + // IDs in the current tensor's root domain because the root domain contains + // raw information. + bool is_rfactor; + + RootIDInfo() = default; + }; + + struct RootDomainInfo : public Information { + std::vector info; + operator bool() const override; + bool operator<(const Information& r) const override; + }; + + virtual std::shared_ptr computeInfoPasC( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) override; + virtual std::shared_ptr computeInfoCasP( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) override; + + public: + using MaxInfoPropagator::MaxInfoPropagator; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 5a482fc321115b..f60c97ff517bee 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -850,7 +850,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - TransformPropagator::from(reference_tv); + TransformPropagator(reference_tv).run(); scheduler_utils::parallelizeAllLike(reference_tv, all_tvs); if (params.vectorize) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index f8a7e04a714309..93e4be71fef035 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -219,7 +219,7 @@ void multiReductionInliner( std::vector reduction_tvs, std::vector cached_inputs, std::vector> cached_outputs) { - TransformPropagator::from(reference_tv); + TransformPropagator(reference_tv).run(); // Apply rfactor to all reductions if applicable std::vector rfactor_tvs; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 01dc590099cdf0..d4ffb22447caa9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -13795,7 +13795,7 @@ TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { tv3->reorder({{4, 2}}); // [bidx, unswitch, vectorize{2}, unroll{2}, tidx] - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); @@ -14433,7 +14433,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { // Split inner-most dim tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); c0->computeAt(tv2, -2); c1->computeAt(tv2, -2); @@ -14495,7 +14495,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { } tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); c0->computeAt(tv2, -2); c1->computeAt(tv2, -2); @@ -15154,7 +15154,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { tv4->split(0, 3); tv4->split(0, 2); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv0->computeAt(tv2, 2); tv3->computeAt(tv4, 2); @@ -16524,7 +16524,7 @@ TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { tv1->split(1, 32); auto tv1_rf = tv1->rFactor({1}); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -16570,7 +16570,7 @@ TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { tv1_rf->axis(-1)->padToMultipleOfWarp(32); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -16618,7 +16618,7 @@ TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -16659,7 +16659,7 @@ TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -16703,7 +16703,7 @@ TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { tv1_rf->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-2)->padToMultipleOfWarp(); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv0->axis(-2)->parallelize(ParallelType::TIDx); tv0_cache->axis(-2)->parallelize(ParallelType::TIDx); tv2->axis(-2)->parallelize(ParallelType::TIDx); @@ -16750,7 +16750,7 @@ TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { tv1_rf->axis(-1)->padToMultipleOfWarp(32); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -16832,7 +16832,7 @@ TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); tv2_rf->axis(-1)->padToMultipleOfWarp(); - TransformPropagator::from(tv2_rf); + TransformPropagator(tv2_rf).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -16878,7 +16878,7 @@ TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); tv1->axis(1)->parallelize(ParallelType::Unroll); - TransformPropagator::from(tv1_rf); + TransformPropagator(tv1_rf).run(); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(1)->parallelize(ParallelType::Unroll); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -17080,7 +17080,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) { tv1->split(0, 10); tv1->split(0, 33); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); auto tv4 = tv1->rFactor({-1}); auto tv5 = tv1->rFactor({-1}); @@ -17133,7 +17133,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) { tv1->split(1, 7); tv1->split(0, 11); tv1->reorder({{1, 2}, {2, 1}}); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv1->axis(0)->parallelize(ParallelType::TIDy); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -17181,7 +17181,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) { fusion.addOutput(tv3); tvs2.avg->split(0, 4); - TransformPropagator::from(tvs2.avg); + TransformPropagator(tvs2.avg).run(); auto rtvs2 = tvs2.rFactor({1}); rtvs2.avg->axis(0)->parallelize(ParallelType::TIDx); @@ -17225,7 +17225,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination6_CUDA) { fusion.addOutput(tv4); tv4->split(1, 5); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv4->reorder({{0, 1}, {1, 0}}); tv3->computeAt(tv4, 1); @@ -17272,7 +17272,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination7_CUDA) { tv3->split(-1, 5); tv3->split(-1, 4); tv3->split(-1, 3); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); @@ -18909,7 +18909,7 @@ TEST_F(NVFuserTest, FusionFloatPow_CUDA) { tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -20152,7 +20152,7 @@ TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { tv3->split(0, 8, false); tv3->split(1, 4); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv3->axis(1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); @@ -20314,7 +20314,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); @@ -20351,7 +20351,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, -1); @@ -20390,7 +20390,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); @@ -20438,7 +20438,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); tv3->split(-1, 8); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 2); tv2->computeAt(tv3, -1); @@ -20479,7 +20479,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { tv2->split(-1, 128); tv2->split(-1, 32); tv2->split(-1, 8); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); tv0->computeAt(tv2, 2); tv1->computeAt(tv2, -1); @@ -20522,7 +20522,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { tv3->split(-1, 16); tv3->split(-2, 4); tv3->split(-2, 2); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); tv2->computeAt(tv3, -1); @@ -20559,7 +20559,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { tv2->split(-1, 128); tv2->split(-1, 4); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); tv1->computeAt(tv2, 2); @@ -20599,7 +20599,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { tv4->split(0, 32); tv4->split(0, 4); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv0->computeAt(tv4, 1); tv1->computeAt(tv4, 1); @@ -20640,7 +20640,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { out->split(0, 32); out->split(0, 4); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv2->setMemoryType(MemoryType::Shared); @@ -20708,7 +20708,7 @@ TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { auto tv6_rf = tv6->rFactor({-1}); - TransformPropagator::from(tv6_rf); + TransformPropagator(tv6_rf).run(); tv0->computeAt(tv6, 3); tv1->computeAt(tv6, 3); @@ -20774,7 +20774,7 @@ TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { tv1->setMemoryType(mem_type); tv3->split(-1, 4); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv1->computeAt(tv3, -2); @@ -21423,7 +21423,7 @@ TEST_F(NVFuserTest, FusionIndexHoist2_CUDA) { fusion.addOutput(tv5); tv5->split(-1, 4); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv4->split(-1, 3); @@ -21515,7 +21515,7 @@ TEST_F(NVFuserTest, FusionTestGridComm2_CUDA) { tv4->merge(0); tv4->split(0, 2); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv3->computeAt(tv4, 1); @@ -21881,7 +21881,7 @@ TEST_F(NVFuserTest, FusionContigIndexingWithBroadcast_CUDA) { fusion.addOutput(tv3); tv3->merge(0); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv2->setMemoryType(MemoryType::Local); @@ -21930,7 +21930,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail2_CUDA) { tv4->merge(1, 2); tv4->merge(0, 1); tv4->split(0, 4); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv0->computeAt(tv4, -2); tv1->computeAt(tv4, -2); @@ -21975,7 +21975,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexWithBroadcast_CUDA) { // Don't modify tv1 so that it's replayed as tv2 with actual // transformations. It would create temporary IterDomains, and the // validation should still be able to detect vectorization by 4 is valid. - // TransformPropagator::from(tv3); + // TransformPropagator(tv3).run(); tv2->merge(1, 2); tv2->merge(0, 1); tv2->split(0, 4); @@ -22058,7 +22058,7 @@ TEST_F(NVFuserTest, FusionTrivialReductionForwarding1_CUDA) { tv2->merge(0); tv2->split(0, 4); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); // All tensors must be transformed to a 2D tensor with each axis // mapped with each other in the LOOP map. @@ -22820,7 +22820,7 @@ TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { fusion.addOutput(tv_avg); tv_avg->split(0, 128); - TransformPropagator::from(tv_avg); + TransformPropagator(tv_avg).run(); tv_avg->axis(0)->parallelize(ParallelType::BIDx); tv_avg->axis(1)->parallelize(ParallelType::TIDx); @@ -23047,7 +23047,7 @@ TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) { tv6->merge(0); tv6->merge(0); - TransformPropagator::from(tv6); + TransformPropagator(tv6).run(); tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined); tv1->computeAt(tv6, -1, ComputeAtMode::MostInlined); @@ -23108,7 +23108,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { // T2_g[iblockIdx.x, ithreadIdx.x24, rblockIdx.y, rthreadIdx.y, rS{16}, // iV25{4}] - TransformPropagator::from(reduction_tv); + TransformPropagator(reduction_tv).run(); auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4}); scheduler_utils::parallelizeAllLike(rfactor_tv, ir_utils::allTvs(&fusion)); @@ -23597,7 +23597,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { auto tvs2 = tvs.rFactor({1, 4}); - TransformPropagator::from(tvs2.var_sum); + TransformPropagator(tvs2.var_sum).run(); // check that the resulting tensors in tvs2 are identical auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) { @@ -23659,7 +23659,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { tv0->merge(2); tv0->merge(0); - TransformPropagator::from(tv0); + TransformPropagator(tv0).run(); TORCH_CHECK(tv1->nDims() == 4); } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index f64d58302daf27..1d47bc48beaf65 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -135,7 +135,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { fusion.addOutput(tv3); tv3->split(0, tidx); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -178,7 +178,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) { fusion.addOutput(tv3); tv3->split(0, tidx); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -230,7 +230,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) { fusion.addOutput(tv3); tv3->split(1, tidx); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); @@ -277,7 +277,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) { fusion.addOutput(tv4); tv4->split(0, tidx); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::TIDx); @@ -335,7 +335,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) { // Setup the reduction tv4->split(1, tidx); - TransformPropagator::from(tv4); + TransformPropagator(tv4).run(); tv4->axis(1)->parallelize(ParallelType::BIDx); tv4->axis(2)->parallelize(ParallelType::TIDx); @@ -389,7 +389,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) { tv1->split(1, vec); tv1->split(1, tidx); tv1->split(0, tidy); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -437,7 +437,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) { fusion.addOutput(tv5); tv5->split(0, tidx); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv5->axis(0)->parallelize(ParallelType::BIDx); tv5->axis(1)->parallelize(ParallelType::TIDx); @@ -484,7 +484,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) { fusion.addOutput(tv3); tv3->split(1, tidx); - TransformPropagator::from(tv3); + TransformPropagator(tv3).run(); tv0->computeAt(tv3, 1); @@ -591,7 +591,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) { {9, 8}, {6, 9}}); - TransformPropagator::from(tv0); + TransformPropagator(tv0).run(); auto tvs_rf = tvs.rFactor({-5, -4, -3, -2, -1}); @@ -716,7 +716,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction2_CUDA) { groupReductions({tv2, tv4}); tv2->split(1, 128); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); @@ -759,7 +759,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction3_CUDA) { groupReductions({tv1, tv3}); tv1->split(1, 128); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv0->computeAt(tv5, -1, ComputeAtMode::MostInlined); @@ -1044,7 +1044,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) { groupReductions({tv1, tv3}); tv2->split(0, 128); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::TIDx); @@ -1089,7 +1089,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) { const int tidx = 512; groupReductions({tv1, tv4}); tv1->split(1, tidx); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv0->computeAt(tv8, -1, ComputeAtMode::MostInlined); @@ -1143,7 +1143,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) { groupReductions({tv1, tv4, tv7}); tv1->split(0, 128); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -1194,7 +1194,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) { auto reduction_tv = reduction_tvs.at(0); reduction_tv->split(0, 128); - TransformPropagator::from(reduction_tv); + TransformPropagator(reduction_tv).run(); reduction_tv->axis(0)->parallelize(ParallelType::BIDx); reduction_tv->axis(1)->parallelize(ParallelType::TIDx); @@ -1253,7 +1253,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) { groupReductions({tv1, tv5, tv9}); tv1->split(0, 128); - TransformPropagator::from(tv1); + TransformPropagator(tv1).run(); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -1409,7 +1409,7 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { grad_input->axis(3)->parallelize(ParallelType::BIDy); grad_input->axis(4)->parallelize(ParallelType::TIDx); - TransformPropagator::from(grad_input); + TransformPropagator(grad_input).run(); auto rf_tensors = grad_output_sum->rFactor( {-1}, std::vector({grad_output_sum, dot_p})); @@ -1522,7 +1522,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) { tv2->split(1, tidx); tv2->split(0, tidy); - TransformPropagator::from(tv2); + TransformPropagator(tv2).run(); tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); @@ -1623,7 +1623,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) { // Move the serial reduction to the right of the vector axis ref->reorder({{3, 4}, {4, 3}}); - TransformPropagator::from(ref); + TransformPropagator(ref).run(); auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); auto tv5_rf = rf_tvs.at(0); @@ -1751,7 +1751,7 @@ TEST_F( // Move the serial reduction to the right of the vector axis ref->reorder({{3, 4}, {4, 3}}); - TransformPropagator::from(ref); + TransformPropagator(ref).run(); auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); auto tv5_rf = rf_tvs.at(0); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp index 8c45bb37bbeb81..c8de3545802cce 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp @@ -1264,7 +1264,7 @@ TEST_F(NVFuserTest, FusionShiftDoubleSplitMerge2_CUDA) { out->merge(2, 3); out->merge(0, 1); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv0->computeAt(out, 1); @@ -2324,7 +2324,7 @@ TEST_F(NVFuserTest, FusionHdiffPartialSplitUnswitch_CUDA) { out->reorder({{1, 3}, {2, 1}, {3, 4}, {4, 2}}); // out: [NZ/tz, NY/by, NX/bx, tz, by, bx] - TransformPropagator::from(out); + TransformPropagator(out).run(); inp->computeAt(out, 4); @@ -2720,7 +2720,7 @@ TEST_F(NVFuserTest, FusionGather6_CUDA) { out->split(0, block_y); out->reorder({{1, 2}, {2, 1}}); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv0->computeAt(out, 2); @@ -2779,7 +2779,7 @@ TEST_F(NVFuserTest, FusionGather7_CUDA) { out->split(0, block_y); out->reorder({{1, 2}, {2, 1}}); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv0->computeAt(out, 2); @@ -2879,7 +2879,7 @@ TEST_F(NVFuserTest, FusionGather9_CUDA) { out->split(0, block_y); out->reorder({{1, 2}, {2, 1}}); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv0->computeAt(out, 2); @@ -3804,7 +3804,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) { tv5->split(-1, 8); tv5->reorder({{1, 2}}); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv2->computeAt(tv5, -1); tv3->computeAt(tv5, -1); @@ -3860,7 +3860,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) { tv5->reorder({{1, 2}}); tv5->merge(-2, -1); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv2->computeAt(tv5, -1); tv3->computeAt(tv5, -1); @@ -3920,7 +3920,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) { tv_avg->reorder({{1, 2}}); tv_avg->merge(-2, -1); - TransformPropagator::from(tv_avg); + TransformPropagator(tv_avg).run(); tv2->computeAt(tv_avg, -1); tv3->computeAt(tv_avg, -1); @@ -4106,7 +4106,7 @@ TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) { tv5->split(-1, 8); tv5->reorder({{1, 2}}); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv2->computeAt(tv5, -1); tv3->computeAt(tv5, -1); @@ -5314,7 +5314,7 @@ TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) { out->split(-2, 4); out->split(-1, 32); out->reorder({{1, 2}, {2, 1}}); - TransformPropagator::from(out); + TransformPropagator(out).run(); tv0->computeAt(out, 2); @@ -5363,7 +5363,7 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) { tv5->split(-1, 1024); tv5->split(-1, 2); - TransformPropagator::from(tv5); + TransformPropagator(tv5).run(); tv0->computeAt(tv5, 1); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 58b6a74ea1010a..f3bffdbea8a4aa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -645,495 +646,35 @@ std::pair TransformReplay::replayCasP( return replayCasP(consumer, producer, compute_at_axis, root_map); } -namespace { - -std::deque deduplicate(const std::deque& tv_deuqe) { - std::deque deduplicated; - std::unordered_set inserted; - for (auto tv_entry : tv_deuqe) { - if (inserted.find(tv_entry) == inserted.end()) { - deduplicated.emplace_back(tv_entry); - inserted.emplace(tv_entry); - } - } - return deduplicated; -} - -std::deque tvInputs(Expr* expr) { - auto tv_inputs = ir_utils::filterByType(expr->inputs()); - return std::deque(tv_inputs.begin(), tv_inputs.end()); -} - -std::deque tvOutputs(Expr* expr) { - auto tv_outputs = ir_utils::filterByType(expr->outputs()); - return std::deque(tv_outputs.begin(), tv_outputs.end()); -} - -std::deque consumersOf(TensorView* tv) { - std::deque consumer_tvs; - for (auto def : tv->uses()) { - auto outs = tvOutputs(def); - consumer_tvs.insert(consumer_tvs.end(), outs.begin(), outs.end()); - } - return deduplicate(consumer_tvs); -} - -std::deque producersFor(TensorView* tv) { - auto def = tv->definition(); - if (def == nullptr) { - return {}; - } - - return deduplicate(tvInputs(def)); -} - -// This is a struct storing how the information about a root ID in the -// starting tensor is preserved during propagation. If during propagation, we -// reached a tensor called the "current" tensor, we are interested in the -// following information: -// - Which reference tensor's root ID's information does the current tensor -// contains? Each RootIDInfo object should correspond to one reference -// tensor's root ID, but we don't need to store this ID explicitly. -// - For this reference tensor's root ID, what are its corresponding IDs in -// the current tensor's root/rfactor domain? -// - Is the current tensor's information about this reference tensor's root ID -// complete? -struct RootIDInfo { - // Each object of this class correspond to one root ID in the reference - // tensor, but we do not need to explicitly store this ID. - - // The IDs in the current tensor's root or rfactor domain that contains - // information of the corresponding reference tensor's root ID. Whether we - // are using root domain or rfactor domain depends on how we reached the - // current tensor during propagation. `is_rfactor` tells us whether the IDs - // contained in `mapped_ids` are from the root domain or the rfactor domain. - std::unordered_set mapped_ids; - - // Does `mapped_ids` contain all the IDs required to recompute the - // corresponding reference tensor's root ID? For example, if we have - // t1 = input tensor of shape (20,) - // t2 = view(t1, {4, 5}) - // t3 = sum(t2, {1}) - // t4 = set(t3) - // and we start the propagation from t1, then t2 and t3's information about - // t1 is complete, but t4 is not because one axis is missing. - bool is_complete; - - // Is `mapped_ids` from the root domain or rfactor domain of the current - // tensor? We only store IDs from one of them, depending on how we reach the - // current tensor during propagation. If we reached the current tensor from - // a consumer, then `mapped_ids` containes IDs in the current tensor's - // rfactor domain because the rfactor domain contains raw information. If we - // reached the current tensor from a producer, then `mapped_ids` containes - // IDs in the current tensor's root domain because the root domain contains - // raw information. - bool is_rfactor; - - RootIDInfo() = default; - - // This constructor is only used on the reference tensor where the - // propagation starts, so the mapped_ids are just the starting_root_id. - RootIDInfo(IterDomain* starting_root_id) - : mapped_ids{starting_root_id}, is_complete(true), is_rfactor(false) {} -}; - -enum class NextHopType { - C_AS_P, - P_AS_C, -}; - -// This is a helper struct that contains all the information about the next -// step in the Dijkstra algorithm -struct NextHopInfo { - NextHopType type; - TensorView* from = nullptr; - TensorView* to; - - std::vector root_id_info_from; - std::vector root_id_info_to; -}; - -// l < r means l contains a smaller amount of information about the starting -// tensor than r. -bool operator<(const NextHopInfo& l, const NextHopInfo& r) { - if (l.root_id_info_to.size() != r.root_id_info_to.size()) { - return l.root_id_info_to.size() < r.root_id_info_to.size(); - } - size_t l_complete = std::count_if( - l.root_id_info_to.begin(), - l.root_id_info_to.end(), - [](const RootIDInfo& i) { return i.is_complete; }); - size_t r_complete = std::count_if( - r.root_id_info_to.begin(), - r.root_id_info_to.end(), - [](const RootIDInfo& i) { return i.is_complete; }); - return l_complete < r_complete; -} - -// l > r means l contains a bigger amount of information about the starting -// tensor than r. -bool operator>(const NextHopInfo& l, const NextHopInfo& r) { - return r < l; -} - -// l == r means it is hard to tell which one of then contains more information -bool operator==(const NextHopInfo& l, const NextHopInfo& r) { - return !(r < l) && !(l < r); -} - -std::vector getStartingRootIDInfo(TensorView* tv) { - std::vector result; +std::shared_ptr TransformPropagator:: + getStartingRootIDInfo(TensorView* tv) { + RootDomainInfo result; const auto& root_domain = tv->getRootDomain(); - result.reserve(root_domain.size()); + result.info.reserve(root_domain.size()); for (auto id : root_domain) { - result.emplace_back(id); - } - return result; -} - -// Infer the compute-at position from the information of preserved reference -// root ID. -// -// TODO: -// I think I need to modify TransformReplay to add a new interface to specify -// the root domains, instead of a position in the leaf domain. With the new -// interface, this function will not be needed. -size_t getReplayPos(const NextHopInfo& next_hop) { - auto& root_id_info = next_hop.root_id_info_from; - auto from_tv = next_hop.from; - // Flatten `root_id_info_from` to get the list of ids in the `from_tv`'s - // root/rfactor domain that contains information about the reference tensor. - std::unordered_set from_ids; - from_ids.reserve(root_id_info.size()); - for (auto info : root_id_info) { - for (auto id : info.mapped_ids) { - from_ids.insert(id); - } - } - // Get leaf IDs that contain information of `from_ids` - std::unordered_set relevant_leaves; - std::vector to_visit(from_ids.begin(), from_ids.end()); - while (!to_visit.empty()) { - auto front = to_visit.back(); - to_visit.pop_back(); - if (front->uses().empty()) { - relevant_leaves.emplace(front); - } else { - for (auto def : front->uses()) { - auto outs = ir_utils::filterByType(def->outputs()); - to_visit.insert(to_visit.end(), outs.begin(), outs.end()); - } - } - } - // Find the pos where all leaf IDs at <= pos contains - // information about the starting root domain - // - // TODO: should I change to the following behavior? - // - // Find the smallest pos where all leaf IDs containing - // information about the starting root domain are <= pos - // - // For example, if I have - // preserved root domain: [I1, I2, I3, I4] - // leaf domain: [I5, I6, I7, I8] - // where - // I5 = merge(I1, I2) - // I6 = something unrelated - // I7 = merge(I3, I4) - // I8 = something unrelated - // should I return 1, or 3 ? - - // size_t i; - // for (i = from_tv->nDims() - 1; i >= 0; i--) { - // if (relevant_leaves.count(from_tv->axis(i)) > 0) { - // break; - // } - // } - // return i + 1; - - for (size_t i = 0; i < from_tv->nDims(); i++) { - if (relevant_leaves.count(from_tv->axis(i)) == 0) { - return i; - } - } - return from_tv->nDims(); -} - -// Given `root_ids`, a list of IDs in the root domain of `tv`, find their -// corresponding IDs in the rfactor domain of `tv`. -std::unordered_set mapRootToRFactor( - TensorView* tv, - const std::unordered_set& root_ids) { - std::unordered_set mapped_rfactor_ids; - const auto& rfactor_dom = tv->getMaybeRFactorDomain(); - for (auto id : rfactor_dom) { - if (root_ids.count(id) > 0) { - mapped_rfactor_ids.emplace(id); - continue; - } - for (auto root_id : root_ids) { - if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) { - mapped_rfactor_ids.emplace(id); - break; - } - } - } - return mapped_rfactor_ids; -} - -// Given `rfactor_ids`, a list of IDs in the rfactor domain of `tv`, find their -// corresponding IDs in the root domain of `tv`. -std::unordered_set mapRFactorToRoot( - TensorView* tv, - const std::unordered_set& rfactor_ids) { - std::unordered_set mapped_root_ids; - for (auto id : tv->getRootDomain()) { - if (rfactor_ids.count(id) > 0) { - mapped_root_ids.emplace(id); - continue; - } - for (auto rfactor_id : rfactor_ids) { - if (DependencyCheck::isDependencyOf(id, rfactor_id)) { - mapped_root_ids.emplace(id); - break; - } - } + result.info.emplace_back(RootIDInfo{{id}, true, false}); } - return mapped_root_ids; + return std::make_shared( + std::move(result)); } -// Given the preserved reference root ID info of a producer, compute -// the corresponding info in consumer. The given info may be represented by -// producer's root domain, or rfactor domain, depending on how we reached the -// producer during propagation. If the given info is already represented with -// producer's rfactor domain, then we directly map it to the consumer's root -// domain. If the given info is represented with producer's root domain, we need -// to first map it to the rfactor domain of the producer, then we can map it to -// the consumer's root domain. The computed info will be represented by root -// domain as root domain contains the raw information. -std::vector computeNextRootIDInfoCasP( - TensorView* producer, - TensorView* consumer, - const std::vector& producer_root_id_info) { - std::vector result; - - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); - auto p2c_map = pairwise_map.mapProducerToConsumer( - producer->domain(), consumer->domain()); - - for (auto& info : producer_root_id_info) { - RootIDInfo consumer_info; - consumer_info.is_complete = info.is_complete; - consumer_info.is_rfactor = false; - - // mapped root ids in producer -> mapped rfactor ids in producer - std::unordered_set producer_mapped_rfactor_ids; - if (producer->hasRFactor() && !info.is_rfactor) { - producer_mapped_rfactor_ids = mapRootToRFactor(producer, info.mapped_ids); - } else { - producer_mapped_rfactor_ids = info.mapped_ids; - } - - // mapped rfactor ids in producer -> mapped root ids in consumer - for (auto producer_id : producer_mapped_rfactor_ids) { - auto it = p2c_map.find(producer_id); - if (it != p2c_map.end()) { - consumer_info.mapped_ids.insert(it->second); - } else { - consumer_info.is_complete = false; - } - } - - // If at least one root id in the consumer contains information - // of this starting root id, then keep this record - if (!consumer_info.mapped_ids.empty()) { - result.push_back(consumer_info); - } - } - return result; +void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { + int pos = replayed_pos.at(from); + auto replay = TransformReplay::replayPasC(to, from, pos); + to->setDomain(replay.first); + replayed_pos[to] = replay.second; } -// Given the preserved reference root ID info of a consumer, compute -// the corresponding info in producer. The given info may be represented by -// consumer's root domain, or rfactor domain, depending on how we reached the -// consumer during propagation. If the given info is already represented with -// consumer's root domain, then we directly map it to the producer's rfactor -// domain. If the given info is represented with consumer's rfactor domain, we -// need to first map it to the root domain of the consumer, then we can map it -// to the producer's rfactor domain. The computed info will be represented by -// rfactor domain as rfactor domain contains the raw information. -std::vector computeNextRootIDInfoPasC( - TensorView* producer, - TensorView* consumer, - const std::vector& consumer_root_id_info) { - std::vector result; - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); - auto c2p_map = pairwise_map.mapConsumerToProducer( - consumer->domain(), producer->domain()); - - for (auto& info : consumer_root_id_info) { - RootIDInfo producer_info; - producer_info.is_complete = info.is_complete; - producer_info.is_rfactor = true; - - // mapped rfactor ids in consumer -> mapped root ids in consumer - std::unordered_set consumer_mapped_root_ids; - if (info.is_rfactor && consumer->hasRFactor()) { - consumer_mapped_root_ids = mapRFactorToRoot(consumer, info.mapped_ids); - } else { - consumer_mapped_root_ids = info.mapped_ids; - } - - // mapped root ids in consumer -> mapped rfactor ids in producer - for (auto consumer_id : consumer_mapped_root_ids) { - auto it = c2p_map.find(consumer_id); - if (it != c2p_map.end()) { - producer_info.mapped_ids.insert(it->second); - } else { - producer_info.is_complete = false; - } - } - - // We will stop at the rfactor ids in producer, and will not further map - // them into root ids in producer. This means, we only keep the unprocessed - // raw information of a tensor. This behavior is important to make sure that - // info is as accurate as possible throughout the propagation. - // - // For example, if we do a C->P->C' propagation, we want to do - // C(root) -> P(rfactor) -> C'(root) - // instead of - // C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root) - // - // and the above two paths do lead to different results: - // - // For example if you have a producer tensor - // root domain: [I1, I2] - // rfactor domain: [I3, I5] - // where I3, I4 = split(I1), I5 = merge(I4, I2) - // Then the P(rfactor) -> P(root) -> P(rfactor) could lead to - // P(rfactor: {I5}) -> P(root: {I1, I2}) -> P(rfactor: {I3, I5}) - // which is not correct - - // If at least one root id in the producer contains information - // of this starting root id, then keep this record - if (!producer_info.mapped_ids.empty()) { - result.push_back(producer_info); - } - } - return result; -} - -}; // namespace - -unsigned int TransformPropagator::replay(const NextHopInfo& next_hop) { - if (next_hop.from == nullptr) { - // nullptr used to start from starting_tv - return next_hop.to->nDims(); - } - // TODO: why does TransformReplay require specifying a position in the - // leaf domain? I want to change the interface to allow specifying - // the starting root domains instead of leaf position. - int pos = getReplayPos(next_hop); - std::pair replay; - switch (next_hop.type) { - case NextHopType::P_AS_C: { - auto pairwiseMap = PairwiseRootDomainMap(next_hop.to, next_hop.from); - replay = TransformReplay::replayPasC( - next_hop.to, next_hop.from, pos, pairwiseMap); - break; - } - case NextHopType::C_AS_P: { - auto pairwiseMap = PairwiseRootDomainMap(next_hop.from, next_hop.to); - replay = TransformReplay::replayCasP( - next_hop.to, next_hop.from, pos, pairwiseMap); - break; - } - } - next_hop.to->setDomain(replay.first); - return replay.second; -} - -// Dijkstra -TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { - // A set that allows us to quickly tell if a tensor has been replayed. If yes, - // then we will not bother computing if a new path to this tensor is worth - // taking (because the answer is always not worth) - std::unordered_set replayed; - - // A sorted list of possible next steps. The list is sorted in the order of - // ascending amount of preserved information about the reference tensor. The - // back of the list preserves the most amount of information about the - // reference tensor, and should always be the next step to take. We use - // std::list instead of std::priority_queue because C++'s - // std::priority_queue does not support increase-key, and might not be - // deterministic either. - std::list propagation(1); - propagation.back().from = nullptr; - propagation.back().to = starting_tv; - propagation.back().root_id_info_to = getStartingRootIDInfo(starting_tv); - - // Insert the given next hop the correct position in `propagation`. If there - // is an existing next hop that preserves more information, then we will just - // discard `info`. - auto insertNextHopInfo = [&](const NextHopInfo& info) { - if (info.root_id_info_from.empty()) { - // When there is no more information about the starting tensor, - // we are not interested in continuing the propagation. - return; - } - // Find if there is already a path to the dest tensor - auto existing = std::find_if( - propagation.begin(), propagation.end(), [&](const NextHopInfo& i) { - return i.to == info.to; - }); - // Only insert if there is no existing path to the dest tensor, or the new - // path preserves more information about the starting tensor. - if (existing == propagation.end() || *existing < info) { - if (existing != propagation.end()) { - propagation.erase(existing); - } - auto pos = std::upper_bound(propagation.begin(), propagation.end(), info); - propagation.insert(pos, info); - } - }; - - while (!propagation.empty()) { - auto next_hop = propagation.back(); - propagation.pop_back(); - - replay(next_hop); - replayed.emplace(next_hop.to); - - for (auto consumer_tv : consumersOf(next_hop.to)) { - if (replayed.count(consumer_tv)) { - continue; - } - insertNextHopInfo( - {.type = NextHopType::C_AS_P, - .from = next_hop.to, - .to = consumer_tv, - .root_id_info_from = next_hop.root_id_info_to, - .root_id_info_to = computeNextRootIDInfoCasP( - next_hop.to, consumer_tv, next_hop.root_id_info_to)}); - } - - for (auto producer_tv : producersFor(next_hop.to)) { - if (replayed.count(producer_tv)) { - continue; - } - insertNextHopInfo( - {.type = NextHopType::P_AS_C, - .from = next_hop.to, - .to = producer_tv, - .root_id_info_from = next_hop.root_id_info_to, - .root_id_info_to = computeNextRootIDInfoPasC( - producer_tv, next_hop.to, next_hop.root_id_info_to)}); - } - } +void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { + int pos = replayed_pos.at(from); + auto replay = TransformReplay::replayCasP(to, from, pos); + to->setDomain(replay.first); + replayed_pos[to] = replay.second; } -void TransformPropagator::from(TensorView* tv) { - TransformPropagator propagate(tv); +TransformPropagator::TransformPropagator(TensorView* from) + : MaxRootDomainInfoPropagator(from, getStartingRootIDInfo(from)) { + replayed_pos[from] = from->nDims(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 8b12917985467c..cb9daf8fff832f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -155,29 +156,18 @@ class TORCH_CUDA_CU_API TransformReplay { const TensorDomain* self); }; -namespace { -struct NextHopInfo; -} +class TORCH_CUDA_CU_API TransformPropagator + : public MaxRootDomainInfoPropagator { + std::unordered_map replayed_pos; + static std::shared_ptr + getStartingRootIDInfo(TensorView* tv); -// TransformPropagator starts from a reference tensor, and propagate -// the transformations in this tensor to the entire graph. The propagation -// is done with the Dijkstra algorithm which will transform every tensor -// in this graph based on the information flow from the path that perserves -// the most amount of information about the reference tensor. Every tensor in -// the graph is replayed only once. -// -// During the propagation, we explicitly keep track of the information about -// which reference tensor's root ID's information is preserved, and to which -// level. This information is stored as a vector of `RootIDInfo`, where each -// item in the vector correspond to one ID in the reference tensor's root -// domain. -class TORCH_CUDA_CU_API TransformPropagator { - TensorView* starting_tv = nullptr; - TransformPropagator(TensorView* from); - static unsigned int replay(const NextHopInfo&); + protected: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; public: - static void from(TensorView* tv); + TransformPropagator(TensorView* from); }; } // namespace cuda