diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 095cbdd05568c..79e7bc4cf8faf 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -83,31 +83,6 @@ DisjointSets& IterDomainGraph::nodes(IdMappingMode mode) { return node_set_it->second; } -//! Map corresponding inputs and outputs of swizzle op together -//! on the given disjoint set, if the given id is an output -//! of a swizzle operator. -//! -//! The current usage of swizzle operator is local to each tensor -//! itself, so they should not affect exact or permissive mapping -//! between iterdomains on different tensor domains. -//! TODO: -//! Exact mapping based index hoisting of swizzled iterdomains -//! is disabled currently and will be re-enabled in the next -//! few build out steps. -void mapMaybeSwizzleOp( - DisjointSets& disjoint_sets, - IterDomain* id) { - if (auto swizzle_2d = dynamic_cast(id->definition())) { - // Map each input to its corresponding output on the given - // disjoint set if this is a loop swizzle. Loop swizzles don't impact - // indexing, only iteration order. - if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) { - disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX()); - disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY()); - } - } -} - bool IterDomainGraph::exprsMap( Expr* first, Expr* second, @@ -344,10 +319,27 @@ findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { } // namespace -void IterDomainGraph::build(Fusion* fusion) { - FusionGuard fg(fusion); +// TODO: Should we avoid marking leaf nodes at this point? +void IterDomainGraph::initializeId( + IterDomain* id, + bool is_view_rfactor_id, + bool is_leaf_id) { + nodes(IdMappingMode::PERMISSIVE).initializeSet(id); + nodes(IdMappingMode::EXACT).initializeSet(id); + if (is_leaf_id) { + nodes(IdMappingMode::LOOP).initializeSet(id); + } + consumers_[id] = {}; + producers_[id] = {}; + + if (is_view_rfactor_id) { + view_rfactor_ids_.emplace(id); + } +} - // Initialize a node for every iteration domain +void IterDomainGraph::initialIdProcessing(Fusion* fusion) { + // Initialize a node for every iteration domain and mark view like iteration + // domains and leaf iteration domains. for (auto tv : ir_utils::allTvs(fusion)) { const auto& domain = tv->domain()->domain(); auto all_ids = ir_utils::allIDsOf(tv); @@ -373,169 +365,209 @@ void IterDomainGraph::build(Fusion* fusion) { initializeId(id, is_view_rfactor_id, is_leaf_id); } } +} - // All ID's are initialized, start connecting them on the permissive, exact, - // and loop dimensions. - - for (auto expr : fusion->exprs()) { - if (!ir_utils::isTvOp(expr)) { - continue; - } +void IterDomainGraph::mapMultiOutput(Expr* expr) { + auto tv_outputs = ir_utils::filterByType(expr->outputs()); + if (std::distance(tv_outputs.begin(), tv_outputs.end()) <= 1) { + // No multi TV outputs to map just return + return; + } - auto tv_outputs = ir_utils::filterByType(expr->outputs()); - TensorView* first_output_tv = nullptr; - - for (auto c_tv : tv_outputs) { - if (first_output_tv == nullptr) { - first_output_tv = c_tv; - } else { - // Map multi outputs of an expression to each other. c is current - // output, and f as first output. Keep consistent with the later section - // of producer and consumers. Which here producer is now "first output", - // and consumer is still consumer. One exception is how the - // domains left of CA positions are handled in the Parallel - // map. Those domains are not mapped in producer and consumer - // mappings as they do not share loops, but are mapped in the - // case of mapping multiple outputs since they do share the - // same loops. + TensorView* first_output_tv = *tv_outputs.begin(); + std::deque other_tv_outputs( + tv_outputs.begin(), tv_outputs.end()); + other_tv_outputs.pop_front(); + + for (auto other_tv_output : other_tv_outputs) { + // Map multi outputs of an expression to each other. c is current + // output, and f as first output. Keep consistent with the later section + // of producer and consumers. Which here producer is now "first output", + // and consumer is still consumer. One exception is how the + // domains left of CA positions are handled in the Parallel + // map. Those domains are not mapped in producer and consumer + // mappings as they do not share loops, but are mapped in the + // case of mapping multiple outputs since they do share the + // same loops. - TORCH_INTERNAL_ASSERT( - c_tv->getRootDomain().size() == - first_output_tv->getRootDomain().size(), - "Multiple outputs with mismatched dimensions is not supported. ", - "Only supported case is welford op where all outputs tvs have idential domains."); - // p->f, c->c - std::unordered_map c2f_root_map; - for (const auto i : - c10::irange(first_output_tv->getRootDomain().size())) { - c2f_root_map.insert(std::make_pair( - c_tv->getRootDomain()[i], first_output_tv->getRootDomain()[i])); - } + TORCH_INTERNAL_ASSERT( + other_tv_output->getRootDomain().size() == + first_output_tv->getRootDomain().size(), + "Multiple outputs with mismatched dimensions is not supported. ", + "Only supported case is welford op where all outputs tvs have idential domains."); + // other to first map + std::unordered_map o2f; + for (const auto i : c10::irange(first_output_tv->getRootDomain().size())) { + o2f.insert(std::make_pair( + other_tv_output->getRootDomain()[i], + first_output_tv->getRootDomain()[i])); + } - // Multi output mapping, outputs are required to have the same domain - // and same transformations, so they can be mapped in permissive/exact, - // and when within compute at position of domain()->domain() in the - // parallel map. - auto replay_FasC = BestEffortReplay( - first_output_tv->domain()->domain(), - c_tv->domain()->domain(), - c2f_root_map); - - // Map the entire replay map between the multiple - // consumers - auto c2f_disjoint_sets = replay_FasC.getIterDomainEquivalence(); - for (auto disjoint_set : c2f_disjoint_sets.disjointSets()) { - if (disjoint_set->empty()) { - continue; - } - auto id0 = *disjoint_set->begin(); - for (auto id1 : disjoint_set->vector()) { - mapNodes(id0, id1, IdMappingMode::PERMISSIVE); - mapNodes(id0, id1, IdMappingMode::EXACT); - } - } + // Multi output mapping, outputs are required to have the same domain + // and same transformations, so they can be mapped in permissive/exact, + // and when within compute at position of domain()->domain() in the + // parallel map. + auto replay_FasC = BestEffortReplay( + first_output_tv->domain()->domain(), + other_tv_output->domain()->domain(), + o2f); + + // Map the entire replay map between the multiple + // consumers + auto c2f_disjoint_sets = replay_FasC.getIterDomainEquivalence(); + for (auto disjoint_set : c2f_disjoint_sets.disjointSets()) { + if (disjoint_set->empty()) { + continue; + } + auto id0 = *disjoint_set->begin(); + for (auto id1 : disjoint_set->vector()) { + mapNodes(id0, id1, IdMappingMode::PERMISSIVE); + mapNodes(id0, id1, IdMappingMode::EXACT); + } + } - // Map all entries for the Loop map as they share the same loops. - for (auto f_id : first_output_tv->domain()->domain()) { - auto disjoint_set = c2f_disjoint_sets.getDisjointSetOf(f_id); - auto id0 = *(disjoint_set.begin()); - for (auto id1 : disjoint_set) { - mapNodes(id0, id1, IdMappingMode::LOOP); - } - } + // Map all entries for the Loop map as they share the same loops. + for (auto f_id : first_output_tv->domain()->domain()) { + auto disjoint_set = c2f_disjoint_sets.getDisjointSetOf(f_id); + auto id0 = *(disjoint_set.begin()); + for (auto id1 : disjoint_set) { + mapNodes(id0, id1, IdMappingMode::LOOP); } + } + } +} - auto tv_inputs = ir_utils::filterByType(expr->inputs()); - - for (auto p_tv : tv_inputs) { - auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv); - - // Look for matching ID transformations in producer and consumer, replay - // producer as consumer. We use the symmetric API of BestEffortReplay so - // that both broadcast and squeeze are handled correctly. - const auto permissive_disjoint_sets = - BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map) - .getIterDomainEquivalence(); - - // For exact mapings do not map any broadcast dimensions to - // non-broadcast dimensions. Prevent any broadcasted axes being mapped - // to non-broadcasted axes. - auto exact_c2p_root_map = - PairwiseRootDomainMap(p_tv, c_tv, true) - .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); - - // Same as permissive above but for exact - auto exact_replay_PasC = BestEffortReplay( - p_tv->domain()->domain(), - c_tv->domain()->domain(), - exact_c2p_root_map); - - const auto& exact_c2p_map = exact_replay_PasC.getReplay(); - - for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { - auto p_id = exact_c2p_map.at(c_id); - mapNodes(c_id, p_id, IdMappingMode::EXACT); - consumers_.at(p_id).pushBack(c_id); - producers_.at(c_id).pushBack(p_id); - - // Add the swizzle inputs to the same - // disjoint set as well if either c_id - // or p_id is swizzle output. - mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), p_id); - mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), c_id); - } +namespace { +//! Map corresponding inputs and outputs of swizzle op together +//! on the given disjoint set, if the given id is an output +//! of a swizzle operator. +//! +//! The current usage of swizzle operator is local to each tensor +//! itself, so they should not affect exact or permissive mapping +//! between iterdomains on different tensor domains. +//! TODO: +//! Exact mapping based index hoisting of swizzled iterdomains +//! is disabled currently and will be re-enabled in the next +//! few build out steps. +void mapMaybeSwizzleOp( + DisjointSets& disjoint_sets, + IterDomain* id) { + if (auto swizzle_2d = dynamic_cast(id->definition())) { + // Map each input to its corresponding output on the given + // disjoint set if this is a loop swizzle. Loop swizzles don't impact + // indexing, only iteration order. + if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) { + disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX()); + disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY()); + } + } +} +} // namespace - auto p_ids_vec = ir_utils::allIDsOf(p_tv); - auto c_ids_vec = ir_utils::allIDsOf(c_tv); - std::unordered_set p_ids( - p_ids_vec.begin(), p_ids_vec.end()); - std::unordered_set c_ids( - c_ids_vec.begin(), c_ids_vec.end()); - - for (auto& dset : permissive_disjoint_sets.disjointSets()) { - auto& vec = dset->vector(); - for (auto i : c10::irange(vec.size())) { - auto id1 = vec[i]; - mapNodes(id1, vec[0], IdMappingMode::PERMISSIVE); - - // Add the swizzle inputs to the same - // disjoint set as well if either c_id - // or p_id is swizzle output. - mapMaybeSwizzleOp(nodes(IdMappingMode::PERMISSIVE), id1); - - for (auto j : c10::irange(i + 1, vec.size())) { - auto id2 = vec[j]; - if (p_ids.count(id1) && c_ids.count(id2)) { - consumers_.at(id1).pushBack(id2); - producers_.at(id2).pushBack(id1); - if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && - idIsALeafDomain(id2, c_tv)) { - mapNodes(id1, id2, IdMappingMode::LOOP); - } - } - if (c_ids.count(id1) && p_ids.count(id2)) { - producers_.at(id1).pushBack(id2); - consumers_.at(id2).pushBack(id1); - if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && - idIsALeafDomain(id1, c_tv)) { - mapNodes(id1, id2, IdMappingMode::LOOP); - } - } +void IterDomainGraph::mapExact(Expr* expr) { + TensorView* c_tv = ir_utils::getTvOutput(expr); + + auto tv_inputs = ir_utils::filterByType(expr->inputs()); + for (auto p_tv : tv_inputs) { + // For exact mapings do not map any broadcast dimensions to + // non-broadcast dimensions. Prevent any broadcasted axes being mapped + // to non-broadcasted axes. + auto exact_c2p_root_map = + PairwiseRootDomainMap(p_tv, c_tv, true) + .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); + + // Same as permissive above but for exact + auto exact_replay_PasC = BestEffortReplay( + p_tv->domain()->domain(), c_tv->domain()->domain(), exact_c2p_root_map); + + const auto& exact_c2p_map = exact_replay_PasC.getReplay(); + + for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { + auto p_id = exact_c2p_map.at(c_id); + mapNodes(c_id, p_id, IdMappingMode::EXACT); + + // TODO: consumers/producers should be on a per map basis, mapping should + // include unique expr between the disjoint sets + consumers_.at(p_id).pushBack(c_id); + producers_.at(c_id).pushBack(p_id); + + // Add the swizzle inputs to the same + // disjoint set as well if either c_id + // or p_id is swizzle output. + mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), p_id); + mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), c_id); + } + } +} + +void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { + // Multiple outputs are already mapped, we can ignore all but the first + // consumer given they have to be replayed in the same exact way + TensorView* c_tv = ir_utils::getTvOutput(expr); + + auto tv_inputs = ir_utils::filterByType(expr->inputs()); + + for (auto p_tv : tv_inputs) { + auto p_ids_vec = ir_utils::allIDsOf(p_tv); + auto c_ids_vec = ir_utils::allIDsOf(c_tv); + std::unordered_set p_ids(p_ids_vec.begin(), p_ids_vec.end()); + std::unordered_set c_ids(c_ids_vec.begin(), c_ids_vec.end()); + + auto permissive_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv); + + // Look for matching ID transformations in producer and consumer, replay + // producer as consumer. We use the symmetric API of BestEffortReplay so + // that both broadcast and squeeze are handled correctly. + const auto permissive_disjoint_sets = + BestEffortReplay::replayPasC(p_tv, c_tv, -1, permissive_c2p_root_map) + .getIterDomainEquivalence(); + + for (auto& dset : permissive_disjoint_sets.disjointSets()) { + auto& vec = dset->vector(); + for (auto i : c10::irange(vec.size())) { + auto id1 = vec[i]; + mapNodes(id1, vec[0], IdMappingMode::PERMISSIVE); + + // Add the swizzle inputs to the same + // disjoint set as well if either c_id + // or p_id is swizzle output. + mapMaybeSwizzleOp(nodes(IdMappingMode::PERMISSIVE), id1); + + // Loop/producer/consumer + for (auto j : c10::irange(i + 1, vec.size())) { + auto id2 = vec[j]; + if (p_ids.count(id1) && c_ids.count(id2)) { + consumers_.at(id1).pushBack(id2); + producers_.at(id2).pushBack(id1); + if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && + idIsALeafDomain(id2, c_tv)) { + mapNodes(id1, id2, IdMappingMode::LOOP); + } + } + if (c_ids.count(id1) && p_ids.count(id2)) { + producers_.at(id1).pushBack(id2); + consumers_.at(id2).pushBack(id1); + if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && + idIsALeafDomain(id1, c_tv)) { + mapNodes(id1, id2, IdMappingMode::LOOP); } } } } } } +} +void IterDomainGraph::mapRFactorExprs(Fusion* fusion) { // Explicitly map through rfactor transformations, if we have an op like: // // T1[x, y*z] = view(T0[x*y, z]) // T3[x, y*z] = view(T2[x*y, z]) // T4 = T0 + T2 // - // We want to map T1 and T3's rfactor transformations together by playing the - // transformations forward since their root domains map. If instead we have: + // We want to map T1 and T3's rfactor transformations together by playing + // the transformations forward since their root domains map. If instead we + // have: // // T1[x, y*z] = view(T0[x*y, z]) // T3[x, y*z] = view(T2[x*y, z]) @@ -546,10 +578,10 @@ void IterDomainGraph::build(Fusion* fusion) { // rfactor transformations starting at their rfactor domains. // // Therefore we'll explicitly map rfactor transformation iteration domains - // forward and backwards. Something similar could happen with rfactor of root - // domains, though it seems mapping rfactor reduction domains aren't that - // important. Mapping view transformations is more important since view is - // part of the compute definition so having the map through the + // forward and backwards. Something similar could happen with rfactor of + // root domains, though it seems mapping rfactor reduction domains aren't + // that important. Mapping view transformations is more important since view + // is part of the compute definition so having the map through the // transformations makes it easy to check if different view operations are // consistent with eachother. @@ -563,10 +595,10 @@ void IterDomainGraph::build(Fusion* fusion) { // IterDomains could have multiple uses defined in the fusion if multiple // transformations were redefined (more than one transform propagation pass - // was run and retransformed sections of the graph). We're going to make a new - // uses map so we can easily process the actual uses of IterDomains. We - // actually only need rfactor uses for this section of mapping, so we'll limit - // this map to only rfactor transformations. + // was run and retransformed sections of the graph). We're going to make a + // new uses map so we can easily process the actual uses of IterDomains. We + // actually only need rfactor uses for this section of mapping, so we'll + // limit this map to only rfactor transformations. std::unordered_map rfactor_id_uses; // Order of traversal is important for processing all the rfactor ids as the @@ -621,8 +653,8 @@ void IterDomainGraph::build(Fusion* fusion) { ? rfactor_id_order[rfactor_id_i] : rfactor_id_order[rfactor_id_order.size() - 1 - rfactor_id_i]; - // At should be safe since we made rfactor_id_order and rfactor_id_uses at - // the same time so they should have the same exact entries. + // At should be safe since we made rfactor_id_order and rfactor_id_uses + // at the same time so they should have the same exact entries. auto first_expr = prop_forward ? rfactor_id_uses.at(first_rfactor_id) : first_rfactor_id->definition(); @@ -675,7 +707,9 @@ void IterDomainGraph::build(Fusion* fusion) { } } } +} +void IterDomainGraph::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes nodes(IdMappingMode::ALMOSTEXACT) = nodes(IdMappingMode::EXACT); std::unordered_set visited; @@ -706,25 +740,39 @@ void IterDomainGraph::build(Fusion* fusion) { } } } - - self_mapping_info_ = findFirstSelfMapping(fusion, *this); } -void IterDomainGraph::initializeId( - IterDomain* id, - bool is_view_rfactor_id, - bool is_leaf_id) { - nodes(IdMappingMode::PERMISSIVE).initializeSet(id); - nodes(IdMappingMode::EXACT).initializeSet(id); - if (is_leaf_id) { - nodes(IdMappingMode::LOOP).initializeSet(id); - } - consumers_[id] = {}; - producers_[id] = {}; +void IterDomainGraph::build(Fusion* fusion) { + FusionGuard fg(fusion); - if (is_view_rfactor_id) { - view_rfactor_ids_.emplace(id); + // Initialize the maps with all the IterDomains defined in the fusion. + initialIdProcessing(fusion); + + for (auto expr : fusion->exprs()) { + if (!ir_utils::isTvOp(expr)) { + continue; + } + + // Connect multi-output expressions as they're trivial to connect. + mapMultiOutput(expr); + + // Connect ID's on the exact dimension + mapExact(expr); + + // Connect across the permissive, loop, and for now consumer_, producer_ + // dimensions. + mapPermissiveAndLoop(expr); } + + // Map forward and backward through TV root<->rfactor to cross map connections + // that are not explicitly defined through input<->output expression maps. + mapRFactorExprs(fusion); + + buildAlmostExactMap(); + + // Debug, make sure there's no self mapping in TensorView's during lowering + // that would invalidate lowering assumptions. + self_mapping_info_ = findFirstSelfMapping(fusion, *this); } ComputeAtMap::ComputeAtMap(Fusion* fusion) @@ -872,13 +920,13 @@ IterDomain* ComputeAtMap::computeConcreteId( id->toString()); if (disjoint_set_shared_ptr->vector().size() == 1) { - // If only one entry in the disjoint set, by definition the existing ID has - // to be the concrete ID. + // If only one entry in the disjoint set, by definition the existing ID + // has to be the concrete ID. return disjoint_set_shared_ptr->vector().front(); } - // Grab a set of candidate concrete_ids, we track towards the consumers in the - // ID group as one of those is guaranteed to be a valid concrete id. + // Grab a set of candidate concrete_ids, we track towards the consumers in + // the ID group as one of those is guaranteed to be a valid concrete id. VectorOfUniqueEntries maybe_concrete_ids; for (auto id : disjoint_set_shared_ptr->vector()) { bool id_output = true; @@ -904,17 +952,17 @@ IterDomain* ComputeAtMap::computeConcreteId( return maybe_concrete_ids.vector().front(); } - // Broadcast resolution is what we have to figure out here. So if we traverse - // back from leaves to rfactor inputs through the exact map, if there's an - // operation with a broadcast input that's resolved within the history all of - // the domains in all of the maybe_rfactor_ids, then the concrete ID must - // resolve that broadcast. + // Broadcast resolution is what we have to figure out here. So if we + // traverse back from leaves to rfactor inputs through the exact map, if + // there's an operation with a broadcast input that's resolved within the + // history all of the domains in all of the maybe_rfactor_ids, then the + // concrete ID must resolve that broadcast. // // (1) Compute "traversed IDs" which is every exact disjoint set starting at // all maybe concrete ID's traversing back through exact map. // - // (2) Check all broadcast sets, remove from "traversed IDs" any broadcast set - // that has its broadcast resolved ID within "traversed IDs", and all + // (2) Check all broadcast sets, remove from "traversed IDs" any broadcast + // set that has its broadcast resolved ID within "traversed IDs", and all // IterDomains dependant on that broadcast. // // (3) Start at all "traversed IDs" set that has an rfactor domain, traverse @@ -934,14 +982,14 @@ IterDomain* ComputeAtMap::computeConcreteId( disjointSetOf(maybe_concrete_id, IdMappingMode::EXACT)); } - // Going to iteratively modify this to be all sets that the concrete ID needs - // to cover + // Going to iteratively modify this to be all sets that the concrete ID + // needs to cover VectorOfUniqueEntries>> all_exact_sets_covered = getAllDisjointSetProducers(maybe_concrete_exact_sets); - // Remove all broadcast domains that are resolved within the history of any of - // the maybe concrete sets. + // Remove all broadcast domains that are resolved within the history of any + // of the maybe concrete sets. { // All broadcast exact sets in all_exact_sets_covered that are resolved by // IterDomains in all_exact_sets_covered @@ -985,8 +1033,8 @@ IterDomain* ComputeAtMap::computeConcreteId( auto all_resolved_broadcast_uses = getAllDisjointSetConsumers(resolved_broadcasts); - // Remove broadcast resolved sets from all_exact_sets_covered by effectively - // doing an inplace copy_if + // Remove broadcast resolved sets from all_exact_sets_covered by + // effectively doing an inplace copy_if VectorOfUniqueEntries>> tmp_all_exact_sets_covered; std::swap(tmp_all_exact_sets_covered, all_exact_sets_covered); @@ -1065,8 +1113,8 @@ IterDomain* ComputeAtMap::computeConcreteId( // The concrete_id should have the most roots it can trace back to that are // iter domains, (non-broadcast/non-reduction). We don't trace back through - // view operations, so the one with the most iter root domains is the concrete - // ID. + // view operations, so the one with the most iter root domains is the + // concrete ID. IterDomain* concrete_id = nullptr; int max_iter_root_count = 0; int max_bcast_root_count = 0; @@ -1103,8 +1151,8 @@ IterDomain* ComputeAtMap::computeConcreteId( void ComputeAtMap::buildConcreteIds() { // For the exact map just select the first ID since they're all exactly the // same size, it doesn't matter which is selected. This should be run-to-run - // deterministic but which ID gets selected her depends on the traversal order - // generating the set (compute at map build). + // deterministic but which ID gets selected her depends on the traversal + // order generating the set (compute at map build). for (const auto& disjoint_set_shared_ptr : id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { TORCH_INTERNAL_ASSERT( @@ -1211,8 +1259,8 @@ void ComputeAtMap::buildUniqueExactExprMaps() { // Definition to this exact map, shouldn't be marked as a definition // to traverse on the exact map. - // This is a WAR for FusionSimpleSwizzle2_CUDA wher there is a pattern - // like: + // This is a WAR for FusionSimpleSwizzle2_CUDA wher there is a + // pattern like: // // tv0[32, 32] // tv0->swizzle(Swizzle2DType::ZShape, 0, 1); @@ -1221,8 +1269,8 @@ void ComputeAtMap::buildUniqueExactExprMaps() { // So the pre and post swizzle ID is in an exact set, but that exact // set also has the swizzle as a definition that leads to itself. // - // TODO: Try to formalize this better in the exact ID traversal. Right - // now its just interfering with concrete ID detection. + // TODO: Try to formalize this better in the exact ID traversal. + // Right now its just interfering with concrete ID detection. continue; } bool match = false; diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index bb6abd0c21b8f..6c8051d6993c8 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -100,6 +100,45 @@ class TORCH_CUDA_CU_API IterDomainGraph { private: void build(Fusion* fusion); + // ======= START Iteration domain build process in order called ======= + + // Initializes entries for the provided IterDomain in the overall + // IterDomainGraph + void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); + + // Iterates over all Iter Domains in allTvs(fusion) computes + // is_view_rfactor_id, is_leaf_id and calls initializeID. + void initialIdProcessing(Fusion* fusion); + + // Maps sibling TensorViews that are outputs of expr. TensorView outputs must + // be replayed the same as eachother, so mapping them is very straightforward. + void mapMultiOutput(Expr* expr); + + // Fills nodes_[IdMappingMode::EXACT] for relationships between inputs and + // first output of expr + void mapExact(Expr* expr); + + // Fills nodes_[IdMappingMode::PERMISSIVE] for relationships between inputs + // and first output of expr + // + // Currently also fills nodes_[IdMappingMode::LOOP], consumer_, and producer_ + void mapPermissiveAndLoop(Expr* expr); + + // Propagates forward then backward through all view like rfactor + // transformations to map cross view operations. + // + // TODO: This should be refactored to just process all IterDomain expressions + // between all Tv's root and rfactor domain. Although view is the only place + // this happens where there may be a significant perf implication. There's no + // reason we can't do this on all such transformations. + void mapRFactorExprs(Fusion* fusion); + + // Initialize AlmostExact as Exact entries, then map anything that's either + // merged with a size-1 or split by a size-1 dimension. + void buildAlmostExactMap(); + + // ======= END Iteration domain build process in order called ======= + // Non-const internal only version of getNodes. DisjointSets& nodes(IdMappingMode mode); @@ -108,8 +147,6 @@ class TORCH_CUDA_CU_API IterDomainGraph { nodes(mode).mapEntries(id0, id1); } - void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); - // Checks if expr's are considered "the same" where sameness inputs and // outputs in the same position across expressions map with provided // MappingMode. If the expressions are determined the same then