Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jun 3, 2022
1 parent 0ae1c5b commit 71f7c68
Showing 1 changed file with 4 additions and 90 deletions.
94 changes: 4 additions & 90 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,84 +2065,6 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
const std::vector<kir::ForLoop*>& loops) {
const auto gpu_lower = GpuLower::current();

// Get a reference tensor replayed as existing loop structure
auto reference = IndexReferenceReplay::getReference(loops, consumer_tv);

auto reference_domain = reference.domain;
auto reference_id_map = reference.concrete_to_id;

auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops);
std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
std::unordered_set<kir::ForLoop*> zero_loops;
std::tie(loop_to_ind_map, zero_loops) =
indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true);

ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops);

// Map loop nests to indicies, zeroing out those not used due to locality of
// memory
std::unordered_map<IterDomain*, Val*> ref_id_to_ind_map;
std::unordered_set<IterDomain*> ref_zero_domains;

// Due to rfactor/initialization reference_domain may be bigger than loop nest
// structure, ignore IterDomains that aren't present in the loop nest when
// indexing reference.
TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
for (const auto loop_i : c10::irange(loops.size())) {
auto ref_axis = reference_domain->axis(loop_i);
ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
if (zero_loops.count(loops[loop_i]) > 0) {
ref_zero_domains.insert(ref_axis);
}
}

// Map everything we can from reference to consumer using compute at index
// map.
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map);

// Grab roots that map into consumer and save them into the preferred roots
// set for references indexing
std::unordered_set<IterDomain*> preferred_roots;
for (auto entry : index_map_ref_to_consumer) {
if (entry.second->isBroadcast() || entry.second->isReduction() ||
entry.second->isStride()) {
continue;
}
preferred_roots.emplace(entry.first);
}

// Make sure propagation of indexing while mixing with 0 indicies we propagate
// in a way that consumer will be able to see what's going on.
auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);

// Index into the reference tensor
auto ref_compute = getReferenceIndexing(
loops,
reference_domain,
ref_id_to_ind_map,
ref_zero_domains,
preferred_paths);

// Adds halo info mappings for the reference
updateHaloInfoForReference(reference, consumer_tv);

const auto reference_halo_extent_map =
getReferenceHaloExtentMap(reference, index_map_ref_to_consumer);

ContigIDs contig_finder(
consumer_tv->domain()->domain(),
consumer_tv->getMaybeRFactorDomain(),
consumer_tv->domain()->contiguity(),
reference_id_map);

// Index into consumer using reference indexing
// auto consumer_indexing = ref_compute.updateIndexCompute(
// consumer_tv->domain(),
// index_map_ref_to_consumer,
// contig_finder,
// reference_halo_extent_map);

auto consumer_indexing_from_idgraph = getTensorIndexFromIdGraph(
loops,
consumer_tv,
Expand Down Expand Up @@ -2185,11 +2107,7 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
" dim: ",
i,
" id: ",
root_dom[i]->toString(),
", reference domain: ",
reference_domain->toString(),
", reference root: ",
ir_utils::toString(reference_domain->getRootDomain()));
root_dom[i]->toString());

auto root_ind_i = index_map.at(root_dom[i]);
if (root_ind_i->isZeroInt()) {
Expand All @@ -2201,8 +2119,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
root_dom[i],
consumer_tv,
consumer_indexing,
reference.domain,
ref_compute,
consumer_indexing_from_idgraph.resolved_loop_domains,
consumer_indexing_from_idgraph.initial_concrete_index_map,
loops,
root_ind_i);

Expand All @@ -2222,11 +2140,7 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
" dim: ",
j,
" id: ",
root_dom[j]->toString(),
", reference domain: ",
reference_domain->toString(),
", reference root: ",
ir_utils::toString(reference_domain->getRootDomain()));
root_dom[j]->toString());

auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end()
? root_dom[j]->extent()
Expand Down

0 comments on commit 71f7c68

Please sign in to comment.