From b5bd10a7e86cf4d35a9157eb252cafa0f048514d Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 2 Jun 2022 22:23:17 -0700 Subject: [PATCH] cleanup --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 33 ++----------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index eac8586473421..d61887014f2a3 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1923,27 +1923,8 @@ std::vector Index::getGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); - const auto gpu_lower = GpuLower::current(); - // Get a reference tensor replayed as existing loop structure - auto reference = IndexReferenceReplay::getReference(loops, consumer_tv); - auto reference_domain = reference.domain; - auto reference_id_map = reference.concrete_to_id; - - // Map everything we can from reference to consumer using compute at index - // map. - std::unordered_map index_map_ref_to_consumer = - indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map); - - // Index into the reference tensor. Reference indexing will handle vectorized - // dims where index should be set to 0 - auto ref_compute = getReferenceIndexing(loops, reference_domain); - - ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - consumer_tv->domain()->contiguity(), - reference_id_map); + auto gpu_lower = GpuLower::current(); auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); @@ -1999,11 +1980,7 @@ std::vector Index::getGlobalConsumerStridedIndices( " dim: ", dim, " id: ", - root_dom[dim]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[dim]->toString()); if (consumer_tv->domain()->contiguity()[dim]) { // If contig, used the stored stride which may be the previous @@ -2046,11 +2023,7 @@ std::vector Index::getGlobalConsumerStridedIndices( " dim: ", i, " id: ", - root_dom[i]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[i]->toString()); auto root_ind = consumer_indexing.indexMap().at(root_dom[i]);