diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index ab37e8fcf6319..46bf0801e2e9e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -158,6 +159,7 @@ class TransformPropagator; class TransformIter; class TransformReplay; class OptOutMutator; +class TensorDomain; namespace ir_utils { class TVDomainGuard; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 2da0c8de6d4a8..f4a8ea1fe36d1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -4,7 +4,6 @@ #include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 4fdb04e5760f1..01dc590099cdf 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -23574,6 +23574,96 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1760 + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion.addOutput(tvs.var_sum); + + tvs.avg->split(1, 1); + tvs.avg->split(1, 2); + tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + + auto tvs2 = tvs.rFactor({1, 4}); + + TransformPropagator::from(tvs2.var_sum); + + // check that the resulting tensors in tvs2 are identical + auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + TORCH_CHECK(replay_root.size() == target_root.size()); + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + TORCH_CHECK(replay_it != r.end()); + TORCH_CHECK( + replay_it->second == replay_dom[i], + "IterDomain mismatch when checking ", + replay, + " and ", + target, + " at ", + i, + ", got ", + replay_it->second, + " and ", + replay_dom[i]); + } + }; + std::vector siblings[] = { + {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}}; + for (auto tensors : siblings) { + for (auto t1 : tensors) { + for (auto t2 : tensors) { + checkSiblingConsistency(t1, t2); + } + } + } +} + +TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(4); + auto tv1 = makeSymbolicTensor(6); + fusion.addInput(tv0); + + auto tv2 = broadcast(tv0, {false, false, true, false, false, true}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv0->merge(2); + tv0->merge(0); + TransformPropagator::from(tv0); + + TORCH_CHECK(tv1->nDims() == 4); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index d71efd00ed0f3..58b6a74ea1010 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -687,134 +687,447 @@ std::deque producersFor(TensorView* tv) { return deduplicate(tvInputs(def)); } -}; // namespace +// This is a struct storing how the information about a root ID in the +// starting tensor is preserved during propagation. If during propagation, we +// reached a tensor called the "current" tensor, we are interested in the +// following information: +// - Which reference tensor's root ID's information does the current tensor +// contains? Each RootIDInfo object should correspond to one reference +// tensor's root ID, but we don't need to store this ID explicitly. +// - For this reference tensor's root ID, what are its corresponding IDs in +// the current tensor's root/rfactor domain? +// - Is the current tensor's information about this reference tensor's root ID +// complete? +struct RootIDInfo { + // Each object of this class correspond to one root ID in the reference + // tensor, but we do not need to explicitly store this ID. + + // The IDs in the current tensor's root or rfactor domain that contains + // information of the corresponding reference tensor's root ID. Whether we + // are using root domain or rfactor domain depends on how we reached the + // current tensor during propagation. `is_rfactor` tells us whether the IDs + // contained in `mapped_ids` are from the root domain or the rfactor domain. + std::unordered_set mapped_ids; + + // Does `mapped_ids` contain all the IDs required to recompute the + // corresponding reference tensor's root ID? For example, if we have + // t1 = input tensor of shape (20,) + // t2 = view(t1, {4, 5}) + // t3 = sum(t2, {1}) + // t4 = set(t3) + // and we start the propagation from t1, then t2 and t3's information about + // t1 is complete, but t4 is not because one axis is missing. + bool is_complete; + + // Is `mapped_ids` from the root domain or rfactor domain of the current + // tensor? We only store IDs from one of them, depending on how we reach the + // current tensor during propagation. If we reached the current tensor from + // a consumer, then `mapped_ids` containes IDs in the current tensor's + // rfactor domain because the rfactor domain contains raw information. If we + // reached the current tensor from a producer, then `mapped_ids` containes + // IDs in the current tensor's root domain because the root domain contains + // raw information. + bool is_rfactor; + + RootIDInfo() = default; + + // This constructor is only used on the reference tensor where the + // propagation starts, so the mapped_ids are just the starting_root_id. + RootIDInfo(IterDomain* starting_root_id) + : mapped_ids{starting_root_id}, is_complete(true), is_rfactor(false) {} +}; -bool TransformPropagator::replayPasC( - TensorView* producer_tv, - TensorView* consumer_tv) { - if (producer_tv == starting_tv) { - return false; - } +enum class NextHopType { + C_AS_P, + P_AS_C, +}; - auto consumer_pos_it = replayed_pos.find(consumer_tv); - if (consumer_pos_it == replayed_pos.end()) { - return false; - } +// This is a helper struct that contains all the information about the next +// step in the Dijkstra algorithm +struct NextHopInfo { + NextHopType type; + TensorView* from = nullptr; + TensorView* to; - auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto replayed_producer = TransformReplay::replayPasC( - producer_tv, consumer_tv, consumer_pos_it->second, pairwiseMap); + std::vector root_id_info_from; + std::vector root_id_info_to; +}; - auto producer_root = producer_tv->getMaybeRFactorDomain(); - auto replayed_domain = replayed_producer.first->domain(); +// l < r means l contains a smaller amount of information about the starting +// tensor than r. +bool operator<(const NextHopInfo& l, const NextHopInfo& r) { + if (l.root_id_info_to.size() != r.root_id_info_to.size()) { + return l.root_id_info_to.size() < r.root_id_info_to.size(); + } + size_t l_complete = std::count_if( + l.root_id_info_to.begin(), + l.root_id_info_to.end(), + [](const RootIDInfo& i) { return i.is_complete; }); + size_t r_complete = std::count_if( + r.root_id_info_to.begin(), + r.root_id_info_to.end(), + [](const RootIDInfo& i) { return i.is_complete; }); + return l_complete < r_complete; +} - // Find the number of root IDs involved in the transformation - auto dep_vals = DependencyCheck::getAllValsBetween( - {producer_root.begin(), producer_root.end()}, - {replayed_domain.begin(), - replayed_domain.begin() + replayed_producer.second}); +// l > r means l contains a bigger amount of information about the starting +// tensor than r. +bool operator>(const NextHopInfo& l, const NextHopInfo& r) { + return r < l; +} - std::unordered_set dep_vals_set{dep_vals.begin(), dep_vals.end()}; +// l == r means it is hard to tell which one of then contains more information +bool operator==(const NextHopInfo& l, const NextHopInfo& r) { + return !(r < l) && !(l < r); +} - auto n_transformed_root_dims = std::count_if( - producer_root.begin(), - producer_root.end(), - [&dep_vals_set](IterDomain* root_id) { - return dep_vals_set.find(root_id) != dep_vals_set.end(); - }); +std::vector getStartingRootIDInfo(TensorView* tv) { + std::vector result; + const auto& root_domain = tv->getRootDomain(); + result.reserve(root_domain.size()); + for (auto id : root_domain) { + result.emplace_back(id); + } + return result; +} - if (replayed_pos.find(producer_tv) != replayed_pos.end()) { - if (n_transformed_root_dims < n_replayed_root_dims.at(producer_tv) || - (n_transformed_root_dims == n_replayed_root_dims.at(producer_tv) && - replayed_producer.second <= replayed_pos.at(producer_tv))) { - return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) +// Infer the compute-at position from the information of preserved reference +// root ID. +// +// TODO: +// I think I need to modify TransformReplay to add a new interface to specify +// the root domains, instead of a position in the leaf domain. With the new +// interface, this function will not be needed. +size_t getReplayPos(const NextHopInfo& next_hop) { + auto& root_id_info = next_hop.root_id_info_from; + auto from_tv = next_hop.from; + // Flatten `root_id_info_from` to get the list of ids in the `from_tv`'s + // root/rfactor domain that contains information about the reference tensor. + std::unordered_set from_ids; + from_ids.reserve(root_id_info.size()); + for (auto info : root_id_info) { + for (auto id : info.mapped_ids) { + from_ids.insert(id); } } - - producer_tv->setDomain(replayed_producer.first); - replayed_pos[producer_tv] = replayed_producer.second; - n_replayed_root_dims[producer_tv] = n_transformed_root_dims; - - return true; + // Get leaf IDs that contain information of `from_ids` + std::unordered_set relevant_leaves; + std::vector to_visit(from_ids.begin(), from_ids.end()); + while (!to_visit.empty()) { + auto front = to_visit.back(); + to_visit.pop_back(); + if (front->uses().empty()) { + relevant_leaves.emplace(front); + } else { + for (auto def : front->uses()) { + auto outs = ir_utils::filterByType(def->outputs()); + to_visit.insert(to_visit.end(), outs.begin(), outs.end()); + } + } + } + // Find the pos where all leaf IDs at <= pos contains + // information about the starting root domain + // + // TODO: should I change to the following behavior? + // + // Find the smallest pos where all leaf IDs containing + // information about the starting root domain are <= pos + // + // For example, if I have + // preserved root domain: [I1, I2, I3, I4] + // leaf domain: [I5, I6, I7, I8] + // where + // I5 = merge(I1, I2) + // I6 = something unrelated + // I7 = merge(I3, I4) + // I8 = something unrelated + // should I return 1, or 3 ? + + // size_t i; + // for (i = from_tv->nDims() - 1; i >= 0; i--) { + // if (relevant_leaves.count(from_tv->axis(i)) > 0) { + // break; + // } + // } + // return i + 1; + + for (size_t i = 0; i < from_tv->nDims(); i++) { + if (relevant_leaves.count(from_tv->axis(i)) == 0) { + return i; + } + } + return from_tv->nDims(); } -bool TransformPropagator::replayCasP( - TensorView* consumer_tv, - TensorView* producer_tv) { - if (consumer_tv == starting_tv) { - return false; +// Given `root_ids`, a list of IDs in the root domain of `tv`, find their +// corresponding IDs in the rfactor domain of `tv`. +std::unordered_set mapRootToRFactor( + TensorView* tv, + const std::unordered_set& root_ids) { + std::unordered_set mapped_rfactor_ids; + const auto& rfactor_dom = tv->getMaybeRFactorDomain(); + for (auto id : rfactor_dom) { + if (root_ids.count(id) > 0) { + mapped_rfactor_ids.emplace(id); + continue; + } + for (auto root_id : root_ids) { + if (id == root_id || DependencyCheck::isDependencyOf(root_id, id)) { + mapped_rfactor_ids.emplace(id); + break; + } + } } + return mapped_rfactor_ids; +} - auto producer_pos_it = replayed_pos.find(producer_tv); - if (producer_pos_it == replayed_pos.end()) { - return false; +// Given `rfactor_ids`, a list of IDs in the rfactor domain of `tv`, find their +// corresponding IDs in the root domain of `tv`. +std::unordered_set mapRFactorToRoot( + TensorView* tv, + const std::unordered_set& rfactor_ids) { + std::unordered_set mapped_root_ids; + for (auto id : tv->getRootDomain()) { + if (rfactor_ids.count(id) > 0) { + mapped_root_ids.emplace(id); + continue; + } + for (auto rfactor_id : rfactor_ids) { + if (DependencyCheck::isDependencyOf(id, rfactor_id)) { + mapped_root_ids.emplace(id); + break; + } + } } + return mapped_root_ids; +} - auto pairwiseMap = PairwiseRootDomainMap(producer_tv, consumer_tv); - auto replayed_consumer = TransformReplay::replayCasP( - consumer_tv, producer_tv, producer_pos_it->second, pairwiseMap); +// Given the preserved reference root ID info of a producer, compute +// the corresponding info in consumer. The given info may be represented by +// producer's root domain, or rfactor domain, depending on how we reached the +// producer during propagation. If the given info is already represented with +// producer's rfactor domain, then we directly map it to the consumer's root +// domain. If the given info is represented with producer's root domain, we need +// to first map it to the rfactor domain of the producer, then we can map it to +// the consumer's root domain. The computed info will be represented by root +// domain as root domain contains the raw information. +std::vector computeNextRootIDInfoCasP( + TensorView* producer, + TensorView* consumer, + const std::vector& producer_root_id_info) { + std::vector result; + + auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto p2c_map = pairwise_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); + + for (auto& info : producer_root_id_info) { + RootIDInfo consumer_info; + consumer_info.is_complete = info.is_complete; + consumer_info.is_rfactor = false; + + // mapped root ids in producer -> mapped rfactor ids in producer + std::unordered_set producer_mapped_rfactor_ids; + if (producer->hasRFactor() && !info.is_rfactor) { + producer_mapped_rfactor_ids = mapRootToRFactor(producer, info.mapped_ids); + } else { + producer_mapped_rfactor_ids = info.mapped_ids; + } - auto consumer_root = consumer_tv->getRootDomain(); - auto replayed_domain = replayed_consumer.first->domain(); + // mapped rfactor ids in producer -> mapped root ids in consumer + for (auto producer_id : producer_mapped_rfactor_ids) { + auto it = p2c_map.find(producer_id); + if (it != p2c_map.end()) { + consumer_info.mapped_ids.insert(it->second); + } else { + consumer_info.is_complete = false; + } + } - // Find the number of root IDs involved in the transformation - auto dep_vals = DependencyCheck::getAllValsBetween( - {consumer_root.begin(), consumer_root.end()}, - {replayed_domain.begin(), - replayed_domain.begin() + replayed_consumer.second}); + // If at least one root id in the consumer contains information + // of this starting root id, then keep this record + if (!consumer_info.mapped_ids.empty()) { + result.push_back(consumer_info); + } + } + return result; +} - std::unordered_set dep_vals_set{dep_vals.begin(), dep_vals.end()}; +// Given the preserved reference root ID info of a consumer, compute +// the corresponding info in producer. The given info may be represented by +// consumer's root domain, or rfactor domain, depending on how we reached the +// consumer during propagation. If the given info is already represented with +// consumer's root domain, then we directly map it to the producer's rfactor +// domain. If the given info is represented with consumer's rfactor domain, we +// need to first map it to the root domain of the consumer, then we can map it +// to the producer's rfactor domain. The computed info will be represented by +// rfactor domain as rfactor domain contains the raw information. +std::vector computeNextRootIDInfoPasC( + TensorView* producer, + TensorView* consumer, + const std::vector& consumer_root_id_info) { + std::vector result; + auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto c2p_map = pairwise_map.mapConsumerToProducer( + consumer->domain(), producer->domain()); + + for (auto& info : consumer_root_id_info) { + RootIDInfo producer_info; + producer_info.is_complete = info.is_complete; + producer_info.is_rfactor = true; + + // mapped rfactor ids in consumer -> mapped root ids in consumer + std::unordered_set consumer_mapped_root_ids; + if (info.is_rfactor && consumer->hasRFactor()) { + consumer_mapped_root_ids = mapRFactorToRoot(consumer, info.mapped_ids); + } else { + consumer_mapped_root_ids = info.mapped_ids; + } - auto n_transformed_root_dims = std::count_if( - consumer_root.begin(), - consumer_root.end(), - [&dep_vals_set](IterDomain* root_id) { - return dep_vals_set.find(root_id) != dep_vals_set.end(); - }); + // mapped root ids in consumer -> mapped rfactor ids in producer + for (auto consumer_id : consumer_mapped_root_ids) { + auto it = c2p_map.find(consumer_id); + if (it != c2p_map.end()) { + producer_info.mapped_ids.insert(it->second); + } else { + producer_info.is_complete = false; + } + } - if (replayed_pos.find(consumer_tv) != replayed_pos.end()) { - if (n_transformed_root_dims < n_replayed_root_dims.at(consumer_tv) || - (n_transformed_root_dims == n_replayed_root_dims.at(consumer_tv) && - replayed_consumer.second <= replayed_pos.at(consumer_tv))) { - return false; // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + // We will stop at the rfactor ids in producer, and will not further map + // them into root ids in producer. This means, we only keep the unprocessed + // raw information of a tensor. This behavior is important to make sure that + // info is as accurate as possible throughout the propagation. + // + // For example, if we do a C->P->C' propagation, we want to do + // C(root) -> P(rfactor) -> C'(root) + // instead of + // C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root) + // + // and the above two paths do lead to different results: + // + // For example if you have a producer tensor + // root domain: [I1, I2] + // rfactor domain: [I3, I5] + // where I3, I4 = split(I1), I5 = merge(I4, I2) + // Then the P(rfactor) -> P(root) -> P(rfactor) could lead to + // P(rfactor: {I5}) -> P(root: {I1, I2}) -> P(rfactor: {I3, I5}) + // which is not correct + + // If at least one root id in the producer contains information + // of this starting root id, then keep this record + if (!producer_info.mapped_ids.empty()) { + result.push_back(producer_info); } } + return result; +} - consumer_tv->setDomain(replayed_consumer.first); - replayed_pos[consumer_tv] = replayed_consumer.second; - n_replayed_root_dims[consumer_tv] = n_transformed_root_dims; +}; // namespace - return true; +unsigned int TransformPropagator::replay(const NextHopInfo& next_hop) { + if (next_hop.from == nullptr) { + // nullptr used to start from starting_tv + return next_hop.to->nDims(); + } + // TODO: why does TransformReplay require specifying a position in the + // leaf domain? I want to change the interface to allow specifying + // the starting root domains instead of leaf position. + int pos = getReplayPos(next_hop); + std::pair replay; + switch (next_hop.type) { + case NextHopType::P_AS_C: { + auto pairwiseMap = PairwiseRootDomainMap(next_hop.to, next_hop.from); + replay = TransformReplay::replayPasC( + next_hop.to, next_hop.from, pos, pairwiseMap); + break; + } + case NextHopType::C_AS_P: { + auto pairwiseMap = PairwiseRootDomainMap(next_hop.from, next_hop.to); + replay = TransformReplay::replayCasP( + next_hop.to, next_hop.from, pos, pairwiseMap); + break; + } + } + next_hop.to->setDomain(replay.first); + return replay.second; } +// Dijkstra TransformPropagator::TransformPropagator(TensorView* from) : starting_tv(from) { - VectorOfUniqueEntries propagation{starting_tv}; - - // Seed position with local tv - replayed_pos[from] = from->nDims(); + // A set that allows us to quickly tell if a tensor has been replayed. If yes, + // then we will not bother computing if a new path to this tensor is worth + // taking (because the answer is always not worth) + std::unordered_set replayed; + + // A sorted list of possible next steps. The list is sorted in the order of + // ascending amount of preserved information about the reference tensor. The + // back of the list preserves the most amount of information about the + // reference tensor, and should always be the next step to take. We use + // std::list instead of std::priority_queue because C++'s + // std::priority_queue does not support increase-key, and might not be + // deterministic either. + std::list propagation(1); + propagation.back().from = nullptr; + propagation.back().to = starting_tv; + propagation.back().root_id_info_to = getStartingRootIDInfo(starting_tv); + + // Insert the given next hop the correct position in `propagation`. If there + // is an existing next hop that preserves more information, then we will just + // discard `info`. + auto insertNextHopInfo = [&](const NextHopInfo& info) { + if (info.root_id_info_from.empty()) { + // When there is no more information about the starting tensor, + // we are not interested in continuing the propagation. + return; + } + // Find if there is already a path to the dest tensor + auto existing = std::find_if( + propagation.begin(), propagation.end(), [&](const NextHopInfo& i) { + return i.to == info.to; + }); + // Only insert if there is no existing path to the dest tensor, or the new + // path preserves more information about the starting tensor. + if (existing == propagation.end() || *existing < info) { + if (existing != propagation.end()) { + propagation.erase(existing); + } + auto pos = std::upper_bound(propagation.begin(), propagation.end(), info); + propagation.insert(pos, info); + } + }; - // While tensor views are being replayed, if they're modified, make sure we - // propagate back to all producers as well as consumers. This is definitely - // not the most efficient implementation as what we do is any time a tv is - // changed we propagate both forward and backward. while (!propagation.empty()) { - auto tv = propagation.popBack(); - - // Replay tv forward to its consumers. - for (auto consumer_tv : consumersOf(tv)) { - auto replayed = replayCasP(consumer_tv, tv); - // If consumer has changed, mark we should propagate - if (replayed) { - propagation.pushBack(consumer_tv); + auto next_hop = propagation.back(); + propagation.pop_back(); + + replay(next_hop); + replayed.emplace(next_hop.to); + + for (auto consumer_tv : consumersOf(next_hop.to)) { + if (replayed.count(consumer_tv)) { + continue; } + insertNextHopInfo( + {.type = NextHopType::C_AS_P, + .from = next_hop.to, + .to = consumer_tv, + .root_id_info_from = next_hop.root_id_info_to, + .root_id_info_to = computeNextRootIDInfoCasP( + next_hop.to, consumer_tv, next_hop.root_id_info_to)}); } - for (auto producer_tv : producersFor(tv)) { - // If producer has changed, mark we should propagate - auto replayed = replayPasC(producer_tv, tv); - if (replayed) { - propagation.pushBack(producer_tv); + for (auto producer_tv : producersFor(next_hop.to)) { + if (replayed.count(producer_tv)) { + continue; } + insertNextHopInfo( + {.type = NextHopType::P_AS_C, + .from = next_hop.to, + .to = producer_tv, + .root_id_info_from = next_hop.root_id_info_to, + .root_id_info_to = computeNextRootIDInfoPasC( + producer_tv, next_hop.to, next_hop.root_id_info_to)}); } } } diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 0f7c7c00c8532..8b12917985467 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -2,9 +2,11 @@ #include #include +#include #include #include +#include #include namespace torch { @@ -153,36 +155,26 @@ class TORCH_CUDA_CU_API TransformReplay { const TensorDomain* self); }; -class TORCH_CUDA_CU_API TransformPropagator { - private: - bool replayPasC(TensorView* producer_tv, TensorView* consumer_tv = nullptr); - bool replayCasP(TensorView* consumer_tv, TensorView* producer_tv = nullptr); - - TransformPropagator(TensorView* from); +namespace { +struct NextHopInfo; +} - private: - std::unordered_map replayed_pos; - - // This example comes from a BN kernel, the domain: - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i4, 128), 4), 1)}, iS{1}, iS{4}, iS{128}, - // iS{i0}, iS{i2}, iS{i3} ] - // - // and - // - // [ iS{ceilDiv(ceilDiv(ceilDiv(i5*i6*i7*i8, 128), 4), 1)}, iS252{1}, - // iS250{4}, iS248{128} ] - // - // Have the same number of replayed dimensions, however the second one - // involves more root domains. The second one is also likely the prefered - // replay. Therefore keep track of how many root domains were part of the - // replay and prefer transformations with more root domains. We could probably - // fix this instances of this occuring by changing the traversal pattern so - // that once propagating towards roots through broadcast axes, it can't come - // back through another broadcast, losing the transformation on those axes. - // However, this should work for existing cases. - std::unordered_map n_replayed_root_dims; +// TransformPropagator starts from a reference tensor, and propagate +// the transformations in this tensor to the entire graph. The propagation +// is done with the Dijkstra algorithm which will transform every tensor +// in this graph based on the information flow from the path that perserves +// the most amount of information about the reference tensor. Every tensor in +// the graph is replayed only once. +// +// During the propagation, we explicitly keep track of the information about +// which reference tensor's root ID's information is preserved, and to which +// level. This information is stored as a vector of `RootIDInfo`, where each +// item in the vector correspond to one ID in the reference tensor's root +// domain. +class TORCH_CUDA_CU_API TransformPropagator { TensorView* starting_tv = nullptr; + TransformPropagator(TensorView* from); + static unsigned int replay(const NextHopInfo&); public: static void from(TensorView* tv);