Skip to content

Commit

Permalink
remove ref tensor in index hoisting
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jun 3, 2022
1 parent 5bb23e2 commit b3e636d
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 16 deletions.
73 changes: 61 additions & 12 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1115,23 +1115,22 @@ std::unordered_map<IterDomain*, IterDomain*> indexMapReferenceTo(
return index_map_ref_to_producer;
}

Val* hoistConsumerIndex(
//! Returns an iterdomain that corresponds to the
//! indexing sub-expression to hoist or a nullopt
//! if the index should not be hoisted.
c10::optional<IterDomain*> getMaybeIndexedConsumerIdToHoist(
IterDomain* consumer_root_id,
const TensorView* consumer_tv,
const IndexCompute& consumer_indexing,
TensorDomain* ref_td,
const IndexCompute& ref_indexing,
const std::vector<kir::ForLoop*>& loops,
Val* index) {
// If index has no defining expression, there's nothing to hoist
if (isDisabled(DisableOption::IndexHoist) || index->definition() == nullptr) {
return index;
return c10::nullopt;
}

// The old swizzle interface, which should be deprecated, is not
// supported.
if (consumer_tv->swizzleType() != SwizzleType::NoSwizzle) {
return index;
return c10::nullopt;
}

// auto indexed_consumer_id = consumer_root_id;
Expand All @@ -1149,12 +1148,30 @@ Val* hoistConsumerIndex(
"Invalid contig index: ",
contig_id_it->second->toString());

return indexed_consumer_id;
}

Val* hoistConsumerIndex(
IterDomain* consumer_root_id,
const TensorView* consumer_tv,
const IndexCompute& consumer_indexing,
TensorDomain* ref_td,
const IndexCompute& ref_indexing,
const std::vector<kir::ForLoop*>& loops,
Val* index) {
auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist(
consumer_root_id, consumer_tv, consumer_indexing, index);

if (!maybe_hoisted_consumer_id.has_value()) {
return index;
}

// Insert the index into the common index map. A previously inserted
// val can be returned.
auto common_index = GpuLower::current()
->commonIndexMap()
.insert(
indexed_consumer_id,
maybe_hoisted_consumer_id.value(),
consumer_tv->domain(),
ref_td,
ref_indexing.indexMap(),
Expand All @@ -1165,6 +1182,40 @@ Val* hoistConsumerIndex(
return common_index;
}

// Version of hoisting without using reference tensor,
// should eventually deprecate the other one once reference
// tensor is completely deprecated.
Val* hoistConsumerIndex(
IterDomain* consumer_root_id,
const TensorView* consumer_tv,
const IndexCompute& consumer_indexing,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*> initial_loop_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index) {
auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist(
consumer_root_id, consumer_tv, consumer_indexing, index);

if (!maybe_hoisted_consumer_id.has_value()) {
return index;
}

// Insert the index into the common index map. A previously inserted
// val can be returned.
auto common_index = GpuLower::current()
->commonIndexMap()
.insert(
maybe_hoisted_consumer_id.value(),
consumer_tv->domain(),
loop_domains,
initial_loop_index_map,
loops,
index)
.first;

return common_index;
}

std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap(
const std::unordered_map<IterDomain*, IterDomain*>& map) {
std::unordered_map<IterDomain*, IterDomain*> inverted;
Expand Down Expand Up @@ -2008,8 +2059,8 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
root_dom[i],
consumer_tv,
consumer_indexing,
reference.domain,
ref_compute,
index_from_id_graph.resolved_loop_domains,
index_from_id_graph.initial_concrete_index_map,
loops,
root_ind);

Expand All @@ -2032,8 +2083,6 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
TORCH_INTERNAL_ASSERT(
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());

fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder);

return strided_inds;
}

Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/codegen/cuda/index_idgraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ struct IndexFromIdGraph {
IndexFromIdGraph getTensorIndexFromIdGraph(
const std::vector<kir::ForLoop*>& loops,
const TensorView* consumer_tv,
const TensorView* producer_tv=nullptr,
bool is_global=true,
std::unordered_map<IterDomain*, IterDomain*> c2p_map={});
const TensorView* producer_tv = nullptr,
bool is_global = true,
std::unordered_map<IterDomain*, IterDomain*> c2p_map = {});

//! A data structure keeping tack of loop nest dependent indexing
//! math. In the current version of indexing pass, the loop nest
Expand Down
80 changes: 79 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,62 @@ CommonIndexKey::CommonIndexKey(
loops.size());
}

