-
Notifications
You must be signed in to change notification settings - Fork 7
Add support for select op
#2179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b8a8c29
caff3ac
ee5b0e8
2439478
041dbf1
11a56e6
87fee1d
5a54098
02b2425
3ae76f5
7ee9b0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,9 +69,11 @@ Val* getProducerIndexWithHalo( | |
| const TensorView* producer_tv, | ||
| size_t producer_axis, | ||
| Val* producer_index, | ||
| const TensorView* consumer_tv) { | ||
| const auto offset = | ||
| getProducerHaloOffset(producer_tv, producer_axis, consumer_tv); | ||
| const TensorView* consumer_tv, | ||
| bool is_overriden_index) { | ||
| const auto offset = is_overriden_index | ||
| ? 0 | ||
| : getProducerHaloOffset(producer_tv, producer_axis, consumer_tv); | ||
|
|
||
| if (offset == 0) { | ||
| return producer_index; | ||
|
|
@@ -1460,7 +1462,8 @@ Val* hoistProducerIndex( | |
| std::vector<Val*> Index::getGlobalProducerStridedIndices( | ||
| TensorView* producer_tv, | ||
| const TensorView* consumer_tv, | ||
| const std::vector<kir::ForLoop*>& loops) { | ||
| const std::vector<kir::ForLoop*>& loops, | ||
| const std::unordered_map<IterDomain*, Val*>& override_index) { | ||
| FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex"); | ||
|
|
||
| // Replay producer to look like consumer so we can index on producer since | ||
|
|
@@ -1545,23 +1548,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |
| continue; | ||
| } | ||
|
|
||
| Val* root_ind = nullptr; | ||
| if (producer_indexing.indexMap().find(root_dom[dim]) != | ||
| producer_indexing.indexMap().end()) { | ||
| root_ind = producer_indexing.indexMap().at(root_dom[dim]); | ||
| } else if (root_dom[dim]->isBroadcast()) { | ||
| root_ind = GpuLower::current()->kernel()->zeroVal(); | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| root_ind != nullptr, | ||
| "Couldn't find root mapping for ", | ||
| producer_tv->toString(), | ||
| " dim: ", | ||
| dim, | ||
| " id: ", | ||
| root_dom[dim]->toString()); | ||
|
|
||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (producer_tv->domain()->contiguity()[dim]) { | ||
| // If contig, used the stored stride which may be the previous | ||
| // dimensions stride * previous dimensions size | ||
|
|
@@ -1591,18 +1577,27 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |
| continue; | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| Val* root_ind = nullptr; | ||
| auto override_it = override_index.find(root_dom[i]); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about passing the optional map to That said, I think this is good enough for now given that the whole indexing code would be redesigned. Pinging @csarofeen
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed |
||
| if (override_it != override_index.end()) { | ||
| root_ind = override_it->second; | ||
| } else if ( | ||
| producer_indexing.indexMap().find(root_dom[i]) != | ||
| producer_indexing.indexMap().end(), | ||
| "Couldn't find root mapping for TV", | ||
| producer_tv->name(), | ||
| producer_indexing.indexMap().end()) { | ||
| root_ind = producer_indexing.indexMap().at(root_dom[i]); | ||
| } else if (root_dom[i]->isBroadcast()) { | ||
| root_ind = GpuLower::current()->kernel()->zeroVal(); | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| root_ind != nullptr, | ||
| "Couldn't find root mapping for ", | ||
| producer_tv->toString(), | ||
| " dim: ", | ||
| i, | ||
| " id: ", | ||
| root_dom[i]->toString()); | ||
|
|
||
| auto root_ind = producer_indexing.indexMap().at(root_dom[i]); | ||
|
|
||
| // index hoist must be done before the adjustments for halo | ||
| root_ind = hoistProducerIndex( | ||
| root_dom[i], | ||
|
|
@@ -1615,7 +1610,12 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices( | |
| loops, | ||
| root_ind); | ||
|
|
||
| root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv); | ||
| root_ind = getProducerIndexWithHalo( | ||
| producer_tv, | ||
| i, | ||
| root_ind, | ||
| consumer_tv, | ||
| override_index.count(root_dom[i])); | ||
|
|
||
| root_ind = getProducerIndexWithGather( | ||
| root_ind, | ||
|
|
@@ -1686,7 +1686,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer( | |
| std::vector<Val*> Index::getNonGlobalProducerStridedIndices( | ||
| TensorView* producer_tv, | ||
| const TensorView* consumer_tv, | ||
| const std::vector<kir::ForLoop*>& loops) { | ||
| const std::vector<kir::ForLoop*>& loops, | ||
| const std::unordered_map<IterDomain*, Val*>& override_index) { | ||
| const auto gpu_lower = GpuLower::current(); | ||
|
|
||
| // Replay producer to look like consumer so we can index on producer since our | ||
|
|
@@ -1827,7 +1828,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices( | |
| " id: ", | ||
| root_dom[i]->toString()); | ||
|
|
||
| auto root_ind_i = index_map.at(root_dom[i]); | ||
| auto override_it = override_index.find(root_dom[i]); | ||
| auto root_ind_i = | ||
| (override_it != override_index.end() ? override_it->second | ||
| : index_map.at(root_dom[i])); | ||
|
|
||
| // index hoist must be done before the adjustments for halo | ||
| root_ind_i = hoistProducerIndex( | ||
|
|
@@ -1841,8 +1845,12 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices( | |
| loops, | ||
| root_ind_i); | ||
|
|
||
| root_ind_i = | ||
| getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv); | ||
| root_ind_i = getProducerIndexWithHalo( | ||
| producer_tv, | ||
| i, | ||
| root_ind_i, | ||
| consumer_tv, | ||
| override_index.count(root_dom[i])); | ||
|
|
||
| root_ind_i = getProducerIndexWithGather( | ||
| root_ind_i, | ||
|
|
@@ -2226,7 +2234,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices( | |
| std::vector<Val*> Index::getProducerStridedIndices( | ||
| TensorView* producer, | ||
| const TensorView* consumer, | ||
| const std::vector<kir::ForLoop*>& loops) { | ||
| const std::vector<kir::ForLoop*>& loops, | ||
| const std::unordered_map<IterDomain*, Val*>& override_index) { | ||
| FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices"); | ||
| if (producer->domain()->noReductions().size() == 0) { | ||
| return std::vector<Val*>( | ||
|
|
@@ -2236,11 +2245,11 @@ std::vector<Val*> Index::getProducerStridedIndices( | |
|
|
||
| std::vector<Val*> strided_indices; | ||
| if (producer->getMemoryType() == MemoryType::Global) { | ||
| strided_indices = | ||
| getGlobalProducerStridedIndices(producer, consumer, loops); | ||
| strided_indices = getGlobalProducerStridedIndices( | ||
| producer, consumer, loops, override_index); | ||
| } else { | ||
| strided_indices = | ||
| getNonGlobalProducerStridedIndices(producer, consumer, loops); | ||
| strided_indices = getNonGlobalProducerStridedIndices( | ||
| producer, consumer, loops, override_index); | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
|
|
@@ -2256,8 +2265,10 @@ std::vector<Val*> Index::getProducerStridedIndices( | |
| kir::TensorIndex* Index::getProducerIndex( | ||
| TensorView* producer, | ||
| const TensorView* consumer, | ||
| const std::vector<kir::ForLoop*>& loops) { | ||
| auto strided_indices = getProducerStridedIndices(producer, consumer, loops); | ||
| const std::vector<kir::ForLoop*>& loops, | ||
| const std::unordered_map<IterDomain*, Val*>& override_index) { | ||
| auto strided_indices = | ||
| getProducerStridedIndices(producer, consumer, loops, override_index); | ||
| return SimplifyingIrBuilder::create<kir::TensorIndex>( | ||
| producer, strided_indices); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.