@@ -1460,7 +1460,8 @@ Val* hoistProducerIndex(
14601460std::vector<Val*> Index::getGlobalProducerStridedIndices (
14611461 TensorView* producer_tv,
14621462 const TensorView* consumer_tv,
1463- const std::vector<kir::ForLoop*>& loops) {
1463+ const std::vector<kir::ForLoop*>& loops,
1464+ const std::unordered_map<IterDomain*, Val*>& override_index) {
14641465 FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalProducerIndex" );
14651466
14661467 // Replay producer to look like consumer so we can index on producer since
@@ -1536,13 +1537,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15361537 }
15371538 }
15381539
1539- IterDomain* selected_id = nullptr ;
1540- Val* selected_index = nullptr ;
1541- if (auto sop = dynamic_cast <SelectOp*>(consumer_tv->definition ())) {
1542- selected_id = TensorDomain::noReductions (root_dom)[sop->getDim ()];
1543- selected_index = sop->input (1 );
1544- }
1545-
15461540 TORCH_INTERNAL_ASSERT (
15471541 root_dom.size () == producer_tv->domain ()->contiguity ().size ());
15481542 Val* cur_contig_stride = GpuLower::current ()->kernel ()->oneVal ();
@@ -1582,13 +1576,15 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15821576 }
15831577
15841578 Val* root_ind = nullptr ;
1585- if (producer_indexing.indexMap ().find (root_dom[i]) !=
1579+ auto override_it = override_index.find (root_dom[i]);
1580+ if (override_it != override_index.end ()) {
1581+ root_ind = override_it->second ;
1582+ } else if (
1583+ producer_indexing.indexMap ().find (root_dom[i]) !=
15861584 producer_indexing.indexMap ().end ()) {
15871585 root_ind = producer_indexing.indexMap ().at (root_dom[i]);
15881586 } else if (root_dom[i]->isBroadcast ()) {
15891587 root_ind = GpuLower::current ()->kernel ()->zeroVal ();
1590- } else if (root_dom[i] == selected_id) {
1591- root_ind = selected_index;
15921588 }
15931589
15941590 TORCH_INTERNAL_ASSERT (
@@ -1612,7 +1608,7 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
16121608 loops,
16131609 root_ind);
16141610
1615- if (root_dom[i] != selected_id ) {
1611+ if (!override_index. count ( root_dom[i]) ) {
16161612 root_ind =
16171613 getProducerIndexWithHalo (producer_tv, i, root_ind, consumer_tv);
16181614 }
@@ -1686,7 +1682,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
16861682std::vector<Val*> Index::getNonGlobalProducerStridedIndices (
16871683 TensorView* producer_tv,
16881684 const TensorView* consumer_tv,
1689- const std::vector<kir::ForLoop*>& loops) {
1685+ const std::vector<kir::ForLoop*>& loops,
1686+ const std::unordered_map<IterDomain*, Val*>& override_index) {
16901687 const auto gpu_lower = GpuLower::current ();
16911688
16921689 // Replay producer to look like consumer so we can index on producer since our
@@ -1794,13 +1791,6 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
17941791 // and use them.
17951792 auto root_dom = producer_tv->getMaybeRFactorDomain ();
17961793
1797- IterDomain* selected_id = nullptr ;
1798- Val* selected_index = nullptr ;
1799- if (auto sop = dynamic_cast <SelectOp*>(consumer_tv->definition ())) {
1800- selected_id = TensorDomain::noReductions (root_dom)[sop->getDim ()];
1801- selected_index = sop->input (1 );
1802- }
1803-
18041794 // Figure out which root axes we don't need to index
18051795 std::unordered_set<IterDomain*> skip_indexing;
18061796
@@ -1834,9 +1824,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18341824 " id: " ,
18351825 root_dom[i]->toString ());
18361826
1827+ auto override_it = override_index.find (root_dom[i]);
18371828 auto root_ind_i =
1838- (selected_id == root_dom[i] ? selected_index
1839- : index_map.at (root_dom[i]));
1829+ (override_it != override_index. end () ? override_it-> second
1830+ : index_map.at (root_dom[i]));
18401831
18411832 // index hoist must be done before the adjustments for halo
18421833 root_ind_i = hoistProducerIndex (
@@ -1850,7 +1841,7 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18501841 loops,
18511842 root_ind_i);
18521843
1853- if (root_dom[i] != selected_id ) {
1844+ if (override_index. count ( root_dom[i]) ) {
18541845 root_ind_i =
18551846 getProducerIndexWithHalo (producer_tv, i, root_ind_i, consumer_tv);
18561847 }
@@ -2237,7 +2228,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
22372228std::vector<Val*> Index::getProducerStridedIndices (
22382229 TensorView* producer,
22392230 const TensorView* consumer,
2240- const std::vector<kir::ForLoop*>& loops) {
2231+ const std::vector<kir::ForLoop*>& loops,
2232+ const std::unordered_map<IterDomain*, Val*>& override_index) {
22412233 FUSER_PERF_SCOPE (" GpuLower::Lower::Index::getProducerStridedIndices" );
22422234 if (producer->domain ()->noReductions ().size () == 0 ) {
22432235 return std::vector<Val*>(
@@ -2247,11 +2239,11 @@ std::vector<Val*> Index::getProducerStridedIndices(
22472239
22482240 std::vector<Val*> strided_indices;
22492241 if (producer->getMemoryType () == MemoryType::Global) {
2250- strided_indices =
2251- getGlobalProducerStridedIndices ( producer, consumer, loops);
2242+ strided_indices = getGlobalProducerStridedIndices (
2243+ producer, consumer, loops, override_index );
22522244 } else {
2253- strided_indices =
2254- getNonGlobalProducerStridedIndices ( producer, consumer, loops);
2245+ strided_indices = getNonGlobalProducerStridedIndices (
2246+ producer, consumer, loops, override_index );
22552247 }
22562248
22572249 TORCH_INTERNAL_ASSERT (
@@ -2267,8 +2259,10 @@ std::vector<Val*> Index::getProducerStridedIndices(
22672259kir::TensorIndex* Index::getProducerIndex (
22682260 TensorView* producer,
22692261 const TensorView* consumer,
2270- const std::vector<kir::ForLoop*>& loops) {
2271- auto strided_indices = getProducerStridedIndices (producer, consumer, loops);
2262+ const std::vector<kir::ForLoop*>& loops,
2263+ const std::unordered_map<IterDomain*, Val*>& override_index) {
2264+ auto strided_indices =
2265+ getProducerStridedIndices (producer, consumer, loops, override_index);
22722266 return SimplifyingIrBuilder::create<kir::TensorIndex>(
22732267 producer, strided_indices);
22742268}
0 commit comments