CommonIndexKey::CommonIndexKey(
IterDomain* consumer_indexed_id,
TensorDomain* consumer_td,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops) {
auto gpu_lower = GpuLower::current();

concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID(
consumer_indexed_id, IdMappingMode::EXACT);

const auto consumer_leaf_ids =
getUsedLeafIds(consumer_indexed_id, consumer_td);

// Convert to Parallel concrete IDs to find matching loops.
std::unordered_set<IterDomain*> concrete_leaf_ids;
for (auto& id : consumer_leaf_ids) {
concrete_leaf_ids.insert(
gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP));
}

// Find used loops and their index vals
for (const auto i : c10::irange(loops.size())) {
auto loop = loops.at(i);
auto loop_id = gpu_lower->caMap()->getConcreteMappedID(
loop->iter_domain(), IdMappingMode::LOOP);
auto it = concrete_leaf_ids.find(loop_id);
if (it != concrete_leaf_ids.end()) {
// This leaf reference id is used for indexing the consumer id
used_loops_.push_back(loop);
auto index_it =
ref_index_map.find(gpu_lower->caMap()->getConcreteMappedID(
loop_domains.at(i), IdMappingMode::EXACT));
TORCH_INTERNAL_ASSERT(
index_it != ref_index_map.end(),
"Index not found for leaf ID, ",
loop_domains.at(i)->toString());
loop_index_vals_.push_back(index_it->second);
}
}

TORCH_INTERNAL_ASSERT(
!used_loops_.empty(),
"No loop used for indexing found. ",
consumer_indexed_id->toString());

TORCH_INTERNAL_ASSERT(
consumer_leaf_ids.size() == used_loops_.size(),
"consumer_leaf_ids.size() = ",
consumer_leaf_ids.size(),
", used_loops_.size() == ",
used_loops_.size(),
", loops.size() == ",
loops.size());
}

bool CommonIndexKey::operator==(const CommonIndexKey& other) const {
auto gpu_lower = GpuLower::current();

Expand Down Expand Up @@ -179,7 +235,30 @@ std::pair<Val*, bool> CommonIndexMap::insert(

const CommonIndexKey key(
indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops);
return tryInsertNewIndex(key, index);
}

std::pair<Val*, bool> CommonIndexMap::insert(
IterDomain* indexed_consumer_id,
TensorDomain* consumer_td,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index) {
if (index->definition() == nullptr) {
// Only expression is eligible to hoist
return {index, false};
}

const CommonIndexKey key(
indexed_consumer_id, consumer_td, loop_domains, ref_index_map, loops);

return tryInsertNewIndex(key, index);
}

std::pair<Val*, bool> CommonIndexMap::tryInsertNewIndex(
CommonIndexKey key,
Val* index) {
Val* hoisted_index = nullptr;
bool new_index_inserted = false;

Expand All @@ -195,7 +274,6 @@ std::pair<Val*, bool> CommonIndexMap::insert(
new_index_inserted = true;
use_counts_[key] = 1;
}

return {hoisted_index, new_index_inserted};
}

Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_index_hoist.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ class CommonIndexKey {
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops);

//! \param consumer_indexed_id Indexed consumer domain
//! \param consumer_td TensorDomain of consumer_indexed_id
//! \param loop_domains Resolved vector of iterdomain corresponding to loops
//! \param loop_index_map Index mapping generated from the loop nest.
//! \param loops Loop structure where this id is indexed
CommonIndexKey(
IterDomain* consumer_indexed_id,
TensorDomain* consumer_td,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*>& loop_index_map,
const std::vector<kir::ForLoop*>& loops);

const IterDomain* concreteIndexedId() const {
return concrete_indexed_id_;
}
Expand Down Expand Up @@ -96,6 +108,15 @@ class TORCH_CUDA_CU_API CommonIndexMap {
const std::vector<kir::ForLoop*>& loops,
Val* index);

// Version of insertion without reference tensor domain.
std::pair<Val*, bool> insert(
IterDomain* indexed_consumer_id,
TensorDomain* consumer_td,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*>& ref_index_map,
const std::vector<kir::ForLoop*>& loops,
Val* index);

const auto& commonIndexMap() const {
return common_index_map_;
}
Expand All @@ -104,6 +125,16 @@ class TORCH_CUDA_CU_API CommonIndexMap {
return use_counts_;
}

private:
//! Utility method to insert a key into common index
//! map. Returns a pair of an IR node and a boolean value.
//! The IR node will be the previously inserted index if
//! the key found a match, or will be the original index
//! if this is new key and the key will be stored.
//! The boolean value will be true if the key is stored,
//! i.e. first time it is inserted.
std::pair<Val*, bool> tryInsertNewIndex(CommonIndexKey key, Val* index);

private:
//! Map to hold hoisted common indices
std::unordered_map<CommonIndexKey, Val*, CommonIndexKeyHash>
Expand Down

0 comments on commit b3e636d

Please sign in to comment.