@@ -69,9 +69,11 @@ Val* getProducerIndexWithHalo(
6969 const TensorView* producer_tv,
7070 size_t producer_axis,
7171 Val* producer_index,
72- const TensorView* consumer_tv) {
73- const auto offset =
74- getProducerHaloOffset (producer_tv, producer_axis, consumer_tv);
72+ const TensorView* consumer_tv,
73+ bool is_overriden_index) {
74+ const auto offset = is_overriden_index
75+ ? 0
76+ : getProducerHaloOffset (producer_tv, producer_axis, consumer_tv);
7577
7678 if (offset == 0 ) {
7779 return producer_index;
@@ -1460,7 +1462,8 @@ Val* hoistProducerIndex(
14601462std::vector<Val*> Index::getGlobalProducerStridedIndices (
14611463 TensorView* producer_tv,
14621464 const TensorView* consumer_tv,
1463- const std::vector<kir::ForLoop*>& loops) {
1465+ const std::vector<kir::ForLoop*>& loops,
1466+ const std::unordered_map<IterDomain*, Val*>& override_index) {
14641467 FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalProducerIndex" );
14651468
14661469 // Replay producer to look like consumer so we can index on producer since
@@ -1545,23 +1548,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15451548 continue ;
15461549 }
15471550
1548- Val* root_ind = nullptr ;
1549- if (producer_indexing.indexMap ().find (root_dom[dim]) !=
1550- producer_indexing.indexMap ().end ()) {
1551- root_ind = producer_indexing.indexMap ().at (root_dom[dim]);
1552- } else if (root_dom[dim]->isBroadcast ()) {
1553- root_ind = GpuLower::current ()->kernel ()->zeroVal ();
1554- }
1555-
1556- TORCH_INTERNAL_ASSERT (
1557- root_ind != nullptr ,
1558- " Couldn't find root mapping for " ,
1559- producer_tv->toString (),
1560- " dim: " ,
1561- dim,
1562- " id: " ,
1563- root_dom[dim]->toString ());
1564-
15651551 if (producer_tv->domain ()->contiguity ()[dim]) {
15661552 // If contig, used the stored stride which may be the previous
15671553 // dimensions stride * previous dimensions size
@@ -1591,18 +1577,27 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15911577 continue ;
15921578 }
15931579
1594- TORCH_INTERNAL_ASSERT (
1580+ Val* root_ind = nullptr ;
1581+ auto override_it = override_index.find (root_dom[i]);
1582+ if (override_it != override_index.end ()) {
1583+ root_ind = override_it->second ;
1584+ } else if (
15951585 producer_indexing.indexMap ().find (root_dom[i]) !=
1596- producer_indexing.indexMap ().end (),
1597- " Couldn't find root mapping for TV" ,
1598- producer_tv->name (),
1586+ producer_indexing.indexMap ().end ()) {
1587+ root_ind = producer_indexing.indexMap ().at (root_dom[i]);
1588+ } else if (root_dom[i]->isBroadcast ()) {
1589+ root_ind = GpuLower::current ()->kernel ()->zeroVal ();
1590+ }
1591+
1592+ TORCH_INTERNAL_ASSERT (
1593+ root_ind != nullptr ,
1594+ " Couldn't find root mapping for " ,
1595+ producer_tv->toString (),
15991596 " dim: " ,
16001597 i,
16011598 " id: " ,
16021599 root_dom[i]->toString ());
16031600
1604- auto root_ind = producer_indexing.indexMap ().at (root_dom[i]);
1605-
16061601 // index hoist must be done before the adjustments for halo
16071602 root_ind = hoistProducerIndex (
16081603 root_dom[i],
@@ -1615,7 +1610,12 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
16151610 loops,
16161611 root_ind);
16171612
1618- root_ind = getProducerIndexWithHalo (producer_tv, i, root_ind, consumer_tv);
1613+ root_ind = getProducerIndexWithHalo (
1614+ producer_tv,
1615+ i,
1616+ root_ind,
1617+ consumer_tv,
1618+ override_index.count (root_dom[i]));
16191619
16201620 root_ind = getProducerIndexWithGather (
16211621 root_ind,
@@ -1686,7 +1686,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
16861686std::vector<Val*> Index::getNonGlobalProducerStridedIndices (
16871687 TensorView* producer_tv,
16881688 const TensorView* consumer_tv,
1689- const std::vector<kir::ForLoop*>& loops) {
1689+ const std::vector<kir::ForLoop*>& loops,
1690+ const std::unordered_map<IterDomain*, Val*>& override_index) {
16901691 const auto gpu_lower = GpuLower::current ();
16911692
16921693 // Replay producer to look like consumer so we can index on producer since our
@@ -1827,7 +1828,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18271828 " id: " ,
18281829 root_dom[i]->toString ());
18291830
1830- auto root_ind_i = index_map.at (root_dom[i]);
1831+ auto override_it = override_index.find (root_dom[i]);
1832+ auto root_ind_i =
1833+ (override_it != override_index.end () ? override_it->second
1834+ : index_map.at (root_dom[i]));
18311835
18321836 // index hoist must be done before the adjustments for halo
18331837 root_ind_i = hoistProducerIndex (
@@ -1841,8 +1845,12 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18411845 loops,
18421846 root_ind_i);
18431847
1844- root_ind_i =
1845- getProducerIndexWithHalo (producer_tv, i, root_ind_i, consumer_tv);
1848+ root_ind_i = getProducerIndexWithHalo (
1849+ producer_tv,
1850+ i,
1851+ root_ind_i,
1852+ consumer_tv,
1853+ override_index.count (root_dom[i]));
18461854
18471855 root_ind_i = getProducerIndexWithGather (
18481856 root_ind_i,
@@ -2226,7 +2234,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
22262234std::vector<Val*> Index::getProducerStridedIndices (
22272235 TensorView* producer,
22282236 const TensorView* consumer,
2229- const std::vector<kir::ForLoop*>& loops) {
2237+ const std::vector<kir::ForLoop*>& loops,
2238+ const std::unordered_map<IterDomain*, Val*>& override_index) {
22302239 FUSER_PERF_SCOPE (" GpuLower::Lower::Index::getProducerStridedIndices" );
22312240 if (producer->domain ()->noReductions ().size () == 0 ) {
22322241 return std::vector<Val*>(
@@ -2236,11 +2245,11 @@ std::vector<Val*> Index::getProducerStridedIndices(
22362245
22372246 std::vector<Val*> strided_indices;
22382247 if (producer->getMemoryType () == MemoryType::Global) {
2239- strided_indices =
2240- getGlobalProducerStridedIndices ( producer, consumer, loops);
2248+ strided_indices = getGlobalProducerStridedIndices (
2249+ producer, consumer, loops, override_index );
22412250 } else {
2242- strided_indices =
2243- getNonGlobalProducerStridedIndices ( producer, consumer, loops);
2251+ strided_indices = getNonGlobalProducerStridedIndices (
2252+ producer, consumer, loops, override_index );
22442253 }
22452254
22462255 TORCH_INTERNAL_ASSERT (
@@ -2256,8 +2265,10 @@ std::vector<Val*> Index::getProducerStridedIndices(
22562265kir::TensorIndex* Index::getProducerIndex (
22572266 TensorView* producer,
22582267 const TensorView* consumer,
2259- const std::vector<kir::ForLoop*>& loops) {
2260- auto strided_indices = getProducerStridedIndices (producer, consumer, loops);
2268+ const std::vector<kir::ForLoop*>& loops,
2269+ const std::unordered_map<IterDomain*, Val*>& override_index) {
2270+ auto strided_indices =
2271+ getProducerStridedIndices (producer, consumer, loops, override_index);
22612272 return SimplifyingIrBuilder::create<kir::TensorIndex>(
22622273 producer, strided_indices);
22632274}
0 commit comments