Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jun 3, 2022
1 parent b3e636d commit b5bd10a
Showing 1 changed file with 3 additions and 30 deletions.
33 changes: 3 additions & 30 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1923,27 +1923,8 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
const TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
const auto gpu_lower = GpuLower::current();

// Get a reference tensor replayed as existing loop structure
auto reference = IndexReferenceReplay::getReference(loops, consumer_tv);
auto reference_domain = reference.domain;
auto reference_id_map = reference.concrete_to_id;

// Map everything we can from reference to consumer using compute at index
// map.
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map);

// Index into the reference tensor. Reference indexing will handle vectorized
// dims where index should be set to 0
auto ref_compute = getReferenceIndexing(loops, reference_domain);

ContigIDs contig_finder(
consumer_tv->domain()->domain(),
consumer_tv->getMaybeRFactorDomain(),
consumer_tv->domain()->contiguity(),
reference_id_map);
auto gpu_lower = GpuLower::current();

auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);

Expand Down Expand Up @@ -1999,11 +1980,7 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
" dim: ",
dim,
" id: ",
root_dom[dim]->toString(),
", reference domain: ",
reference_domain->toString(),
", reference root: ",
ir_utils::toString(reference_domain->getRootDomain()));
root_dom[dim]->toString());

if (consumer_tv->domain()->contiguity()[dim]) {
// If contig, used the stored stride which may be the previous
Expand Down Expand Up @@ -2046,11 +2023,7 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
" dim: ",
i,
" id: ",
root_dom[i]->toString(),
", reference domain: ",
reference_domain->toString(),
", reference root: ",
ir_utils::toString(reference_domain->getRootDomain()));
root_dom[i]->toString());

auto root_ind = consumer_indexing.indexMap().at(root_dom[i]);

Expand Down

0 comments on commit b5bd10a

Please sign in to comment.