diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 2fe1b65854dddd..ae429c4104ca74 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -2813,7 +2813,6 @@ auto getPredicateReferenceIndexing( std::pair getStartAndStopOffsets( IterDomain* consumer_id, TensorView* consumer_tv, - const ReferenceTensor& reference, const std::unordered_map& consumer_start_index_map, const std::unordered_map& consumer_stop_index_map, bool padding_predicate, @@ -2979,12 +2978,12 @@ std::pair hoistPredicates( Val* start_index, Val* stop_index, const std::vector& loops, + std::vector loop_domains, + const std::unordered_map start_initial_loop_index_map, + const std::unordered_map stop_initial_loop_index_map, kir::ForLoop* unswitch_or_vec_loop, IterDomain* predicated_consumer_id, - TensorView* predicated_consumer_tv, - TensorDomain* ref_td, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map) { + TensorView* predicated_consumer_tv) { const std::pair same_indices{start_index, stop_index}; if (isDisabled(DisableOption::IndexHoist)) { @@ -3004,8 +3003,8 @@ std::pair hoistPredicates( GpuLower::current()->commonIndexMap().insert( predicated_consumer_id, predicated_consumer_tv->domain(), - ref_td, - ref_stop_index_map, + loop_domains, + stop_initial_loop_index_map, loops, stop_index); } @@ -3021,8 +3020,8 @@ std::pair hoistPredicates( GpuLower::current()->commonIndexMap().insert( predicated_consumer_id, predicated_consumer_tv->domain(), - ref_td, - ref_start_index_map, + loop_domains, + start_initial_loop_index_map, loops, start_index); } @@ -3033,12 +3032,11 @@ std::pair hoistPredicates( } // namespace // Returns predicates and the concrete (by loop map) root domains they cover -std::pair, ReferenceTensor> Index:: - getReferenceRootPredicates( - TensorView* consumer_tv, - const std::vector& loops, - kir::ForLoop* unswitch_or_vec_loop, - bool shift_padding) { +std::vector Index::getReferenceRootPredicates( + TensorView* consumer_tv, + const std::vector& loops, + kir::ForLoop* unswitch_or_vec_loop, + bool shift_padding) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates"); const auto gpu_lower = GpuLower::current(); @@ -3047,22 +3045,9 @@ std::pair, ReferenceTensor> Index:: // Nothing needs to be done when padding is not required. if (shift_padding && !needsPadding(consumer_tv)) { - return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}}; + return {RootPredicateInfo::getFalseInfo()}; } - // Get a reference tensor replayed as existing loop structure - ReferenceTensor reference = - IndexReferenceReplay::getReference(loops, consumer_tv); - - // Generate halo information for reference. - updateHaloInfoForReference(reference, consumer_tv); - - const auto ref_2_consumer = indexMapReferenceTo( - consumer_tv, gpu_lower->caMap(), reference.concrete_to_id); - - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, ref_2_consumer); - auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); // Indexing is done without considering contig merging. Actual @@ -3073,38 +3058,27 @@ std::pair, ReferenceTensor> Index:: std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), {}); + // Generate start and stop indexing from idgraph. + // // Both start and stop positions may need to be predicated. Indexing // differs when generating predicates for unswitch. // NOTE: If we could find-and-replace KIR nodes, we could just // generate one index map, clone it and replace the loop-to-index // mappings of unswitched loops for the start predicate. - auto ref_stop_indexing = getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, db_axis, false); - const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute( - consumer_tv->domain(), - ref_2_consumer, - contig_finder, - reference_halo_extent_map); + + auto stop_indexing_from_idgraph = getPredicateIndexingFromIdGraph( + loops, consumer_tv, unswitch_or_vec_loop, db_axis, false); + const auto consumer_stop_indexing = stop_indexing_from_idgraph.index; const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap(); // If not unswitch, share the same indexing map as the stop index // map - const auto& ref_start_indexing = is_unswitch - ? getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, db_axis, true) - : ref_stop_indexing; - - std::unordered_map consumer_start_index_map; - if (is_unswitch) { - const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute( - consumer_tv->domain(), - ref_2_consumer, - contig_finder, - reference_halo_extent_map); - consumer_start_index_map = consumer_start_indexing.indexMap(); - } else { - consumer_start_index_map = consumer_stop_index_map; - } + const auto start_indexing_from_idgraph = is_unswitch + ? getPredicateIndexingFromIdGraph( + loops, consumer_tv, unswitch_or_vec_loop, db_axis, true) + : stop_indexing_from_idgraph; + const auto consumer_start_indexing = start_indexing_from_idgraph.index; + const auto& consumer_start_index_map = consumer_start_indexing.indexMap(); // Get the contiguous ids we need to generate predicates for auto contig_id_infos = @@ -3161,7 +3135,6 @@ std::pair, ReferenceTensor> Index:: std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets( contig_id, consumer_tv, - reference, consumer_start_index_map, consumer_stop_index_map, shift_padding, @@ -3175,12 +3148,12 @@ std::pair, ReferenceTensor> Index:: start_index, stop_index, loops, + stop_indexing_from_idgraph.resolved_loop_domains, + start_indexing_from_idgraph.initial_concrete_index_map, + stop_indexing_from_idgraph.initial_concrete_index_map, unswitch_or_vec_loop, contig_id, - consumer_tv, - reference.domain, - ref_start_indexing.indexMap(), - ref_stop_indexing.indexMap()); + consumer_tv); // Build predicates for start positions as: // start_index + start_offset >= 0 @@ -3217,7 +3190,7 @@ std::pair, ReferenceTensor> Index:: pred_info_vec.emplace_back(info); } - return {pred_info_vec, reference}; + return pred_info_vec; } bool Index::protectWithMagicZero( diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index e9386f5a53a2c2..49c124133b0d11 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -345,8 +345,7 @@ class Index { //! this is not a bool value as if we have an unswitch loop with a vectorized //! loop inside, we only want to base the "unswitch" like predicate on the //! vectorized loop. - static std::pair, ReferenceTensor> - getReferenceRootPredicates( + static std::vector getReferenceRootPredicates( TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index ad1610c40aa8f5..7bf22f4d133021 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -268,6 +268,179 @@ IndexingParameters getNonGlobalInitialIndexParameters( return index_parameters; } +//! Initial index parameters for predicate, adjusts loop to indexing +//! may according to the information annotated on the loop nest. +//! +//! TODO: +//! This function is mostly copy pasted from previous implementation +//! at this step, further clean up is possible since: +//! 1. Much of the loop-to-ind adjustment will be issued from idgraph +//! 2. Much of the initial index logic could be shared across all +//! the 3 variants. +IndexingParameters getPredicateInitialIndexParameters( + const LoopIndexing& loop_indexing, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate) { + IndexingParameters index_parameters; + const auto& loops = loop_indexing.loops(); + const auto& loop_domains = loop_indexing.loopDomains(); + + // This shouldn't be needed. + TORCH_INTERNAL_ASSERT( + loops.size() <= loop_domains.size(), + "Loop domain didn't replay all loops"); + + std::unordered_map loop_to_ind_map; + + // Fill initial index with each forloop's index. + std::transform( + loops.begin(), + loops.end(), + std::inserter(loop_to_ind_map, loop_to_ind_map.begin()), + [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); }); + + // Generate unswitch loop to index map. + if (unswitch_or_vec_loop != nullptr) { + // Vectorized predicates are different from unswitch. Unswitch predicates + // all loops within the unswitch (the outer most unswitch) are generated + // with loop->extent-1 as the index. With vectorized predicates, only the + // vectorized loop should be like this. + bool vectorized_pred = + unswitch_or_vec_loop->iter_domain()->getParallelType() == + ParallelType::Vectorize; + + bool within_unswitch = false; + + for (const auto loop_i : c10::irange(loops.size())) { + auto loop = loops[loop_i]; + auto loop_id = loop->iter_domain(); + auto loop_pt = loop_id->getParallelType(); + auto ref_id = loop_domains.at(loop_i); + + if (loop == unswitch_or_vec_loop) { + within_unswitch = true; + } + + if (within_unswitch) { + // Rely on the reference to check broadcasting. The for loop could be + // broadcasted on a constant value from an unroll split. Since reference + // may convert this to an iter domain, that for loop could be valid to + // generate predication from. + + // Note that loop->stop() is not used below. Instead, + // loop->iter_domain()->extent() is used, which is uniform + // across the mapped domains irrespective of halo. Predicates are + // compared with each to pick the most restrictive ones. The + // comparison is done by only using the offset, which is the + // term added to the index. So, the index term must be the + // same among all predicates, otherwise the comparison would + // be invalid. The effect by halo is added to the offset + // term. See getUnswitchStopOffset. + + if (ref_id->isBroadcast()) { + // Ignore indexing into broadcasted dimensions. + continue; + } else if (loop_id->isThread()) { + // When parallelized, if the loop stop is the same as the + // extent of the associated IterDomain, i.e., no extra + // iterations for halo, predicating with the threading index + // is sufficient for both the start and stop + // predicates. That isn't the case if the loop has halo, and + // in the case either the minimum and maximum values of the + // iteration domain needs to be used. + // + // Note: Better performance was obtained if using + // threadIdx in unswitch predicates was avoided. More + // specifically, in the Hdiff stencil example, instead of + // predicating with threadIdx.x for both the start and stop + // predicates, using zero and (blockDim.x - 1) for the start + // and stop predicates, respectively, resulted in less + // register pressure. The alternative codegen can be done by + // adding this to the first if condition: + // loop_id->isBlockDim(). This would not be a concern if the + // else part could be omitted, so canOmitElseClause should + // be used as well. + if (loop->stop() == loop_id->extent()) { + loop_to_ind_map[loop] = loop->start(); + } else if (is_start_predicate) { + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); + } else { + // Note that the parallel dimension is used rather than + // loop-stop(). See the above comment. + loop_to_ind_map[loop] = + GpuLower::current()->parallelDimensionMap().get(loop_pt); + } + } else if (is_start_predicate) { + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); + } else { + // Similar to the above, loop_id()->extent() is + // used here instead of loop->stop(). See the above comment. + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + loop_id->extent(), GpuLower::current()->kernel()->oneVal()); + } + } + + // If a vectorized predicate, bail after the vectorized loop was found. + // Don't continue unswitching loops. + if (vectorized_pred && within_unswitch) { + break; + } + } + } + + // Modify trivial loops to use the loop start value. + // FIXME: eventually should be all lifted in idgraph. + for (const auto loop : loops) { + auto& idx = loop_to_ind_map.at(loop); + // If the loop is trivial, the loop index can only be the loop + // start value. + if (idx == loop->index() && loop->isTrivial()) { + idx = loop->start(); + } + } + + // Increment double buffer loop index + if (double_buffer_axis != nullptr) { + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + double_buffer_axis, loops, true); + if (db_loop != nullptr) { + auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); + TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); + auto cur_index = loop_to_ind_map_it->second; + // if cur_index is not the same as the index of db_loop, it must + // be true that that index has been modified to support + // unswitch. In that case, it is not necessary to move ahead the + // index for double buffering. + if (cur_index == db_loop->index()) { + loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr( + cur_index, GpuLower::current()->kernel()->oneVal()); + } + } + } + + // Convert loop-to-ind map to concrete-to-ind map + for (int loop_idx : c10::irange(loops.size())) { + auto loop = loops.at(loop_idx); + auto concrete_loop_domain = + ir_utils::caMapExactConcreteId(loop_domains.at(loop_idx)); + index_parameters.initial_concrete_id_index[concrete_loop_domain] = + loop_to_ind_map.at(loop); + } + + insertMagicZero( + loops, + loop_indexing.loopDomains(), + index_parameters.initial_concrete_id_index); + + // Derive the halo extents from the loop indexing result. + index_parameters.concrete_id_to_halo_extent = + GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + + return index_parameters; +} + } // namespace class LoopIndexingAnalysis { @@ -715,6 +888,74 @@ IndexFromIdGraph getTensorIndexFromIdGraph( loop_indexing.loopDomains()); } +IndexFromIdGraph getPredicateIndexingFromIdGraph( + const std::vector& loops, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate) { + // Run replay pass on the loop nest to generate the deterministic + // traversal info from loop structure. + auto loop_indexing = + LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); + + // Bind initial index variables to the loop nodes and adjust + // according to loop and unswitch info. + auto index_parameters = getPredicateInitialIndexParameters( + loop_indexing, + consumer_tv, + unswitch_or_vec_loop, + double_buffer_axis, + is_start_predicate); + + // Run first backward traversal to generate + // loop nest based indexing math. + IndexCompute indexing( + index_parameters.initial_concrete_id_index, + index_parameters.zero_domains, + index_parameters.preferred_concrete_ids, + index_parameters.concrete_id_to_halo_extent); + + indexing.run(loop_indexing); + + // Map the concrete id indexing back to consumer tv + std::unordered_map index_update_map; + + // First collect all iterdomains in consumer transform history. + auto all_consumer_vals = DependencyCheck::getAllValsBetween( + {consumer_tv->getMaybeRFactorDomain().begin(), + consumer_tv->getMaybeRFactorDomain().end()}, + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); + + for (IterDomain* consumer_id : + ir_utils::filterByType(all_consumer_vals)) { + // Track the non-concrete id we were trying to bind index + // to, whether from producer or consumer. + auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id); + index_update_map[exact_concrete_id] = consumer_id; + } + + // No contiguity info is used in the predicate indexing pass, + // the predicate generation logic that uses the index math + // generated here will take contiguity into account. + ContigIDs contig_finder( + consumer_tv->domain()->domain(), + consumer_tv->getMaybeRFactorDomain(), + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + {}); + + // Run second backward traversal to map back to the consumer_tv + auto target_indexing = indexing.updateIndexCompute( + consumer_tv->domain(), index_update_map, contig_finder); + + return IndexFromIdGraph( + target_indexing, + indexing, + index_parameters.initial_concrete_id_index, + loop_indexing.loopDomains()); +} + namespace { class LoopIndexingTraversal { diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.h b/torch/csrc/jit/codegen/cuda/lower_index_compute.h index a10931e925964d..9217a64976b5af 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.h +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.h @@ -36,6 +36,17 @@ IndexFromIdGraph getTensorIndexFromIdGraph( bool is_global = true, std::unordered_map c2p_map = {}); +//! Indexing interface for calculating predicate index returns IndexFromIdGraph +//! which the IndexCompute object can be queried from directly for the produced +//! indexing If is_start_predicate, will produce indexing math for the start +//! predicates. +IndexFromIdGraph getPredicateIndexingFromIdGraph( + const std::vector& loops, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate); + //! getTensorIndexFromIdGraph is the function that index_compute will call very //! straightforwardly. However, for implementing the new indexing logic that //! starts to abstract some of the indexing away from index_compute we need to diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index 965677154c5e20..2941b96fdae104 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -354,10 +354,8 @@ Bool* PredicateCompute::getInlinePredicate( ->as(); } - auto pred_info_vec = - Index::getReferenceRootPredicates( - out_tv, loops, nullptr, pred_type == PredicateType::Padding) - .first; + auto pred_info_vec = Index::getReferenceRootPredicates( + out_tv, loops, nullptr, pred_type == PredicateType::Padding); std::vector preds; @@ -466,7 +464,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { // temporarily placed in the predicated_keys map and the final // predicates are generated in the finalize function. - for (const auto& pred_info : ref_pred_info.first) { + for (const auto& pred_info : ref_pred_info) { TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr); TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr);