From 0666bb93d96a92d9fabab007fbc48a37f1b4c76e Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 9 May 2022 13:34:23 -0700 Subject: [PATCH 01/12] pre-allocate loop index variables --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 48 ++++++++++++++++++- torch/csrc/jit/codegen/cuda/compute_at_map.h | 16 +++++++ torch/csrc/jit/codegen/cuda/executor.cpp | 3 +- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 3 +- torch/csrc/jit/codegen/cuda/lower2device.cpp | 1 + torch/csrc/jit/codegen/cuda/lower_loops.cpp | 3 +- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 22 ++++----- 7 files changed, 79 insertions(+), 17 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 43382f865d435..0fe3950d2e235 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -226,7 +226,8 @@ void IterDomainGraph::initializeId( } } -ComputeAtMap::ComputeAtMap(Fusion* fusion) : id_graph_(fusion) { +ComputeAtMap::ComputeAtMap(Fusion* fusion) + : id_graph_(fusion), fusion_(fusion) { build(fusion); } @@ -257,6 +258,51 @@ void ComputeAtMap::validateAndPropagatePType() { } } +void ComputeAtMap::allocateIndexVariables() { + for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { + ParallelType ptype; + // first allocate thread and grid parallel indices: + if (std::any_of( + loop_disjoint_set->vector().begin(), + loop_disjoint_set->vector().end(), + [&ptype](IterDomain* id) { + if (id->isThread() && + // Thread/Grid parallel Iterdomains do not generate + // serial loops unless they are halo extended. + (GpuLower::current()->haloInfo().getExtent(id) == nullptr)) { + ptype = id->getParallelType(); + return true; + } + return false; + })) { + loop_index_variable_map_[loop_disjoint_set.get()] = + NamedScalar::getParallelIndex(ptype); + + // Finish allocating parallel indices. + continue; + } + + // Non-parallel broadcast dimensions get 0 for their indices: + if (std::all_of( + loop_disjoint_set->vector().begin(), + loop_disjoint_set->vector().end(), + [](IterDomain* id) { return id->isBroadcast(); })) { + loop_index_variable_map_[loop_disjoint_set.get()] = fusion_->zeroVal(); + continue; + } + + // Everything now should be serial concrete loops, + // where we allocate a loop index integer for each loop. + loop_index_variable_map_[loop_disjoint_set.get()] = + IrBuilder::create(c10::nullopt); + } +} + +Val* ComputeAtMap::getIndexVariable(IterDomain* id) const { + const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id)); + return loop_index_variable_map_.at(loop_set); +} + bool ComputeAtMap::areMapped( IterDomain* id0, IterDomain* id1, diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index 54bb7537a3f16..dc99ba21206a9 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -122,6 +122,10 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! all IterDomains in the disjoint set to that PType. void validateAndPropagatePType(); + //! Run through disjoint sets in the LOOP map and allocate the index + //! variables if any. + void allocateIndexVariables(); + //! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const; @@ -151,6 +155,10 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Get the ID sets for a provided IdMappingMode const DisjointSets& getIdSets(IdMappingMode mode) const; + //! Returns the index variable corresponding to the + //! given iterdomain. + Val* getIndexVariable(IterDomain* id) const; + private: // Build id_graph_ void build(Fusion* fusion); @@ -178,6 +186,14 @@ class TORCH_CUDA_CU_API ComputeAtMap { std::shared_ptr>, IterDomain*> concrete_id_cache_; + + // Allocated Loop index variable through the CA map. + std::unordered_map*, Val*> + loop_index_variable_map_; + + // Shortcut to access the fusion this computeAt map was + // built from. + Fusion* fusion_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d25ef0fda9c3a..0280a0e9295ae 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -800,7 +800,8 @@ std::vector FusionExecutor::runFusion( launch_params_.gdimy() * launch_params_.gdimz(), "Wanted to launch a cooperative kernel, however the number of blocks is greater than ", "what can be resident on the GPU at once. Need: ", - launch_params_.gdimx() * launch_params_.gdimy() * launch_params_.gdimz(), + launch_params_.gdimx() * launch_params_.gdimy() * + launch_params_.gdimz(), " but limited to ", num_blocks_per_SM, " * ", diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index a2da12bd0e18c..25a4c4700f75a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -209,8 +209,7 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) : ForLoop( passkey, iter_domain, - iter_domain->isBroadcast() ? FusionGuard::getCurFusion()->zeroVal() - : IrBuilder::create(c10::nullopt), + GpuLower::current()->caMap()->getIndexVariable(iter_domain), nullptr, nullptr, nullptr, diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 3e644fc9a44d1..e5a05f784ef61 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -288,6 +288,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { doubleBufferInfo().build(fusion_); + compute_at_map_->allocateIndexVariables(); // Run our passes keeping the lowered expressions and forwarding // them diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index aa0ff1a444699..b1b73769c1fb4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -44,8 +44,7 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { // Use the extent that's extended by halo new_scope = IrBuilder::create( id, - id->isBroadcast() ? GpuLower::current()->kernel()->zeroVal() - : IrBuilder::create(c10::nullopt), + GpuLower::current()->caMap()->getIndexVariable(id), nullptr, extent_with_halo, nullptr, diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 3e8c14e924175..3c84535e12c2a 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -1325,17 +1325,17 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i52; - i52 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i52 < T0.size[0])) { + int64_t i51; + i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i51 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i52]; + = T1[i51]; float T4[1]; T4[0] = 0; T4[0] - = T0[i52]; + = T0[i51]; float T6[1]; float T2[1]; T2[0] @@ -1344,7 +1344,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Te T6[0] = T2[0] * T4[0]; - T3[i52] + T3[i51] = T6[0]; } } @@ -19043,9 +19043,9 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i173; - i173 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i173 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + int64_t i172; + i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { __half T9[1]; T9[0] = 0; T9[0] @@ -19053,7 +19053,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T8[1]; T8[0] = 0; T8[0] - = T0[i173]; + = T0[i172]; __half T10[1]; float T3[1]; T3[0] @@ -19073,7 +19073,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, = relu(T5[0]); T10[0] = __float2half(T6[0]); - T7[i173] + T7[i172] = T10[0]; } } From 374082f2e0e4bb821ef243a79cc87cfacabc8cf1 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 25 May 2022 10:34:30 -0700 Subject: [PATCH 02/12] add allocated index in debug print --- torch/csrc/jit/codegen/cuda/compute_at_map.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 270ab524857b0..e60a6d2380bfe 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -632,6 +632,9 @@ std::string idGraphNodesToString( concrete_id = ca_map.getConcreteMappedID(id, mode); } ss << " {"; + if (mode == IdMappingMode::LOOP) { + ss << "(index variable: " << ca_map.getIndexVariable(concrete_id) << ") "; + } for (auto entry : set.vector()) { ss << abstractToString(entry); if (entry == concrete_id) { From 0d4da1de4a7dbf0304a63461b63072e59d0bb5eb Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 25 May 2022 10:39:18 -0700 Subject: [PATCH 03/12] assertion on loop map entry --- torch/csrc/jit/codegen/cuda/compute_at_map.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index e60a6d2380bfe..9b53264281c8d 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -299,6 +299,11 @@ void ComputeAtMap::allocateIndexVariables() { } Val* ComputeAtMap::getIndexVariable(IterDomain* id) const { + TORCH_INTERNAL_ASSERT( + id_graph_.loopNodes().mappingExists(id), + "Index Variable: no index variable allocated as ", + id->toString(), + " is not registered in loop map"); const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id)); return loop_index_variable_map_.at(loop_set); } From 9be2874e92aef49740ace03609a483885f8b0c18 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 25 May 2022 10:55:58 -0700 Subject: [PATCH 04/12] comment --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 18 +++++++++++------ torch/csrc/jit/codegen/cuda/compute_at_map.h | 20 ++++++++++++++++--- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 9b53264281c8d..cc2c3f6c2d6c2 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -259,16 +259,23 @@ void ComputeAtMap::validateAndPropagatePType() { } void ComputeAtMap::allocateIndexVariables() { + // Run through all disjoint sets registered in loop map, + // all lowered kir::ForLoop will correspond to one of the disjoint sets + // and we only need one index variable for each set. for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { ParallelType ptype; // first allocate thread and grid parallel indices: + // The validation pass will check that the parallel bindings within the + // loop nodes are consistent so all the loops within this disjoint set + // will be realized implicitly using parallel index variables. if (std::any_of( loop_disjoint_set->vector().begin(), loop_disjoint_set->vector().end(), [&ptype](IterDomain* id) { if (id->isThread() && - // Thread/Grid parallel Iterdomains do not generate - // serial loops unless they are halo extended. + // Halo extended parallel loops currently are handled + // differently and an index variable would still + // be allocated in this case. (GpuLower::current()->haloInfo().getExtent(id) == nullptr)) { ptype = id->getParallelType(); return true; @@ -277,12 +284,11 @@ void ComputeAtMap::allocateIndexVariables() { })) { loop_index_variable_map_[loop_disjoint_set.get()] = NamedScalar::getParallelIndex(ptype); - - // Finish allocating parallel indices. continue; } - // Non-parallel broadcast dimensions get 0 for their indices: + // All loops in this set are non-parallel, non-concretized broadcast + // iterdomains, their "index variable" should be zero. if (std::all_of( loop_disjoint_set->vector().begin(), loop_disjoint_set->vector().end(), @@ -292,7 +298,7 @@ void ComputeAtMap::allocateIndexVariables() { } // Everything now should be serial concrete loops, - // where we allocate a loop index integer for each loop. + // we just allocate a loop index integer for each set of loops. loop_index_variable_map_[loop_disjoint_set.get()] = IrBuilder::create(c10::nullopt); } diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index dc99ba21206a9..da99bd776b849 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -123,7 +123,21 @@ class TORCH_CUDA_CU_API ComputeAtMap { void validateAndPropagatePType(); //! Run through disjoint sets in the LOOP map and allocate the index - //! variables if any. + //! for each disjoint sets in the loop map. This pre-allocation makes + //! 2 key assumptions about computeAt map that would very likely be + //! long term invariant: + //! 1. All kir::forloop created in the lowering pass should belong + //! to one of the disjoint sets in loop map. + //! 2. The lowering pass will *never* create a loop nest with 2 + //! different nesting levels mapped together, i.e. the case below + //! never occurs: + //! for i in IterDomain1 + //! for j in IterDomain2 + //! ... + //! With loop_map.areMapped(IterDomain1, IterDomain2) == true. + //! Under this condition, we can pre-allocate all required index + //! variable integers before creating any kir::forloop, and this + //! would help optimizing the generated integer math for indexing. void allocateIndexVariables(); //! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode @@ -155,8 +169,8 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Get the ID sets for a provided IdMappingMode const DisjointSets& getIdSets(IdMappingMode mode) const; - //! Returns the index variable corresponding to the - //! given iterdomain. + //! Returns the pre-allocated index variable integer used in + //! the kir::ForLoop corresponding to the given IterDomain. Val* getIndexVariable(IterDomain* id) const; private: From 627af9de3e006ad62f5f6b928bac34ffffc586f7 Mon Sep 17 00:00:00 2001 From: shmsong Date: Wed, 25 May 2022 15:11:38 -0700 Subject: [PATCH 05/12] use original index in double buffer cloned loop --- torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 9d761b81fcd24..1ebe973d439b5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -189,7 +189,7 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { // Main: 0 to (extent-1) // Epilogue: (extent-1) to extent - auto index = IrBuilder::create(c10::nullopt); + auto index = double_buffer_loop_->index(); auto start = double_buffer_loop_->start(); auto stop = double_buffer_loop_->stop(); From 0b666d699007f1a7ae1a607b3d08e0219296e4c5 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 24 Jun 2022 10:01:24 -0700 Subject: [PATCH 06/12] allocate different index variable for different double buffer stages --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 60 +++++++++++++++---- torch/csrc/jit/codegen/cuda/compute_at_map.h | 14 ++++- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 12 ++-- torch/csrc/jit/codegen/cuda/kernel_ir.h | 14 ++++- .../jit/codegen/cuda/lower_allocation.cpp | 3 +- .../jit/codegen/cuda/lower_double_buffer.cpp | 53 ++++++++++------ .../jit/codegen/cuda/lower_double_buffer.h | 8 +++ torch/csrc/jit/codegen/cuda/lower_loops.cpp | 3 +- .../cuda/lower_misaligned_vectorization.cpp | 3 +- torch/csrc/jit/codegen/cuda/type.cpp | 21 +++++++ torch/csrc/jit/codegen/cuda/type.h | 9 +++ 11 files changed, 164 insertions(+), 36 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index cc2c3f6c2d6c2..21993f98d585e 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -274,7 +274,7 @@ void ComputeAtMap::allocateIndexVariables() { [&ptype](IterDomain* id) { if (id->isThread() && // Halo extended parallel loops currently are handled - // differently and an index variable would still + // differently and an index variable would still // be allocated in this case. (GpuLower::current()->haloInfo().getExtent(id) == nullptr)) { ptype = id->getParallelType(); @@ -297,21 +297,64 @@ void ComputeAtMap::allocateIndexVariables() { continue; } - // Everything now should be serial concrete loops, - // we just allocate a loop index integer for each set of loops. - loop_index_variable_map_[loop_disjoint_set.get()] = - IrBuilder::create(c10::nullopt); + // Allocate variable for the iterdomains: + auto concrete_loop_id_it = concrete_id_cache_.find(loop_disjoint_set); + TORCH_INTERNAL_ASSERT( + concrete_loop_id_it != concrete_id_cache_.end(), + "Concrete id not computed"); + + auto concrete_loop_id = concrete_loop_id_it->second; + + // Need to allocate double buffered loop differently. + if (GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain( + concrete_loop_id)) { + // Allocate index variable for each stage of the double buffered loop. + double_buffered_loop_index_variable_map_[loop_disjoint_set.get()] = + std::make_unique(DoubleBufferIndices( + {{DoubleBufferLoopStage::Prolog, + IrBuilder::create(c10::nullopt)}, + {DoubleBufferLoopStage::Main, + IrBuilder::create(c10::nullopt)}, + {DoubleBufferLoopStage::Epilog, + IrBuilder::create(c10::nullopt)}})); + } else { + // Everything now should be serial concrete loops, + // we just allocate a loop index integer for each set of loops. + loop_index_variable_map_[loop_disjoint_set.get()] = + IrBuilder::create(c10::nullopt); + } } } -Val* ComputeAtMap::getIndexVariable(IterDomain* id) const { +Val* ComputeAtMap::getIndexVariable( + IterDomain* id, + DoubleBufferLoopStage double_buffer_loop_stage) const { TORCH_INTERNAL_ASSERT( id_graph_.loopNodes().mappingExists(id), "Index Variable: no index variable allocated as ", id->toString(), " is not registered in loop map"); const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id)); - return loop_index_variable_map_.at(loop_set); + + // Check if this loop was modified by double buffer pass. + bool is_double_buffer_iterdomain = + GpuLower::current()->doubleBufferInfo().isDoubleBufferedIterDomain(id); + + if (is_double_buffer_iterdomain) { + // Use dedicated double buffer index variable if the loop is double buffer + // loop + if (double_buffer_loop_stage == DoubleBufferLoopStage::NotApplicable) { + // The double buffered loop stages are created after the loop nest + // lowering phase so this function will be querried before the double + // buffer pass. At that point, no forloop has any double buffer + // stage defined, and we just default to using the main stage index. + double_buffer_loop_stage = DoubleBufferLoopStage::Main; + } + return double_buffered_loop_index_variable_map_.at(loop_set)->at( + double_buffer_loop_stage); + } else { + return loop_index_variable_map_.at(loop_set); + } } bool ComputeAtMap::areMapped( @@ -643,9 +686,6 @@ std::string idGraphNodesToString( concrete_id = ca_map.getConcreteMappedID(id, mode); } ss << " {"; - if (mode == IdMappingMode::LOOP) { - ss << "(index variable: " << ca_map.getIndexVariable(concrete_id) << ") "; - } for (auto entry : set.vector()) { ss << abstractToString(entry); if (entry == concrete_id) { diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index da99bd776b849..fce36ca1f29de 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -112,6 +112,8 @@ class TORCH_CUDA_CU_API IterDomainGraph { class TrivialReductionInfo; +using DoubleBufferIndices = std::unordered_map; + class TORCH_CUDA_CU_API ComputeAtMap { public: ComputeAtMap() = delete; @@ -171,7 +173,10 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Returns the pre-allocated index variable integer used in //! the kir::ForLoop corresponding to the given IterDomain. - Val* getIndexVariable(IterDomain* id) const; + Val* getIndexVariable( + IterDomain* id, + DoubleBufferLoopStage double_buffer_loop_stage = + DoubleBufferLoopStage::NotApplicable) const; private: // Build id_graph_ @@ -205,6 +210,13 @@ class TORCH_CUDA_CU_API ComputeAtMap { std::unordered_map*, Val*> loop_index_variable_map_; + // Allocated loop indices for double buffer loop. + using DoubleBufferIndicesPtr = std::unique_ptr; + std::unordered_map< + const VectorOfUniqueEntries*, + DoubleBufferIndicesPtr> + double_buffered_loop_index_variable_map_; + // Shortcut to access the fusion this computeAt map was // built from. Fusion* fusion_; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 21b41baf6e3d6..eddb49b6b2fc3 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -183,7 +183,8 @@ ForLoop::ForLoop( Val* step, bool vectorize, Val* vectorize_shift, - bool unroll_required) + bool unroll_required, + DoubleBufferLoopStage double_buffer_loop_stage) : Expr(passkey, ExprType::ForLoop), iter_domain_{iter_domain}, index_(index), @@ -193,7 +194,8 @@ ForLoop::ForLoop( vectorize_(vectorize), vectorize_shift_(vectorize_shift), unroll_required_(unroll_required), - body_(this) { + body_(this), + double_buffer_loop_stage_(double_buffer_loop_stage) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -223,7 +225,8 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) !iter_domain->isBroadcast() && isParallelTypeVectorize(iter_domain->getParallelType()), nullptr, - false) { + false, + DoubleBufferLoopStage::NotApplicable) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -239,7 +242,8 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) other->step(), other->vectorize(), other->vectorize_shift(), - other->isUnrollRequired()) { + other->isUnrollRequired(), + other->doubleBufferLoopStage()) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index a04debdd49128..730e0eb61804b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -392,7 +392,8 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { Val* step, bool vectorize, Val* vectorize_shift, - bool unroll_required); + bool unroll_required, + DoubleBufferLoopStage double_buffer_loop_stage); ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain); @@ -445,6 +446,12 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! True if no actual for-loop is materialized bool isTrivial() const; + //! Returns the stage of a double buffered iterdomain + //! that this for loop materializes. + auto doubleBufferLoopStage() const { + return double_buffer_loop_stage_; + } + private: //! Returns if a loop could be unrolled. bool isUnrollable() const; @@ -468,6 +475,11 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { bool unroll_required_ = false; Scope body_; + + //! Tracks if this for loop is implementing a stage of + //! a double buffered iterdomain. + DoubleBufferLoopStage double_buffer_loop_stage_ = + DoubleBufferLoopStage::NotApplicable; }; //! IfThenElse provides scoping for an boolean operator. Exprs placed in its diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index afd386ce4130e..4dae7e600da5c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -141,7 +141,8 @@ class AllocationInserter : public kir::ExprMutator { nullptr, false, nullptr, - false); + false, + DoubleBufferLoopStage::NotApplicable); } else { new_loop = IrBuilder::create(id); } diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index a362a91401780..24c725f607085 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -138,9 +138,6 @@ class DoubleBufferFusionInspector : private IterVisitor { DoubleBufferInfo& db_info_; }; -// The type of replicated double-buffer loops -enum class LoopType { Prologue, Main, Epilogue }; - // The epilogue loop is only created when the producer of a double // buffer tensor is on smem, in which case it would otherwise require // an additional predicate to guard buffer overruns. When it's on @@ -162,7 +159,7 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { static kir::ForLoop* clone( kir::ForLoop* double_buffer_loop, const std::vector& double_buffer_load_exprs, - LoopType loop_type) { + DoubleBufferLoopStage loop_type) { DoubleBufferLoopCloner cloner( double_buffer_loop, double_buffer_load_exprs, loop_type); cloner.clone(); @@ -173,7 +170,7 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { DoubleBufferLoopCloner( kir::ForLoop* double_buffer_loop, const std::vector& double_buffer_load_exprs, - LoopType loop_type) + DoubleBufferLoopStage loop_type) : double_buffer_loop_(double_buffer_loop), double_buffer_load_exprs_(double_buffer_load_exprs), loop_type_(loop_type) {} @@ -189,19 +186,20 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { // Main: 0 to (extent-1) // Epilogue: (extent-1) to extent - auto index = double_buffer_loop_->index(); + auto index = GpuLower::current()->caMap()->getIndexVariable( + double_buffer_loop_->iter_domain(), loop_type_); auto start = double_buffer_loop_->start(); auto stop = double_buffer_loop_->stop(); - if (loop_type_ == LoopType::Prologue) { + if (loop_type_ == DoubleBufferLoopStage::Prolog) { TORCH_INTERNAL_ASSERT(start->isZeroInt()); stop = gpu_lower->kernel()->oneVal(); } else if ( - loop_type_ == LoopType::Main && + loop_type_ == DoubleBufferLoopStage::Main && requireEpilogue(double_buffer_load_exprs_)) { stop = IrBuilder::subExpr( double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); - } else if (loop_type_ == LoopType::Epilogue) { + } else if (loop_type_ == DoubleBufferLoopStage::Epilog) { TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_)); start = IrBuilder::subExpr( double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal()); @@ -215,7 +213,8 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { gpu_lower->kernel()->oneVal(), false, nullptr, - double_buffer_loop_->isUnrollRequired()); + double_buffer_loop_->isUnrollRequired(), + loop_type_); handle(double_buffer_loop_); } @@ -250,7 +249,7 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty()); - if (loop_type_ == LoopType::Main) { + if (loop_type_ == DoubleBufferLoopStage::Main) { cloned_scopes_.back()->push_back(expr); return; } @@ -268,8 +267,10 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr); return out_tv == double_buffer_tv; }); - if ((loop_type_ == LoopType::Prologue && is_double_buffer_load_expr) || - (loop_type_ == LoopType::Epilogue && !is_double_buffer_load_expr)) { + if ((loop_type_ == DoubleBufferLoopStage::Prolog && + is_double_buffer_load_expr) || + (loop_type_ == DoubleBufferLoopStage::Epilog && + !is_double_buffer_load_expr)) { cloned_scopes_.back()->push_back(expr); } } @@ -277,7 +278,7 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { private: kir::ForLoop* double_buffer_loop_ = nullptr; const std::vector& double_buffer_load_exprs_; - const LoopType loop_type_; + const DoubleBufferLoopStage loop_type_; kir::ForLoop* cloned_top_level_loop_ = nullptr; std::deque cloned_scopes_; @@ -409,7 +410,7 @@ class DoubleBufferInserter : private kir::ExprMutator { kir::ForLoop* double_buffer_loop, const std::vector& loads) { auto prologue_loop = DoubleBufferLoopCloner::clone( - double_buffer_loop, loads, LoopType::Prologue); + double_buffer_loop, loads, DoubleBufferLoopStage::Prolog); registerInsertBefore(double_buffer_loop, prologue_loop); auto write_to_smem = @@ -453,7 +454,7 @@ class DoubleBufferInserter : private kir::ExprMutator { } auto main_loop = DoubleBufferLoopCloner::clone( - double_buffer_loop, loads, LoopType::Main); + double_buffer_loop, loads, DoubleBufferLoopStage::Main); registerReplace(double_buffer_loop, main_loop); @@ -489,7 +490,7 @@ class DoubleBufferInserter : private kir::ExprMutator { if (requireEpilogue(loads)) { auto epilogue_loop = DoubleBufferLoopCloner::clone( - double_buffer_loop, loads, LoopType::Epilogue); + double_buffer_loop, loads, DoubleBufferLoopStage::Epilog); registerInsertAfter(double_buffer_loop, epilogue_loop); } } @@ -534,6 +535,24 @@ class DoubleBufferInserter : private kir::ExprMutator { void DoubleBufferInfo::build(Fusion* fusion) { DoubleBufferFusionInspector inspector(fusion, *this); + + // Build double buffered loop id's + for (auto& info : map_) { + auto double_buffer_axis = info.second.double_buffer_axis; + // Keeps track of which loop disjoint set has been + // double buffered. In index allocation, one index + // variable would need to be allocated in each + // double buffer stage. + concrete_double_buffered_loop_id_.insert( + GpuLower::current()->caMap()->getConcreteMappedID( + double_buffer_axis, IdMappingMode::LOOP)); + } +} + +bool DoubleBufferInfo::isDoubleBufferedIterDomain(IterDomain* id) { + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::LOOP); + return concrete_double_buffered_loop_id_.count(concrete_loop_id); } DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) { diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 96bc247f4ff6f..b17beec31fb6c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -128,12 +128,20 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { Val* getOriginalAllocSize(const TensorView* tv); + //! Returns true if the iterdomain will be realized + //! as a double buffer loop. + bool isDoubleBufferedIterDomain(IterDomain* id); + private: TvInfo& getTvInfo(const TensorView* tv); private: //! Keeps track of information for lowering double buffered tensors std::unordered_map map_; + + //! Keeps track of which concrete loop map is realizing double buffer + //! iterdomains. + std::unordered_set concrete_double_buffered_loop_id_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index b1b73769c1fb4..7fdb149da9359 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -50,7 +50,8 @@ kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { nullptr, false, nullptr, - false); + false, + DoubleBufferLoopStage::NotApplicable); } else { new_scope = IrBuilder::create(id); } diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp index 66b405ac8e2f8..bd3c9baf66e1f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp @@ -367,7 +367,8 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { GpuLower::current()->kernel()->oneVal(), vectorize && has_vectorize_op, vectorize_shift, - fl->isUnrollRequired()); + fl->isUnrollRequired(), + fl->doubleBufferLoopStage()); auto body = &new_loop->body(); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 5c06287f90cb6..4563ed9efa0b2 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -1091,6 +1091,27 @@ size_t dataTypeSize(DataType type, DataType index_type) { return dataTypeSize(type); } +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream& os, + const DoubleBufferLoopStage loop_stage) { + switch (loop_stage) { + case DoubleBufferLoopStage::NotApplicable: + break; + case DoubleBufferLoopStage::Prolog: + os << "{DoubleBufferProlog}"; + break; + case DoubleBufferLoopStage::Main: + os << "{DoubleBufferMainLoop}"; + break; + case DoubleBufferLoopStage::Epilog: + os << "{DoubleBufferEpilog}"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown double buffer stage"); + } + return os; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 38f22c308ce27..490188b77af4c 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -303,8 +303,14 @@ static constexpr std::array kIdMappingModes = { IdMappingMode::EXACT, IdMappingMode::LOOP}; +// Used to annotate the special memory intrinsics that a loadstore +// op will be lowered to. enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync }; +// Used to label what part of the double buffered iterdomain +// a for loop is materializing. +enum class DoubleBufferLoopStage { NotApplicable, Prolog, Main, Epilog }; + // Returns if function needs an f suffix on the operator when operating on a // float value i.e. sin->sinf bool needFloatSuffix(UnaryOpType t); @@ -332,6 +338,9 @@ TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IdMappingMode); TORCH_CUDA_CU_API std::ostream& operator<<( std::ostream&, const LoadStoreOpType); +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream&, + const DoubleBufferLoopStage); std::string stringifyBooleanOp(const UnaryOpType); std::string stringifyBooleanOp(const BinaryOpType); From b18ca04d67427f9e39e36eb80547123b412ea813 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 24 Jun 2022 10:09:16 -0700 Subject: [PATCH 07/12] comments --- torch/csrc/jit/codegen/cuda/compute_at_map.h | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index fce36ca1f29de..44f541f0a6a63 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -125,6 +125,7 @@ class TORCH_CUDA_CU_API ComputeAtMap { void validateAndPropagatePType(); //! Run through disjoint sets in the LOOP map and allocate the index + //! variable for the associated for loop that will be generated //! for each disjoint sets in the loop map. This pre-allocation makes //! 2 key assumptions about computeAt map that would very likely be //! long term invariant: @@ -173,6 +174,9 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Returns the pre-allocated index variable integer used in //! the kir::ForLoop corresponding to the given IterDomain. + //! this interface is only valid if the ID has a loop mapping, + //! ca_map will throw exceptions if given iterdomain doesn't + //! have a loop map entry. Val* getIndexVariable( IterDomain* id, DoubleBufferLoopStage double_buffer_loop_stage = @@ -206,11 +210,14 @@ class TORCH_CUDA_CU_API ComputeAtMap { IterDomain*> concrete_id_cache_; - // Allocated Loop index variable through the CA map. + //! Allocated Loop index variable through the CA map. + //! only valid for disjoint sets on the loop ca map. std::unordered_map*, Val*> loop_index_variable_map_; - // Allocated loop indices for double buffer loop. + //! Allocated loop indices for double buffer loop. + //! only valid for disjoint sets on the loop ca map + //! that have double buffer-ed iterdomains. using DoubleBufferIndicesPtr = std::unique_ptr; std::unordered_map< const VectorOfUniqueEntries*, From 456bdcb2c18fbb98b2266d586ec399c287515d9d Mon Sep 17 00:00:00 2001 From: "S. Song" <41357537+shmsong@users.noreply.github.com> Date: Fri, 24 Jun 2022 16:20:15 -0700 Subject: [PATCH 08/12] Use IterDomain Graph to generate gmem consumer indexing (#1734) Co-authored-by: Christian Sarofeen --- build_variables.bzl | 1 + torch/csrc/jit/codegen/cuda/index_compute.cpp | 83 ++- torch/csrc/jit/codegen/cuda/index_compute.h | 23 + .../jit/codegen/cuda/index_idgraph_utils.cpp | 693 ++++++++++++++++++ .../jit/codegen/cuda/index_idgraph_utils.h | 142 ++++ .../codegen/cuda/index_reference_replay.cpp | 4 +- .../jit/codegen/cuda/index_reference_replay.h | 8 + torch/csrc/jit/codegen/cuda/lower_shift.cpp | 91 +++ torch/csrc/jit/codegen/cuda/lower_shift.h | 7 + torch/csrc/jit/codegen/cuda/lower_utils.cpp | 5 + torch/csrc/jit/codegen/cuda/lower_utils.h | 4 + 11 files changed, 1040 insertions(+), 21 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp create mode 100644 torch/csrc/jit/codegen/cuda/index_idgraph_utils.h diff --git a/build_variables.bzl b/build_variables.bzl index 2e8ff13bed1af..b258c0337f455 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -656,6 +656,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/graph_fuser.cpp", "torch/csrc/jit/codegen/cuda/grouped_reduction.cpp", "torch/csrc/jit/codegen/cuda/index_compute.cpp", + "torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp", "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp", "torch/csrc/jit/codegen/cuda/instrumentation.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index a0c67ad4ef272..38e751665c5aa 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -325,9 +326,9 @@ Val* getProducerIndexWithPartialSplit( } // namespace void IndexCompute::handle(Split* split) { - auto in_id = split->in()->as(); - auto outer_id = split->outer()->as(); - auto inner_id = split->inner()->as(); + auto in_id = maybeGetExactMapConcreteID(split->in()->as()); + auto outer_id = maybeGetExactMapConcreteID(split->outer()->as()); + auto inner_id = maybeGetExactMapConcreteID(split->inner()->as()); auto outer_it = index_map_.find(outer_id); auto inner_it = index_map_.find(inner_id); @@ -382,9 +383,9 @@ void IndexCompute::handle(Split* split) { } void IndexCompute::handle(Merge* merge) { - auto out_id = merge->out(); - auto outer_id = merge->outer(); - auto inner_id = merge->inner(); + auto out_id = maybeGetExactMapConcreteID(merge->out()); + auto outer_id = maybeGetExactMapConcreteID(merge->outer()); + auto inner_id = maybeGetExactMapConcreteID(merge->inner()); auto out_it = index_map_.find(out_id); if (out_it == index_map_.end()) { @@ -397,6 +398,8 @@ void IndexCompute::handle(Merge* merge) { if (isZero(out_id)) { index_map_[outer_id] = zero; index_map_[inner_id] = zero; + // TODO: Why do we set extent_map_ to zero? This has to be protected by zero + // merged in, but seems logical to me the extent would still be one. extent_map_[outer_id] = zero; extent_map_[inner_id] = zero; zero_domains_.emplace(outer_id); @@ -462,12 +465,18 @@ void IndexCompute::handle(Merge* merge) { index_map_[inner_id] = zero; extent_map_[outer_id] = getExtent(out_id); + if (hasZeroMerged(out_id)) { + zero_merged_in_.insert(outer_id); + } } else if (outer_id->isBroadcast() && outer_extent->isOneInt()) { // Propagate away from broadcast dims index_map_[outer_id] = zero; index_map_[inner_id] = out_ind; extent_map_[inner_id] = getExtent(out_id); + if (hasZeroMerged(out_id)) { + zero_merged_in_.insert(inner_id); + } } else if (hasZeroMerged(out_id)) { // Don't propagate to inner id if it's comprised of only broadcast root // domains, unless outer is also all broadcast domains. Index shouldn't be @@ -583,6 +592,25 @@ IndexCompute::IndexCompute( } } +IndexCompute::IndexCompute( + std::unordered_map initial_index_map, + std::unordered_set zero_domains, + std::unordered_set preferred_paths, + std::unordered_map reference_halo_extent_map) + : index_map_(std::move(initial_index_map)), + zero_domains_(std::move(zero_domains)), + preferred_paths_(std::move(preferred_paths)), + reference_halo_extent_map_(std::move(reference_halo_extent_map)) { + FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute"); + concrete_id_pass_ = true; +} + +void IndexCompute::run(const LoopIndexing& loop_indexing) { + for (auto expr : loop_indexing.getBackwardExprList()) { + handle(expr); + } +} + void IndexCompute::run() { const std::vector domain_vals( td_->domain().begin(), td_->domain().end()); @@ -590,6 +618,14 @@ void IndexCompute::run() { traverseFrom(td_->fusion(), domain_vals, false); } +IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) { + if (concrete_id_pass_) { + return GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); + } + return id; +} + Val* IndexCompute::getExtent(IterDomain* id) const { // Pick from extent_map_ if available. Previously parallel // dimensions were ued (e.g., blockDim.x), however, it would result @@ -651,6 +687,25 @@ IndexCompute IndexCompute::updateIndexCompute( contig_finder, {}, reference_halo_extent_map); + + if (concrete_id_pass_) { + // This should be the same behavior as with a reference tensor + // created, since originally halo was pulled through exact + // ca mapping and in the concrete_id_pass case, the id_map + // also represents exact ca mapping. + // TODO: might need to re-visit pathological cases when we may + // need to traverse and propagate halo info again in here. + for (auto id_entry : id_map) { + IterDomain* prev_id = id_entry.first; + IterDomain* new_id = id_entry.second; + auto halo_extent_it = reference_halo_extent_map_.find(prev_id); + if (halo_extent_it != reference_halo_extent_map_.end()) { + updated_index_compute.reference_halo_extent_map_[new_id] = + halo_extent_it->second; + } + } + } + updated_index_compute.run(); return updated_index_compute; @@ -1829,25 +1884,15 @@ std::vector Index::getGlobalConsumerStridedIndices( // dims where index should be set to 0 auto ref_compute = getReferenceIndexing(loops, reference_domain); - // Index into consumer using reference indexing - - // Adds halo info mappings for the reference - updateHaloInfoForReference(reference, consumer_tv); - - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, index_map_ref_to_consumer); - ContigIDs contig_finder( consumer_tv->domain()->domain(), consumer_tv->getMaybeRFactorDomain(), consumer_tv->domain()->contiguity(), reference_id_map); - auto consumer_indexing = ref_compute.updateIndexCompute( - consumer_tv->domain(), - index_map_ref_to_consumer, - contig_finder, - reference_halo_extent_map); + auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); + + auto consumer_indexing = index_from_id_graph.index; // Indices should now be mapped onto IterDomains in consumer, so just grab // and use them. diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 1a88b00fa25c9..defcbe4439c68 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -61,6 +61,7 @@ namespace fuser { namespace cuda { class ContigIDs; +class LoopIndexing; class IndexCompute : public BackwardVisitor { protected: @@ -78,6 +79,12 @@ class IndexCompute : public BackwardVisitor { //! True if any dependent of a domain is not used to index bool hasZeroMerged(IterDomain* id) const; + //! Returns the concrete ID from the compute at EXACT mode map if + //! concrete_id_pass == true, otherwise returns id passed in. + //! Helps unify the expr handling logic in reference domain and concrete id + //! based traversal. + IterDomain* maybeGetExactMapConcreteID(IterDomain* id); + // Tensor domain we're mapping back to root const TensorDomain* td_; // NOLINT @@ -118,6 +125,10 @@ class IndexCompute : public BackwardVisitor { // reference tensor std::unordered_map reference_halo_extent_map_; + // Temporary flag which tells IndexCompute to use concrete id's from the exact + // map rather than the actual IDs used in the ID expressions. + bool concrete_id_pass_ = false; + public: const std::unordered_map& indexMap() const { return index_map_; @@ -159,6 +170,14 @@ class IndexCompute : public BackwardVisitor { std::unordered_set preferred_paths = {}, std::unordered_map reference_halo_extent_map = {}); + // Entry point used for using concrete id based traversal. This traversal is + // assumed to start at leaf IDs provided by initial_index_map. + IndexCompute( + std::unordered_map initial_index_map, + std::unordered_set zero_domains, + std::unordered_set preferred_paths, + std::unordered_map concrete_halo_extent_map); + // Updates index_map, extent_map, and zero_merged_in based on id_map and // returns a new IndexCompute ready to be used. IndexCompute updateIndexCompute( @@ -168,6 +187,10 @@ class IndexCompute : public BackwardVisitor { const std::unordered_map& reference_halo_extent_map = {}) const; + // Interface to run index traversal through loop indexing analysis result to + // be used with the entry point for concrete id based traversal. + void run(const LoopIndexing& loop_indexing); + virtual void run(); }; diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp new file mode 100644 index 0000000000000..fd0ffe13753af --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp @@ -0,0 +1,693 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +IndexFromIdGraph::IndexFromIdGraph( + IndexCompute index_, + std::unordered_map initial_concrete_index_map_, + std::vector loop_domains_) + : index(index_), + initial_concrete_index_map(initial_concrete_index_map_), + resolved_loop_domains(loop_domains_) {} + +namespace { + +//! A struct to keep track of necessary parameters used in +//! configuring index compute pass. +//! These parameters are needed to propagate the indexing from the leaf nodes of +//! the TVs and loop nests to the TVs rfactor domain during +//! index_compute.cpp::IndexCompute passes. +//! TODO: +//! Would expect this list to become shorter over time, +//! as more info can be determined holistically. +struct IndexingParameters { + //! Initial binding of index math to concrete iterdomain ids, + //! from the loop nest analysis. + std::unordered_map initial_concrete_id_index; + + //! (Used in non-global indexing) the concrete iterdomains that + //! we want to skip or merge into contiguous indexing paths. + std::unordered_set zero_domains; + + //! (Used in non-global indexing) the preferred path we would + //! be propagating contiguously merged indices backward. + std::unordered_set preferred_concrete_ids; + + //! The inferred halo padded extents of the concrete iterdomains. + std::unordered_map concrete_id_to_halo_extent; +}; + +// Initial loop index map for global producer or consumer case. +IndexingParameters getGlobalIndexParameters( + const LoopIndexing& loop_indexing, + bool index_producer = false) { + IndexingParameters index_parameters; + TORCH_INTERNAL_ASSERT(!index_producer, " not yet implemented"); + + auto& loops = loop_indexing.loops(); + auto& loop_domain = loop_indexing.loopDomains(); + auto& loop_index_map = index_parameters.initial_concrete_id_index; + + for (auto loop_idx : c10::irange(loops.size())) { + auto loop = loops[loop_idx]; + auto index_domain = ir_utils::caMapExactConcreteId(loop_domain[loop_idx]); + if (loop->isTrivial()) { + // This is useful information in the case of + // MisalignedVectorize and double buffer epilog, etc. + loop_index_map[index_domain] = loop->start(); + } else { + // Default use pre-allocated integers for index + loop_index_map[index_domain] = loop->index(); + } + } + + // Derive the halo extents from the loop indexing result. + index_parameters.concrete_id_to_halo_extent = + GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + + return index_parameters; +} + +} // namespace + +class LoopIndexingAnalysis { + public: + static LoopIndexing fromLoopAndConsumer( + const std::vector& loops, + const TensorView* consumer_tv) { + LoopIndexingAnalysis analysis(loops, consumer_tv); + return analysis.getLoopIndexing(); + } + + private: + explicit LoopIndexingAnalysis( + const std::vector& loops, + const TensorView* consumer_tv); + + //! Populate derived information into a LoopIndexing + //! data structure. + LoopIndexing getLoopIndexing() { + LoopIndexing indexing; + indexing.loops_ = loops_; + indexing.consumer_tv_ = consumer_tv_; + indexing.loop_root_ = loop_root_domains_; + indexing.loop_domains_ = loop_domains_.vector(); + indexing.index_exprs_ = replayed_exprs_; + return indexing; + } + + //! Validates that the current loop structure is well formed, in the sense + //! that ca_map would not map any two loops in the loop nest together. + void validateLoopStructure(const std::vector& loops); + + //! Start at the loop iter domains, and traverse back into history on the + //! concrete IDs in the exact map calling "visitExpr" expressions through the + //! history. + void traverseFromDomainVals(); + + //! Concretize the given iterdomain and record the visit (in deterministic + //! order) in terms of the exact mapped concrete id. Marks the mapping of the + //! id to the concrete id in "concrete_to_original_id_" and returns the + //! concrete id. + IterDomain* concretizeAndVisitId(IterDomain* id); + + //! If an equivalent expression has already been processed this function + //! simply returns. Otherwise puts the exact concrete IDs of inputs in + //! consumed_concrete_, and concrete IDs of outputs in produced_concrete_. + //! Then adds the expression to replayed_exprs_. + void visitExpr(Expr* expr); + + //! Iterates through provided vals, calls concretizeAndVisitId on them, and + //! returns if any of the returned vals are in existing_ids. This is used to + //! check if inputs or outputs of ID expressions have already been + //! produced/consumed in the traversal. Indexing only needs to consume/produce + //! one IterDomain per exact disjoint set. + bool visitIdsAndCheckDuplication( + const std::vector& vals, + const std::unordered_set& existing_ids); + + //! Fills loop_domains_ with the corresponding replayed_concrete_id mapping to + //! the provided loops. Must be done after the exact iterdomain "replay" + //! (traverseFromDomainVals). loop_domains_ are the original_id not the + //! concrete_id (translated with concrete_to_original_id). These iter domains + //! are used to grab the history that will be replayed in IndexCompute. We're + //! looking for "new" root domains and subsequent transformations, filling in + //! any missing "outputs" (or inputs for backward traversal). Then fills + //! loop_domains_ with all of these iter domains. + void constructLoopDomains(); + + private: + //! Original loop nest input to derive info from. + const std::vector& loops_; + + //! Original consumer tv to derive view info from. + const TensorView* consumer_tv_ = nullptr; + + // Exact concrete domains that has been used + // in the traversal connection. + std::unordered_set produced_concrete_; + std::unordered_set consumed_concrete_; + + //! Iterdomains that the corresponding loops are generated from. + std::vector initial_loop_domain_ids_; + + //! All Id's in consumer's transform history + std::vector all_consumer_id_vals_; + + //! Concrete iterdomains visited in the domain traversal, + //! in the order they are visited in traverseFromDomainVals. + VectorOfUniqueEntries replayed_concrete_ids_; + + //! Keeping track of the original visited id's before they + //! were concretized. + std::unordered_map concrete_to_original_id_; + + //! Map from concrete id to its single consumer on the selected + //! iterdomain expression list. + std::unordered_map concrete_id_to_consumer_; + + //! Source domains that all the Iterdomain transforms + //! in the loop nest originated from. + std::vector loop_root_domains_; + + //! Leaf domains representing the original loop structure + VectorOfUniqueEntries loop_domains_; + + //! Selected list of exprs that will produce and consume each + //! of the exact concrete ids from the loop nest exactly once. + std::vector replayed_exprs_; +}; + +LoopIndexingAnalysis::LoopIndexingAnalysis( + const std::vector& loops, + const TensorView* consumer_tv) + : loops_(loops), consumer_tv_(consumer_tv) { + // Validate consistency in given loop nest + validateLoopStructure(loops); + + // Populate initial loop iter domains. + std::transform( + loops.begin(), + loops.end(), + std::back_inserter(initial_loop_domain_ids_), + [](kir::ForLoop* fl) { return fl->iter_domain(); }); + + // Collect consumer id's for view rfactor traversal. + all_consumer_id_vals_ = DependencyCheck::getAllValsBetween( + {consumer_tv->getRootDomain().begin(), + consumer_tv->getRootDomain().end()}, + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); + + // Resolve definition of each exact concrete id's involved in the whole loop + // nest transform history + traverseFromDomainVals(); + + // Construct concrete to consumer map. The replayed exprs are guaranteed to + // consume each concrete id once so this map is well defined. + for (auto expr : replayed_exprs_) { + for (auto input_id : ir_utils::filterByType(expr->inputs())) { + concrete_id_to_consumer_[ir_utils::caMapExactConcreteId(input_id)] = expr; + } + } + + // Reconstruct the iterdomain view of the original loopnest after resolving + // the exact definition of each index. + constructLoopDomains(); +} + +void LoopIndexingAnalysis::validateLoopStructure( + const std::vector& loops) { + // Throw an error when two loops are mapped with each other, which + // violates an assumption that unique mappings between concrete + // IterDomains and the IterDomains of the loop structure must be + // established. It should be a reasonable assumption, but fusions + // like below won't work: + // tv0 = [I0] + // tv1 = broadcast(tv0, {true, false}); + // tv2 = broadcast(tv0, {false, true}); + // tv3 = tv1 + tv2 + // Notice that the two axes of each of tv1, tv2 and tv3 are mapped + // with each other. We believe it is unlikely this limitation + // becomes a real concern in practice. + // Map concrete id to the original loop iter domain. + std::unordered_map concrete_to_loop; + for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) { + // Largely duplicating original logic + auto loop_id = (*it_i)->iter_domain(); + auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id); + + TORCH_INTERNAL_ASSERT( + !concrete_to_loop.count(concrete_loop_id), + "Unsupported loop structure. Two loops are mapped together.", + loop_id->toString(), + " and ", + concrete_to_loop.at(concrete_loop_id)->toString()); + + concrete_to_loop[concrete_loop_id] = loop_id; + } +} + +void LoopIndexingAnalysis::traverseFromDomainVals() { + // Order is really important here, start with outer most for loops in a + // depth first manner. The outer most loops are topologically closer to the + // outputs, so their broadcast dimensions are "more" resolved than those + // towards the inner most loops. + std::deque to_visit( + initial_loop_domain_ids_.begin(), initial_loop_domain_ids_.end()); + std::unordered_set visited_exprs; + std::unordered_set visited_ids; + + while (!to_visit.empty()) { + auto out_id = to_visit.front(); + to_visit.pop_front(); + + if (!visited_ids.emplace(out_id).second) { + continue; + } + auto expr = out_id->definition(); + + if (auto rfactor_id = + getRfactorIDToTraverse(out_id, all_consumer_id_vals_)) { + to_visit.emplace_front(rfactor_id); + } + + // ID's will be copied for the reference as we replay transformations. If + // there was no transformations on an iteration domain, a copy of the + // iteration domain for the reference is made here. + if (expr == nullptr) { + if (std::find( + initial_loop_domain_ids_.begin(), + initial_loop_domain_ids_.end(), + out_id) != initial_loop_domain_ids_.end()) { + concretizeAndVisitId(out_id); + } + continue; + } + + if (!visited_exprs.emplace(expr).second) { + continue; + } + + visitExpr(expr); + + auto inp_ids = ir_utils::filterByType(expr->inputs()); + // Make sure to put at the begining of the deque to maintain correct + // ordering. + to_visit.insert(to_visit.begin(), inp_ids.begin(), inp_ids.end()); + } +} + +IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) { + auto concrete_id = ir_utils::caMapExactConcreteId(id); + if (replayed_concrete_ids_.pushBack(concrete_id)) { + concrete_to_original_id_[concrete_id] = id; + } + return concrete_id; +} + +void LoopIndexingAnalysis::visitExpr(Expr* expr) { + // Current implementation just tries to + // follow the exact behavior of reference replay + // except that no expr was actually "replayed". + + // Record all inputs, and stop if current expr + // duplicates id consumption or production. + if (visitIdsAndCheckDuplication(expr->inputs(), consumed_concrete_)) { + return; + } + if (visitIdsAndCheckDuplication(expr->outputs(), produced_concrete_)) { + return; + } + + // Record the expr if no duplication on input or output found + replayed_exprs_.push_back(expr); + + // Record the consumed and produced concrete ids by the newly + // recorded expression. + auto consumed_ids = ir_utils::filterByType(expr->inputs()); + std::transform( + consumed_ids.begin(), + consumed_ids.end(), + std::inserter(consumed_concrete_, consumed_concrete_.end()), + ir_utils::caMapExactConcreteId); + + auto produced_ids = ir_utils::filterByType(expr->outputs()); + std::transform( + produced_ids.begin(), + produced_ids.end(), + std::inserter(produced_concrete_, produced_concrete_.end()), + ir_utils::caMapExactConcreteId); +} + +bool LoopIndexingAnalysis::visitIdsAndCheckDuplication( + const std::vector& vals, + const std::unordered_set& existing_ids) { + bool duplication = false; + for (auto id : ir_utils::filterByType(vals)) { + duplication = duplication || existing_ids.count(concretizeAndVisitId(id)); + } + return duplication; +} + +void LoopIndexingAnalysis::constructLoopDomains() { + for (auto loop_id : initial_loop_domain_ids_) { + // Find the replayed_concrete_id mapping to the loop id. + auto ref_id_it = std::find_if( + replayed_concrete_ids_.vector().begin(), + replayed_concrete_ids_.vector().end(), + [&](IterDomain* concrete_id) { + return + // Make sure the replayed_concrete_id is a leaf ID + !concrete_id_to_consumer_.count(concrete_id) && + // Use permissive map so the selected ID indeed represents the + // loop. + GpuLower::current()->caMap()->areMapped( + concrete_id, loop_id, IdMappingMode::PERMISSIVE); + }); + + TORCH_INTERNAL_ASSERT( + ref_id_it != replayed_concrete_ids_.vector().end(), + "Could not find required iter domain in reference replay: ", + loop_id->toString()); + + auto ref_id = *ref_id_it; + loop_domains_.pushBack(concrete_to_original_id_.at(ref_id)); + } + + // Construct the root domain as the inputs of the replayed domain + auto loops_replayed_domain_vals = + ir_utils::filterByType(loop_domains_.vector()); + auto root_domain_vals = IterVisitor::getInputsTo( + {loops_replayed_domain_vals.begin(), loops_replayed_domain_vals.end()}); + + // Fill loop roots: + auto root_domain_ids = ir_utils::filterByType(root_domain_vals); + loop_root_domains_ = + std::vector(root_domain_ids.begin(), root_domain_ids.end()); + + // The domain may have dangling iteration domains, i.e. the inner output of + // a split but not the outer. Find which replayed vals are dependant on the + // root domains. + auto all_replayed_vals = + ir_utils::filterByType(replayed_concrete_ids_.vector()); + auto all_ids_from_root = DependencyCheck::getAllValsBetween( + {root_domain_vals.begin(), root_domain_vals.end()}, + {all_replayed_vals.begin(), all_replayed_vals.end()}); + + // Fill all dangling outputs as otherwise backwards visitor in index compute + // will complain for not having all outputs of the traversal. + for (auto id : ir_utils::filterByType(all_ids_from_root)) { + if (id->uses().empty()) { + loop_domains_.pushBack(ir_utils::caMapExactConcreteId(id)); + } + } +} + +IndexFromIdGraph getTensorIndexFromIdGraph( + const std::vector& loops, + const TensorView* consumer_tv, + const TensorView* producer_tv, + bool is_global, + std::unordered_map c2p_map) { + bool index_producer = producer_tv != nullptr; + auto target_tv = index_producer ? producer_tv : consumer_tv; + + // TODO: remove this check when adding other indexing + // support. + TORCH_INTERNAL_ASSERT( + is_global && !index_producer, + "ConsumerIndexFromIdGraph: currently only global consumer indexing is supported."); + + auto loop_indexing = + LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); + + IndexingParameters index_parameters; + + if (is_global) { + index_parameters = getGlobalIndexParameters(loop_indexing, index_producer); + } else { + TORCH_INTERNAL_ASSERT(false, "not yet implemented"); + } + + // Setup IndexCompute to traverse backwards through concrete IDs. + IndexCompute indexing( + index_parameters.initial_concrete_id_index, + index_parameters.zero_domains, + index_parameters.preferred_concrete_ids, + index_parameters.concrete_id_to_halo_extent); + + // Run first backward traversal to generate + // loop nest based indexing math. + indexing.run(loop_indexing); + + // Populate indexing through exact map from initial indexing + + // First collect all iterdomains in consumer transform history. + auto all_consumer_vals = DependencyCheck::getAllValsBetween( + {consumer_tv->getMaybeRFactorDomain().begin(), + consumer_tv->getMaybeRFactorDomain().end()}, + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); + + // Indexable domains are the concrete id's we visited when + // traversing the "reference" indexing pass. + std::unordered_map initial_indexable_map; + + // Map the concrete id indexing back to the producer or consumer tv + std::unordered_map index_update_map; + + for (IterDomain* consumer_id : + ir_utils::filterByType(all_consumer_vals)) { + // Track the non-concrete id we were trying to bind index + // to, whether from producer or consumer. + auto target_id = consumer_id; + + // use mapped producer id when indexing producer + if (index_producer) { + auto target_id_it = c2p_map.find(consumer_id); + if (target_id_it == c2p_map.end()) { + // consumer id not found in c2p map + // skip binding for this id. + continue; + } + target_id = target_id_it->second; + } + + // Exact id will have to be pulled from consumer side as the + // producer side are replayed ids. + auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + consumer_id, IdMappingMode::EXACT); + + index_update_map[exact_concrete_id] = target_id; + + // Keep track of concrete id's that were used for indexing. + if (indexing.indexMap().count(exact_concrete_id)) { + initial_indexable_map[exact_concrete_id] = exact_concrete_id; + } + } + + // TODO: + // This part will be filled in when updating the producer + // indexing logic. This is just placeholder for now. + std::unordered_map p2c_map; + + // No contig indexing was done in reference indexing + ContigIDs contig_finder( + consumer_tv->domain()->domain(), + consumer_tv->getMaybeRFactorDomain(), + consumer_tv->domain()->contiguity(), + initial_indexable_map, + p2c_map); + + auto target_indexing = indexing.updateIndexCompute( + target_tv->domain(), index_update_map, contig_finder); + + // Fill validation info. + // TODO: cleanup seems possible. + if (is_global) { + if (index_producer) { + fillProducerVectorizedContigRootDomains( + target_tv, consumer_tv, c2p_map, contig_finder); + } else { + fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder); + } + } + + return IndexFromIdGraph( + target_indexing, + index_parameters.initial_concrete_id_index, + loop_indexing.loopDomains()); +} + +namespace { + +class LoopIndexingTraversal { + enum class TraversalOrder { ForwardTopological, BackwardTopological }; + + public: + static std::vector forwardTopologicalOrder( + const std::vector& exprs) { + LoopIndexingTraversal traversal(exprs, TraversalOrder::ForwardTopological); + return traversal.getExprList(); + } + + static std::vector backwardTopologicalOrder( + const std::vector& exprs) { + LoopIndexingTraversal traversal(exprs, TraversalOrder::BackwardTopological); + return traversal.getExprList(); + } + + private: + explicit LoopIndexingTraversal( + const std::vector& exprs, + TraversalOrder traversal_order); + + // Returns the vals following the expression in either + // forward or backward order. + const std::vector& nextValsInTraversalOrder(Expr* expr); + + // Returns the vals that the expression follows in either + // forward or backward order. + const std::vector& prevValsInTraversalOrder(Expr* expr); + + // Returns the sorted list according to the given traversal order. + std::vector getExprList(); + + private: + // Reference to original un-sorted expression list. + const std::vector& exprs_; + + // The traversal order in this pass. + const TraversalOrder traversal_order_ = TraversalOrder::ForwardTopological; + + // Internal record of concrete id's and it's corresponding + // iterdomain expression that defines the exact index. + std::unordered_map concrete_id_to_dependency_; +}; + +LoopIndexingTraversal::LoopIndexingTraversal( + const std::vector& exprs, + TraversalOrder traversal_order) + : exprs_(exprs), traversal_order_(traversal_order) { + // Populate concrete id dependencies: + for (auto expr : exprs_) { + auto next_ids = + ir_utils::filterByType(nextValsInTraversalOrder(expr)); + for (auto id : next_ids) { + auto concrete_id = ir_utils::caMapExactConcreteId(id); + TORCH_INTERNAL_ASSERT( + concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr)) + .second, + "Repeated dependency, invalid iterdomain traversal."); + } + } +} + +const std::vector& LoopIndexingTraversal::nextValsInTraversalOrder( + Expr* expr) { + switch (traversal_order_) { + case TraversalOrder::ForwardTopological: + return expr->outputs(); + break; + case TraversalOrder::BackwardTopological: + return expr->inputs(); + break; + + default: + TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order"); + } + return expr->inputs(); +} + +const std::vector& LoopIndexingTraversal::prevValsInTraversalOrder( + Expr* expr) { + switch (traversal_order_) { + case TraversalOrder::ForwardTopological: + return expr->inputs(); + break; + case TraversalOrder::BackwardTopological: + return expr->outputs(); + break; + + default: + TORCH_INTERNAL_ASSERT(false, "unimplemented traversal order"); + } + return expr->inputs(); +} + +std::vector LoopIndexingTraversal::getExprList() { + std::deque to_visit(exprs_.begin(), exprs_.end()); + + // pre-allocate result space. + std::vector result; + result.reserve(exprs_.size()); + + // Keeps track of visited and inserted expressions. + // An expr is visited if it has been placed in result list. + // An expr is inserted if the traversal has put the expr on + // the top of the stack once. Repeated insertion of the same + // expression would never be observed if the underlying + // dependency of the expressions is cycle free. + std::unordered_set visited, inserted; + + while (!to_visit.empty()) { + auto top = to_visit.front(); + if (visited.count(top)) { + to_visit.pop_front(); + continue; + } + + bool ready = true; + + for (auto prev_id : + ir_utils::filterByType(prevValsInTraversalOrder(top))) { + auto prev_expr_it = concrete_id_to_dependency_.find( + ir_utils::caMapExactConcreteId(prev_id)); + if (prev_expr_it != concrete_id_to_dependency_.end()) { + auto prev_expr = prev_expr_it->second; + if (!visited.count(prev_expr)) { + ready = false; + to_visit.push_front(prev_expr); + TORCH_INTERNAL_ASSERT( + inserted.insert(prev_expr).second, + "Circular dependency in loop index expressions."); + break; + } + } + } + + if (ready) { + visited.insert(top); + result.emplace_back(top); + to_visit.pop_front(); + } + } + + return result; +} + +} // namespace + +std::vector LoopIndexing::getForwardExprList() const { + return LoopIndexingTraversal::forwardTopologicalOrder(index_exprs_); +} + +std::vector LoopIndexing::getBackwardExprList() const { + return LoopIndexingTraversal::backwardTopologicalOrder(index_exprs_); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h new file mode 100644 index 0000000000000..5f6eadb70fe99 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h @@ -0,0 +1,142 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Struct to hold useful information from an index pass on iterdomain graph. +// Used to return the IndexCompute structure back to the indexing calls in +// index_compute.cpp. Other structurs are required to resolve the actual +// indexing math there. +struct IndexFromIdGraph { + IndexCompute index; + std::unordered_map initial_concrete_index_map; + std::vector resolved_loop_domains; + + explicit IndexFromIdGraph( + IndexCompute index, + std::unordered_map initial_concrete_index_map, + std::vector loop_domains); +}; + +//! Indexing interface, returns IndexFromIdGraph which the IndexCompute object +//! can be queried from directly for the produced indexing. If producer_tv != +//! nullptr producer will be indexed, if producer_tv == nullptr consumer will be +//! indexed. If is_global global indexing will be done, else error will be +//! thrown (TODO: support local/shared memory indexing). +IndexFromIdGraph getTensorIndexFromIdGraph( + const std::vector& loops, + const TensorView* consumer_tv, + const TensorView* producer_tv = nullptr, + bool is_global = true, + std::unordered_map c2p_map = {}); + +//! getTensorIndexFromIdGraph is the function that index_compute will call very +//! straightforwardly. However, for implementing the new indexing logic that +//! starts to abstract some of the indexing away from index_compute we need to +//! move quite a bit of the intertwined indexing logic away from the +//! index_compute file and the index_reference_replay file. This is because we +//! want to separate out what has to be done on the fly, from what analysis we +//! can do early on with the iter domain graph and associated properties. +//! +//! getTensorIndexFromIdGraph places this analysis internally in +//! LoopIndexingAnalysis. LoopIndexingAnalysis though has to communicate to: +//! 1) index_compute.cpp::IndexCompute to tell IndexCompute which expressions +//! it needs to traverse to compute the indexing math. +//! 2) lower_shift.cpp::HaloInfo::buildConcreteHaloExtentMap to build the halo +//! extent map used in indexing. +//! +//! LoopIndexing is nothing but a mechanism for this communication. +//! +//! Holds information needed to produce indexing math. In the current version of +//! indexing pass, the iter domains combined with the loop nests are the source +//! of truth in terms of resolving the actual integer indexing math from the +//! sequence of iterdomain transforms. +//! +//! This information is crtiical in resolving indexing associated with complex +//! broadcast patterns. Check FusionComplexBCast* test cases as well as +//! FusionAdvancedIndexing* for examples where resolving indices from IterDomain +//! transformations can be challenging. +//! +//! The source of this challenge is due to inling patterns where the IterDomains +//! responsible for control flow are not local to a particular TensorView. +//! Broadcast, operations like view/reshape, and gather/shift can make indexing +//! local buffers complex because of the complex effects inlining into other +//! TensorViews produce. +//! +//! TODO: +//! The first iteration tries to match the semantics of reference +//! replay without any new logic. In a follow up iteration will +//! need to revisit a few further pathological patterns. +//! +//! Note: +//! The current implementation of loop indexing pass works on +//! equivalent classes defined by ComputeAt exact map. The +//! list of expressions stored in this class form a "reference", graph of +//! iterdomain expressions when all of their inputs and outputs are replaced +//! with their exact concrete mapped id's. +//! +//! Here an invariant in a graph of iterdomain expressions is that +//! each iterdomain is produced exactly once and is either a leaf domain +//! or has been consumed exactly once by another expression. This makes sure +//! that a well defined indexing can be generated for each of the concrete ids +//! whenever we either forward or backward traverse the graph. +class LoopIndexing { + public: + //! Returns the original loop nest. + const auto& loops() const { + return loops_; + } + + //! Returns the vector of Iterdomains + //! that match the original loop pattern. + const auto& loopDomains() const { + return loop_domains_; + } + + //! Returns the consumer tv that the view info + //! was derived from. + auto consumerTv() const { + return consumer_tv_; + } + + //! Returns the set of Iterdomain transforms that + //! define the correct indexing path, in forward + //! topological order. + std::vector getForwardExprList() const; + + //! Returns the set of Iterdomain transforms that + //! define the correct indexing path, in backward + //! topological order. + std::vector getBackwardExprList() const; + + private: + friend class LoopIndexingAnalysis; + + //! The loop nest that this loop indexing is derived from. + std::vector loops_; + + //! Consumer tv, where the view related info was derived from. + const TensorView* consumer_tv_; + + //! The source iterdomains that all the Iterdomain transforms + //! in this loop nest originated from. + std::vector loop_root_; + + //! The leaf iterdomains that the original loop nests correspond + //! to. May be longer than loops_ with the dangling iterdomains + //! appended towards the end. + std::vector loop_domains_; + + //! The selected sequence of expressions that should represent + //! the correct indexing math from the given loop nest. + std::vector index_exprs_; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index 63b6f793d8b75..e1a93ddf0f84f 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -132,6 +132,8 @@ bool isMappedWithAny(IterDomain* id, const std::vector& ids) { }); } +} // namespace + // Get an rfactor IterDomain that is mapped with an IterDomain. If // multiple such IDs exist, select one whose input IDs are mapped with // the consumer IDs. This is to ensure the path from the leaf @@ -170,8 +172,6 @@ IterDomain* getRfactorIDToTraverse( return rfactor_ids.at(0); } -} // namespace - TensorDomain* IndexReferenceReplay::computeReplay() { // Throw an error when two loops are mapped with each other, which // violates an assumption that unique mappings between concrete diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.h b/torch/csrc/jit/codegen/cuda/index_reference_replay.h index 144b295faa7e4..3211c4a403dac 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.h +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.h @@ -108,6 +108,14 @@ std::unordered_set buildPreferredPaths( TensorDomain* reference_domain, const std::unordered_set& preferred_roots); +// Get an rfactor IterDomain that is mapped with an IterDomain. If +// multiple such IDs exist, select one whose input IDs are mapped with +// the consumer IDs. This is to ensure the path from the leaf +// IterDomains to the root matches with the consumer tensor. +IterDomain* getRfactorIDToTraverse( + IterDomain* id, + const std::vector& consumer_all_ids); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 913b246e71ac5..c3fa6ed7f8f6d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -715,6 +716,96 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { return false; } +std::unordered_map HaloInfo::buildConcreteHaloExtentMap( + const LoopIndexing& loop_indexing) { + // Use a local workspace to avoid re-defining halo info. + HaloInfo local_halo_info; + + auto& global_halo_info = GpuLower::current()->haloInfo(); + + // Setup root: + for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) { + auto consumer_index_concrete_id = + ir_utils::caMapExactConcreteId(consumer_root_id); + local_halo_info.setRootAxisInfo( + consumer_index_concrete_id, + global_halo_info.getRootAxisInfo(consumer_root_id)); + } + + // Track IDs that are generated by merging halo-extended IDs + std::unordered_set merged_shifted_ids; + + for (auto expr : loop_indexing.getForwardExprList()) { + if (auto split = dynamic_cast(expr)) { + // Merge-then-split of halo-extended IDs is not allowed + TORCH_INTERNAL_ASSERT( + merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(), + "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed"); + + auto in_id = ir_utils::caMapExactConcreteId(split->in()); + + // If no halo info is found, nothing needs to be done. This ID + // must be an ancestor of a domain set by setRootAxisInfo. + if (!local_halo_info.hasHaloWidth(in_id)) { + continue; + } + + const auto halo_width = local_halo_info.getHaloWidth(in_id); + + if (halo_width == 0) { + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(split->outer()), 0); + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(split->inner()), 0); + continue; + } + + // propagate to inner domain + auto out_id = ir_utils::caMapExactConcreteId(split->inner()); + + auto expanded_extent = + SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); + local_halo_info.extent_map_.insert({out_id, expanded_extent}); + + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(split->outer()), 0); + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(split->inner()), halo_width); + + // TODO: add support for inheritance map + } else if (auto merge = dynamic_cast(expr)) { + // If either of the two inputs has halo extension, propagate it + // to the merged output ID + auto inner_extent = local_halo_info.getExtent( + ir_utils::caMapExactConcreteId(merge->inner())); + auto outer_extent = local_halo_info.getExtent( + ir_utils::caMapExactConcreteId(merge->outer())); + if (inner_extent != nullptr || outer_extent != nullptr) { + if (inner_extent == nullptr) { + inner_extent = merge->inner()->extent(); + } + if (outer_extent == nullptr) { + outer_extent = merge->outer()->extent(); + } + auto expanded_extent = + SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); + local_halo_info.extent_map_.insert( + {ir_utils::caMapExactConcreteId(merge->out()), expanded_extent}); + // Splitting the output of this merge is not allowed, so + // remember it + merged_shifted_ids.insert(ir_utils::caMapExactConcreteId(merge->out())); + // Note that halo_width_map_ is not updated + } else { + setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0); + } + } else { + TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); + } + } + + return local_halo_info.extent_map_; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index c0fea8c1eadd2..d1500c5f9f203 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -13,6 +13,8 @@ namespace jit { namespace fuser { namespace cuda { +class LoopIndexing; + //! Auxiliary class to represent information about halo of an axis class AxisHaloInfo { public: @@ -64,6 +66,11 @@ class TORCH_CUDA_CU_API HaloInfo { //! Build mappings of extent information of a TensorDomain void build(TensorDomain* td); + //! Almost exact duplicate of build(TensorDomain* td), except that + //! the traversal was done on loop indexing expressions. + std::unordered_map buildConcreteHaloExtentMap( + const LoopIndexing& loop_indexing); + //! Set initial AxisHaloInfo of a root axis //! //! The axis does not need to be a root domain in the case of diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index fae6a5dfb1f71..c1e948d45901d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -409,6 +409,11 @@ std::vector flattenScopedExprs(const std::vector& loop_nests) { return ExprFlattener::flatten(loop_nests); } +IterDomain* caMapExactConcreteId(IterDomain* id) { + return GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); +} + } // namespace ir_utils namespace loop_utils { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 4db436bdca671..53cad32c04731 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -146,6 +146,10 @@ bool isTensorScalarFillOp(const Expr* expr); TORCH_CUDA_CU_API std::vector flattenScopedExprs( const std::vector& loop_nests); +//! Returns the concretized iterdomain according to +//! the exact compute at map. +IterDomain* caMapExactConcreteId(IterDomain* id); + } // namespace ir_utils namespace loop_utils { From 494294f9b72a6e7e62fe63ce06c45adb56c61e29 Mon Sep 17 00:00:00 2001 From: "S. Song" <41357537+shmsong@users.noreply.github.com> Date: Fri, 24 Jun 2022 17:54:21 -0700 Subject: [PATCH 09/12] Clean up reference domain dependency in gmem consumer indexing (#1735) Co-authored-by: Naoya Maruyama Co-authored-by: Christian Sarofeen --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 106 +++++++++++------- .../jit/codegen/cuda/index_idgraph_utils.cpp | 38 +++++++ .../jit/codegen/cuda/lower_index_hoist.cpp | 80 ++++++++++++- .../csrc/jit/codegen/cuda/lower_index_hoist.h | 34 ++++++ 4 files changed, 215 insertions(+), 43 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 38e751665c5aa..86af31bf278e6 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1117,23 +1117,22 @@ std::unordered_map indexMapReferenceTo( return index_map_ref_to_producer; } -Val* hoistConsumerIndex( +//! Returns an iterdomain that corresponds to the +//! indexing sub-expression to hoist or a nullopt +//! if the index should not be hoisted. +c10::optional getMaybeIndexedConsumerIdToHoist( IterDomain* consumer_root_id, const TensorView* consumer_tv, const IndexCompute& consumer_indexing, - TensorDomain* ref_td, - const IndexCompute& ref_indexing, - const std::vector& loops, Val* index) { - // If index has no defining expression, there's nothing to hoist if (isDisabled(DisableOption::IndexHoist) || index->definition() == nullptr) { - return index; + return c10::nullopt; } // The old swizzle interface, which should be deprecated, is not // supported. if (consumer_tv->swizzleType() != SwizzleType::NoSwizzle) { - return index; + return c10::nullopt; } // auto indexed_consumer_id = consumer_root_id; @@ -1151,12 +1150,30 @@ Val* hoistConsumerIndex( "Invalid contig index: ", contig_id_it->second->toString()); + return indexed_consumer_id; +} + +Val* hoistConsumerIndex( + IterDomain* consumer_root_id, + const TensorView* consumer_tv, + const IndexCompute& consumer_indexing, + TensorDomain* ref_td, + const IndexCompute& ref_indexing, + const std::vector& loops, + Val* index) { + auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist( + consumer_root_id, consumer_tv, consumer_indexing, index); + + if (!maybe_hoisted_consumer_id.has_value()) { + return index; + } + // Insert the index into the common index map. A previously inserted // val can be returned. auto common_index = GpuLower::current() ->commonIndexMap() .insert( - indexed_consumer_id, + maybe_hoisted_consumer_id.value(), consumer_tv->domain(), ref_td, ref_indexing.indexMap(), @@ -1167,6 +1184,40 @@ Val* hoistConsumerIndex( return common_index; } +// Version of hoisting without using reference tensor, +// should eventually deprecate the other one once reference +// tensor is completely deprecated. +Val* hoistConsumerIndex( + IterDomain* consumer_root_id, + const TensorView* consumer_tv, + const IndexCompute& consumer_indexing, + const std::vector& loop_domains, + const std::unordered_map initial_loop_index_map, + const std::vector& loops, + Val* index) { + auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist( + consumer_root_id, consumer_tv, consumer_indexing, index); + + if (!maybe_hoisted_consumer_id.has_value()) { + return index; + } + + // Insert the index into the common index map. A previously inserted + // val can be returned. + auto common_index = GpuLower::current() + ->commonIndexMap() + .insert( + maybe_hoisted_consumer_id.value(), + consumer_tv->domain(), + loop_domains, + initial_loop_index_map, + loops, + index) + .first; + + return common_index; +} + std::unordered_map invertOneToOneMap( const std::unordered_map& map) { std::unordered_map inverted; @@ -1868,27 +1919,8 @@ std::vector Index::getGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); - const auto gpu_lower = GpuLower::current(); - - // Get a reference tensor replayed as existing loop structure - auto reference = IndexReferenceReplay::getReference(loops, consumer_tv); - auto reference_domain = reference.domain; - auto reference_id_map = reference.concrete_to_id; - - // Map everything we can from reference to consumer using compute at index - // map. - std::unordered_map index_map_ref_to_consumer = - indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map); - - // Index into the reference tensor. Reference indexing will handle vectorized - // dims where index should be set to 0 - auto ref_compute = getReferenceIndexing(loops, reference_domain); - ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - consumer_tv->domain()->contiguity(), - reference_id_map); + auto gpu_lower = GpuLower::current(); auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); @@ -1939,11 +1971,7 @@ std::vector Index::getGlobalConsumerStridedIndices( " dim: ", dim, " id: ", - root_dom[dim]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[dim]->toString()); if (consumer_tv->domain()->contiguity()[dim]) { // If contig, used the stored stride which may be the previous @@ -1984,11 +2012,7 @@ std::vector Index::getGlobalConsumerStridedIndices( " dim: ", i, " id: ", - root_dom[i]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[i]->toString()); auto root_ind = consumer_indexing.indexMap().at(root_dom[i]); @@ -1997,8 +2021,8 @@ std::vector Index::getGlobalConsumerStridedIndices( root_dom[i], consumer_tv, consumer_indexing, - reference.domain, - ref_compute, + index_from_id_graph.resolved_loop_domains, + index_from_id_graph.initial_concrete_index_map, loops, root_ind); @@ -2021,8 +2045,6 @@ std::vector Index::getGlobalConsumerStridedIndices( TORCH_INTERNAL_ASSERT( strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); - fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder); - return strided_inds; } diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp index fd0ffe13753af..1b1f95481eedc 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp @@ -22,6 +22,39 @@ IndexFromIdGraph::IndexFromIdGraph( namespace { +void insertMagicZero( + const std::vector& loops, + const std::vector& loop_domains, + std::unordered_map& concrete_loop_idx_map) { + // Find magic zero insertion point + IterDomain* magic_zero_loop = nullptr; + + // Search for proper magic zero insertion point, + // prefer innermost. + for (auto idx : c10::irange(loops.size())) { + auto loop = loops[idx]; + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + loop_domains[idx], IdMappingMode::EXACT); + auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id); + + // Save the concrete id if this loop id is decided to + // be the insertion point by the magic zero util. + if (Index::protectWithMagicZero(loop, concrete_loop_id, loop_ind)) { + magic_zero_loop = concrete_loop_id; + } + } + + // Insert magic zero if insertion point found + if (magic_zero_loop != nullptr && + concrete_loop_idx_map.count(magic_zero_loop)) { + auto& ind = concrete_loop_idx_map.at(magic_zero_loop); + if (!ind->isConstScalar()) { + ind = SimplifyingIrBuilder::addExpr( + ind, GpuLower::current()->kernel()->magicZeroVal()); + } + } +} + //! A struct to keep track of necessary parameters used in //! configuring index compute pass. //! These parameters are needed to propagate the indexing from the leaf nodes of @@ -75,6 +108,11 @@ IndexingParameters getGlobalIndexParameters( index_parameters.concrete_id_to_halo_extent = GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + insertMagicZero( + loops, + loop_indexing.loopDomains(), + index_parameters.initial_concrete_id_index); + return index_parameters; } diff --git a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp index 309867477924f..2772bf04d9eef 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp @@ -94,6 +94,62 @@ CommonIndexKey::CommonIndexKey( loops.size()); } +CommonIndexKey::CommonIndexKey( + IterDomain* consumer_indexed_id, + TensorDomain* consumer_td, + std::vector loop_domains, + const std::unordered_map& loop_index_map, + const std::vector& loops) { + auto gpu_lower = GpuLower::current(); + + concrete_indexed_id_ = gpu_lower->caMap()->getConcreteMappedID( + consumer_indexed_id, IdMappingMode::EXACT); + + const auto consumer_leaf_ids = + getUsedLeafIds(consumer_indexed_id, consumer_td); + + // Convert to Parallel concrete IDs to find matching loops. + std::unordered_set concrete_leaf_ids; + for (auto& id : consumer_leaf_ids) { + concrete_leaf_ids.insert( + gpu_lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP)); + } + + // Find used loops and their index vals + for (const auto i : c10::irange(loops.size())) { + auto loop = loops.at(i); + auto loop_id = gpu_lower->caMap()->getConcreteMappedID( + loop->iter_domain(), IdMappingMode::LOOP); + auto it = concrete_leaf_ids.find(loop_id); + if (it != concrete_leaf_ids.end()) { + // This leaf reference id is used for indexing the consumer id + used_loops_.push_back(loop); + auto index_it = + loop_index_map.find(gpu_lower->caMap()->getConcreteMappedID( + loop_domains.at(i), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT( + index_it != loop_index_map.end(), + "Index not found for leaf ID, ", + loop_domains.at(i)->toString()); + loop_index_vals_.push_back(index_it->second); + } + } + + TORCH_INTERNAL_ASSERT( + !used_loops_.empty(), + "No loop used for indexing found. ", + consumer_indexed_id->toString()); + + TORCH_INTERNAL_ASSERT( + consumer_leaf_ids.size() == used_loops_.size(), + "consumer_leaf_ids.size() = ", + consumer_leaf_ids.size(), + ", used_loops_.size() == ", + used_loops_.size(), + ", loops.size() == ", + loops.size()); +} + bool CommonIndexKey::operator==(const CommonIndexKey& other) const { auto gpu_lower = GpuLower::current(); @@ -179,7 +235,30 @@ std::pair CommonIndexMap::insert( const CommonIndexKey key( indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops); + return tryInsertNewIndex(key, index); +} +std::pair CommonIndexMap::insert( + IterDomain* indexed_consumer_id, + TensorDomain* consumer_td, + std::vector loop_domains, + const std::unordered_map& loop_index_map, + const std::vector& loops, + Val* index) { + if (index->definition() == nullptr) { + // Only expression is eligible to hoist + return {index, false}; + } + + const CommonIndexKey key( + indexed_consumer_id, consumer_td, loop_domains, loop_index_map, loops); + + return tryInsertNewIndex(key, index); +} + +std::pair CommonIndexMap::tryInsertNewIndex( + CommonIndexKey key, + Val* index) { Val* hoisted_index = nullptr; bool new_index_inserted = false; @@ -195,7 +274,6 @@ std::pair CommonIndexMap::insert( new_index_inserted = true; use_counts_[key] = 1; } - return {hoisted_index, new_index_inserted}; } diff --git a/torch/csrc/jit/codegen/cuda/lower_index_hoist.h b/torch/csrc/jit/codegen/cuda/lower_index_hoist.h index 5e0256f9e8449..b3bf36248f8b8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_hoist.h +++ b/torch/csrc/jit/codegen/cuda/lower_index_hoist.h @@ -43,6 +43,20 @@ class CommonIndexKey { const std::unordered_map& ref_index_map, const std::vector& loops); + //! \param consumer_indexed_id Indexed consumer domain + //! \param consumer_td TensorDomain of consumer_indexed_id + //! \param loop_domains Resolved vector of iterdomain corresponding to loops + //! \param loop_index_map Index mapping generated from the loop nest. + //! \param loops Loop structure where this id is indexed + //! Duplicate of above, but without a reference domain. TODO: Remove other + //! implementation. + CommonIndexKey( + IterDomain* consumer_indexed_id, + TensorDomain* consumer_td, + const std::vector& loop_domains, + const std::unordered_map& loop_index_map, + const std::vector& loops); + const IterDomain* concreteIndexedId() const { return concrete_indexed_id_; } @@ -96,6 +110,16 @@ class TORCH_CUDA_CU_API CommonIndexMap { const std::vector& loops, Val* index); + //! Duplicate of above, but without a reference domain. TODO: Remove other + //! implementation. + std::pair insert( + IterDomain* indexed_consumer_id, + TensorDomain* consumer_td, + const std::vector& loop_domains, + const std::unordered_map& loop_index_map, + const std::vector& loops, + Val* index); + const auto& commonIndexMap() const { return common_index_map_; } @@ -104,6 +128,16 @@ class TORCH_CUDA_CU_API CommonIndexMap { return use_counts_; } + private: + //! Utility method to insert a key into common index + //! map. Returns a pair of an IR node and a boolean value. + //! The IR node will be the previously inserted index if + //! the key found a match, or will be the original index + //! if this is new key and the key will be stored. + //! The boolean value will be true if the key is stored, + //! i.e. first time it is inserted. + std::pair tryInsertNewIndex(CommonIndexKey key, Val* index); + private: //! Map to hold hoisted common indices std::unordered_map From e89071aa6fc654eff06dcb13264d404d29e80c4f Mon Sep 17 00:00:00 2001 From: "S. Song" <41357537+shmsong@users.noreply.github.com> Date: Fri, 24 Jun 2022 21:00:39 -0700 Subject: [PATCH 10/12] Use iterdomain graph to index non-global consumers (#1737) Co-authored-by: Christian Sarofeen --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 106 ++----------- torch/csrc/jit/codegen/cuda/index_compute.h | 27 ++++ .../jit/codegen/cuda/index_idgraph_utils.cpp | 140 +++++++++++++++++- .../jit/codegen/cuda/index_idgraph_utils.h | 8 +- .../codegen/cuda/index_reference_replay.cpp | 81 ++++++++++ .../jit/codegen/cuda/index_reference_replay.h | 10 ++ .../jit/codegen/cuda/lower_index_hoist.cpp | 4 +- 7 files changed, 277 insertions(+), 99 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 86af31bf278e6..7a217c3bec6c0 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -890,8 +890,6 @@ void IndexSwizzle::handle(Expr* e) { } } -namespace { - // Used for local and shared index mapping. Returns a map from loops // to loop indices as well as a set of loops that do not contribute to // indexing. @@ -903,7 +901,7 @@ indexMapFromTV( const std::vector& loops, kir::ForLoop* alloc_loop, bool as_consumer, - kir::ForLoop* double_buffer_loop = nullptr) { + kir::ForLoop* double_buffer_loop) { bool within_alloc = false; if (alloc_loop == nullptr) { within_alloc = true; @@ -1034,7 +1032,7 @@ void ensureStaticIndexing( const TensorView* tv, kir::ForLoop* alloc_loop, const std::vector& loops, - const std::unordered_map& id_map = {}) { + const std::unordered_map& id_map) { if (tv->getMemoryType() != MemoryType::Local) { return; } @@ -1078,6 +1076,8 @@ void ensureStaticIndexing( } } +namespace { + // Map everything we can from reference to provided tv using the provided // compute at map. If root_only is true, only root domains are included. // We can't simply try to use the provided tv root domains and @@ -2054,83 +2054,15 @@ std::vector Index::getNonGlobalConsumerStridedIndices( const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - // Get a reference tensor replayed as existing loop structure - auto reference = IndexReferenceReplay::getReference(loops, consumer_tv); - - auto reference_domain = reference.domain; - auto reference_id_map = reference.concrete_to_id; - - auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops); - std::unordered_map loop_to_ind_map; - std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = - indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true); - - ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops); - - // Map loop nests to indicies, zeroing out those not used due to locality of - // memory - std::unordered_map ref_id_to_ind_map; - std::unordered_set ref_zero_domains; - - // Due to rfactor/initialization reference_domain may be bigger than loop nest - // structure, ignore IterDomains that aren't present in the loop nest when - // indexing reference. - TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); - for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = reference_domain->axis(loop_i); - ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; - if (zero_loops.count(loops[loop_i]) > 0) { - ref_zero_domains.insert(ref_axis); - } - } - - // Map everything we can from reference to consumer using compute at index - // map. - std::unordered_map index_map_ref_to_consumer = - indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map); - - // Grab roots that map into consumer and save them into the preferred roots - // set for references indexing - std::unordered_set preferred_roots; - for (auto entry : index_map_ref_to_consumer) { - if (entry.second->isBroadcast() || entry.second->isReduction() || - entry.second->isStride()) { - continue; - } - preferred_roots.emplace(entry.first); - } - - // Make sure propagation of indexing while mixing with 0 indicies we propagate - // in a way that consumer will be able to see what's going on. - auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots); - - // Index into the reference tensor - auto ref_compute = getReferenceIndexing( + auto consumer_indexing_from_idgraph = getTensorIndexFromIdGraph( loops, - reference_domain, - ref_id_to_ind_map, - ref_zero_domains, - preferred_paths); - - // Adds halo info mappings for the reference - updateHaloInfoForReference(reference, consumer_tv); - - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, index_map_ref_to_consumer); - - ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - consumer_tv->domain()->contiguity(), - reference_id_map); + consumer_tv, + // Producer tv + nullptr, + // Index global + false); - // Index into consumer using reference indexing - auto consumer_indexing = ref_compute.updateIndexCompute( - consumer_tv->domain(), - index_map_ref_to_consumer, - contig_finder, - reference_halo_extent_map); + auto consumer_indexing = consumer_indexing_from_idgraph.index; IndexSwizzle index_swizzle( consumer_tv, @@ -2164,11 +2096,7 @@ std::vector Index::getNonGlobalConsumerStridedIndices( " dim: ", i, " id: ", - root_dom[i]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[i]->toString()); auto root_ind_i = index_map.at(root_dom[i]); if (root_ind_i->isZeroInt()) { @@ -2180,8 +2108,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( root_dom[i], consumer_tv, consumer_indexing, - reference.domain, - ref_compute, + consumer_indexing_from_idgraph.resolved_loop_domains, + consumer_indexing_from_idgraph.initial_concrete_index_map, loops, root_ind_i); @@ -2201,11 +2129,7 @@ std::vector Index::getNonGlobalConsumerStridedIndices( " dim: ", j, " id: ", - root_dom[j]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[j]->toString()); auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() ? root_dom[j]->extent() diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index defcbe4439c68..e9386f5a53a2c 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -365,6 +365,33 @@ class Index { Val* ind = nullptr); }; +// Used for local and shared index mapping. Returns a map from loops +// to loop indices as well as a set of loops that do not contribute to +// indexing. +// TODO: could be cleaned up further. +std::pair< + std::unordered_map, + std::unordered_set> +indexMapFromTV( + const TensorView* tv, + const std::vector& loops, + kir::ForLoop* alloc_loop, + bool as_consumer, + kir::ForLoop* double_buffer_loop = nullptr); + +//! Set "pragma unroll" required for loops that indexing of Local +//! tensors depends on. +//! +//! \param tv Indexed tensor +//! \param alloc_loop Allocation loop of tv +//! \param loops The current loop structure +//! \param id_map Producer-to-consumer map in case of indexing as producer +void ensureStaticIndexing( + const TensorView* tv, + kir::ForLoop* alloc_loop, + const std::vector& loops, + const std::unordered_map& id_map = {}); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp index 1b1f95481eedc..b4d94b45b90f1 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -55,6 +56,43 @@ void insertMagicZero( } } +// Maps all producer domains to consumer with broadcast +// forwarding. Used to find the allocation position. +// TODO: should this be an ir_util ? Didn't seem to be +// used too much though. +std::unordered_map mapAllProducerDomainsToConsumer( + const TensorView* producer_tv, + const TensorView* consumer_tv) { + // This map has forwarded broadcast axes, it should only be used to compute + // the allocation position of the producer, and to figure out which producer + // indices are mapped to consumer trivial reductions. + std::unordered_map p2c_alloc_map; + + // We want to replay producer as consumer instead of the other way around + // since consumer may have some broadcasted axes producer doesn't have + // merged into loops producer may use. If we did consumer as producer we + // wouldn't have this information in the mapping. + auto replay_PasC = BestEffortReplay::replayPasC( + producer_tv, + consumer_tv, + -1, + PairwiseRootDomainMap(producer_tv, consumer_tv)); + + // Grab consumer domain entries and reverse replay map. TODO: Maybe + // TransformReplay::replayPasC could return this map + for (auto id : consumer_tv->domain()->domain()) { + const auto& c2p_map = replay_PasC.getReplay(); + auto c2p_it = c2p_map.find(id); + if (c2p_it != c2p_map.end()) { + auto c_id = c2p_it->first; + auto p_id = c2p_it->second; + p2c_alloc_map[p_id] = c_id; + } + } + + return p2c_alloc_map; +} + //! A struct to keep track of necessary parameters used in //! configuring index compute pass. //! These parameters are needed to propagate the indexing from the leaf nodes of @@ -116,6 +154,80 @@ IndexingParameters getGlobalIndexParameters( return index_parameters; } +// Initial index parameters for shared and local case +IndexingParameters getNonGlobalInitialIndexParameters( + const LoopIndexing& loop_indexing, + const TensorView* consumer_tv, + bool index_producer = false, + const TensorView* producer_tv = nullptr, + std::unordered_map p2c_map = {}) { + IndexingParameters index_parameters; + const auto& loops = loop_indexing.loops(); + const auto& loop_domains = loop_indexing.loopDomains(); + + // TODO: + // The non-global path should become shorter as we + // pull more info into id graph. + std::unordered_map alloc_id_map; + + if (index_producer) { + alloc_id_map = mapAllProducerDomainsToConsumer(producer_tv, consumer_tv); + } + + auto alloc_tv = index_producer ? producer_tv : consumer_tv; + auto alloc_info = loop_utils::getAllocInformation( + alloc_tv, loops, alloc_id_map, index_producer); + + std::unordered_map loop_to_ind_map; + std::unordered_set zero_loops; + + kir::ForLoop* double_buffer_loop = nullptr; + + if (index_producer) { + double_buffer_loop = + GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops, true); + } + + std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( + alloc_tv, + loops, + alloc_info.init_for_loop, + !index_producer, + double_buffer_loop); + + ensureStaticIndexing(alloc_tv, alloc_info.init_for_loop, loops, alloc_id_map); + + TORCH_INTERNAL_ASSERT( + loops.size() <= loop_domains.size(), + "Loop domain didn't replay all loops"); + + for (auto loop_idx : c10::irange(loops.size())) { + auto loop = loops[loop_idx]; + auto loop_domain = loop_domains[loop_idx]; + + auto concrete_loop_domain = ir_utils::caMapExactConcreteId(loop_domain); + + index_parameters.initial_concrete_id_index[concrete_loop_domain] = + loop_to_ind_map.at(loop); + + if (zero_loops.count(loop)) { + index_parameters.zero_domains.insert(concrete_loop_domain); + } + } + + // Derive preferred path from loop indexing result. + const TensorView* target_tv = index_producer ? producer_tv : consumer_tv; + index_parameters.preferred_concrete_ids = buildLoopIndexingPreferredPath( + target_tv, loop_indexing, index_producer, p2c_map); + + // Derive the halo extents from the loop indexing result. + index_parameters.concrete_id_to_halo_extent = + GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + + return index_parameters; +} + } // namespace class LoopIndexingAnalysis { @@ -464,8 +576,8 @@ IndexFromIdGraph getTensorIndexFromIdGraph( // TODO: remove this check when adding other indexing // support. TORCH_INTERNAL_ASSERT( - is_global && !index_producer, - "ConsumerIndexFromIdGraph: currently only global consumer indexing is supported."); + !index_producer, + "ConsumerIndexFromIdGraph: currently only consumer indexing is supported."); auto loop_indexing = LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); @@ -475,10 +587,10 @@ IndexFromIdGraph getTensorIndexFromIdGraph( if (is_global) { index_parameters = getGlobalIndexParameters(loop_indexing, index_producer); } else { - TORCH_INTERNAL_ASSERT(false, "not yet implemented"); + index_parameters = getNonGlobalInitialIndexParameters( + loop_indexing, consumer_tv, index_producer, producer_tv); } - // Setup IndexCompute to traverse backwards through concrete IDs. IndexCompute indexing( index_parameters.initial_concrete_id_index, index_parameters.zero_domains, @@ -725,6 +837,26 @@ std::vector LoopIndexing::getBackwardExprList() const { return LoopIndexingTraversal::backwardTopologicalOrder(index_exprs_); } +std::unordered_set LoopIndexing::getAllExactConcreteIdSet() const { + std::unordered_set all_id_set; + for (auto expr : index_exprs_) { + auto out_ids = ir_utils::filterByType(expr->outputs()); + std::transform( + out_ids.begin(), + out_ids.end(), + std::inserter(all_id_set, all_id_set.end()), + ir_utils::caMapExactConcreteId); + + auto in_ids = ir_utils::filterByType(expr->inputs()); + std::transform( + in_ids.begin(), + in_ids.end(), + std::inserter(all_id_set, all_id_set.end()), + ir_utils::caMapExactConcreteId); + } + return all_id_set; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h index 5f6eadb70fe99..c50b429da94b2 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h @@ -25,8 +25,8 @@ struct IndexFromIdGraph { //! Indexing interface, returns IndexFromIdGraph which the IndexCompute object //! can be queried from directly for the produced indexing. If producer_tv != //! nullptr producer will be indexed, if producer_tv == nullptr consumer will be -//! indexed. If is_global global indexing will be done, else error will be -//! thrown (TODO: support local/shared memory indexing). +//! indexed. If is_global global indexing will be done, else shared memory or +//! local indexing will be performed. IndexFromIdGraph getTensorIndexFromIdGraph( const std::vector& loops, const TensorView* consumer_tv, @@ -113,6 +113,10 @@ class LoopIndexing { //! topological order. std::vector getBackwardExprList() const; + //! Returns all exact concrete id's that were produced + //! or consumed in the selected indexing expressions + std::unordered_set getAllExactConcreteIdSet() const; + private: friend class LoopIndexingAnalysis; diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index e1a93ddf0f84f..3a94d04deefd8 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -529,6 +530,77 @@ class PreferredPathCompute : public IterVisitor { return compute.preferred_path; } }; + +class LoopIndexingPreferredPathCompute : public IterVisitor { + public: + static std::unordered_set compute( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map, + const std::unordered_map& p2c_map) { + LoopIndexingPreferredPathCompute compute; + + auto all_concrete_ids = loop_indexing.getAllExactConcreteIdSet(); + + // Annotate all ids + auto all_original_ids = DependencyCheck::getAllValsBetween( + {original_tv->getMaybeRFactorDomain().begin(), + original_tv->getMaybeRFactorDomain().end()}, + {original_tv->domain()->domain().begin(), + original_tv->domain()->domain().end()}); + + for (auto original_id : + ir_utils::filterByType(all_original_ids)) { + auto mapped_id = original_id; + if (use_replay_map) { + auto c_id_it = p2c_map.find(original_id); + if (c_id_it == p2c_map.end()) { + continue; + } + mapped_id = c_id_it->second; + } + auto concrete_original_id = ir_utils::caMapExactConcreteId(mapped_id); + if (all_concrete_ids.count(concrete_original_id)) { + if (original_id->isBroadcast() || original_id->isReduction() || + original_id->isStride()) { + continue; + } + compute.preferred_path_.insert(concrete_original_id); + } + } + + for (auto expr : loop_indexing.getForwardExprList()) { + compute.handle(expr); + } + + return compute.preferred_path_; + } + + private: + void handle(Expr* e) override { + // If an input ID is marked, propagate the marking to outputs of the + // expression + auto all_iter_inputs = ir_utils::filterByType(e->inputs()); + if (std::any_of( + all_iter_inputs.begin(), + all_iter_inputs.end(), + [&](IterDomain* inp_id) { + return this->preferred_path_.find(ir_utils::caMapExactConcreteId( + inp_id)) != this->preferred_path_.end(); + })) { + auto all_iter_outputs = ir_utils::filterByType(e->outputs()); + + std::transform( + all_iter_outputs.begin(), + all_iter_outputs.end(), + std::inserter(preferred_path_, preferred_path_.end()), + ir_utils::caMapExactConcreteId); + } + } + + std::unordered_set preferred_path_; +}; + } // namespace // External interface for preferred path propagation. @@ -538,6 +610,15 @@ std::unordered_set buildPreferredPaths( return PreferredPathCompute::compute(reference_tensor, preferred_roots); } +std::unordered_set buildLoopIndexingPreferredPath( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map, + std::unordered_map p2c_map) { + return LoopIndexingPreferredPathCompute::compute( + original_tv, loop_indexing, use_replay_map, p2c_map); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.h b/torch/csrc/jit/codegen/cuda/index_reference_replay.h index 3211c4a403dac..15062ef452ece 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.h +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.h @@ -108,6 +108,16 @@ std::unordered_set buildPreferredPaths( TensorDomain* reference_domain, const std::unordered_set& preferred_roots); +// When indexing there are sometimes an option to propagate an index down +// multiple paths. This will return the IterDomains in the history of the +// reference domain and mark which paths should be taken (if there's a +// preference) to reach the roots provided in preferred_roots. +std::unordered_set buildLoopIndexingPreferredPath( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map = false, + std::unordered_map p2c_map = {}); + // Get an rfactor IterDomain that is mapped with an IterDomain. If // multiple such IDs exist, select one whose input IDs are mapped with // the consumer IDs. This is to ensure the path from the leaf diff --git a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp index 2772bf04d9eef..3eb6080db82e3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp @@ -97,7 +97,7 @@ CommonIndexKey::CommonIndexKey( CommonIndexKey::CommonIndexKey( IterDomain* consumer_indexed_id, TensorDomain* consumer_td, - std::vector loop_domains, + const std::vector& loop_domains, const std::unordered_map& loop_index_map, const std::vector& loops) { auto gpu_lower = GpuLower::current(); @@ -241,7 +241,7 @@ std::pair CommonIndexMap::insert( std::pair CommonIndexMap::insert( IterDomain* indexed_consumer_id, TensorDomain* consumer_td, - std::vector loop_domains, + const std::vector& loop_domains, const std::unordered_map& loop_index_map, const std::vector& loops, Val* index) { From 0de67360646b9297dd9b0f21d52708e8786816f5 Mon Sep 17 00:00:00 2001 From: "S. Song" <41357537+shmsong@users.noreply.github.com> Date: Fri, 24 Jun 2022 21:23:55 -0700 Subject: [PATCH 11/12] Remove reference tensor creation in producer tensor indexing path (#1750) Co-authored-by: Christian Sarofeen --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 430 ++++++++---------- .../jit/codegen/cuda/index_idgraph_utils.cpp | 85 +++- .../jit/codegen/cuda/index_idgraph_utils.h | 2 + 3 files changed, 256 insertions(+), 261 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 7a217c3bec6c0..818d15900ebeb 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -205,6 +205,58 @@ Val* getProducerOffsetWithGather( return producer_offset; } +//! Create a producer offset based off a consumer index +//! +//! \param consumer_root_axis Position of corresponding consumer axis +//! \param consumer_tv Consumer TensorView +//! \param index_map Mappings from consumer or reference to indices +//! \param use_reference_map True when index_map maps reference domains +//! \param concrete_to_ref_map Mappings from concrete to reference domains +Val* getConcreteProducerOffsetWithGather( + size_t consumer_root_axis, + const TensorView* consumer_tv, + const std::unordered_map& index_map, + bool use_concrete_map = false) { + const auto gpu_lower = GpuLower::current(); + + const auto gather_expr = dynamic_cast(consumer_tv->definition()); + + if (gather_expr == nullptr) { + return gpu_lower->kernel()->zeroVal(); + } + + // If the window extent is one, no specific offsetting + // is necessary + if (consumer_root_axis >= gather_expr->windowShape().size() || + gather_expr->windowShape()[consumer_root_axis] == 1) { + return gpu_lower->kernel()->zeroVal(); + } + + // Basically, the goal is to build an expression of producer_index + + // window_index, so we first need to locate the index expression + // that corresponds to the window axis of this producer axis. + + const auto window_axis = gather_expr->gatherAxis(consumer_root_axis); + auto window_id = consumer_tv->getRootDomain().at(window_axis); + + Val* window_idx = nullptr; + + if (use_concrete_map) { + window_idx = index_map.at(ir_utils::caMapExactConcreteId(window_id)); + } else { + window_idx = index_map.at(window_id); + } + + // Positive padding at offset zero means the indexing shifted to the + // negative direction. + auto pad_width = gather_expr->padWidth()[consumer_root_axis][0]; + + // producer offset: window_index - padding + auto producer_offset = SimplifyingIrBuilder::subExpr( + window_idx, SimplifyingIrBuilder::create(pad_width)); + return producer_offset; +} + //! Offset a producer index of a gather expression //! //! Given an index of a producer root axis, build a new index @@ -248,6 +300,48 @@ Val* getProducerIndexWithGather( return SimplifyingIrBuilder::addExpr(producer_index, offset); } +//! Offset a producer index of a gather expression +//! +//! Given an index of a producer root axis, build a new index +//! expression that accesses a window position that the current loop +//! structure refers to. Use getGatherProducerOffset to create an +//! offset Val. +Val* getProducerIndexWithGather( + Val* producer_index, + size_t producer_root_axis, + const TensorView* producer_tv, + const TensorView* consumer_tv, + const std::unordered_map& concrete_index_map) { + auto gather_op = dynamic_cast(consumer_tv->definition()); + + // Just return the producer index as is if this is not a gather + if (gather_op == nullptr) { + return producer_index; + } + + // Consumer axis that corresponds to the producer axis + int consumer_axis = -1; + for (const auto i : c10::irange(producer_root_axis + 1)) { + if (producer_tv->getMaybeRFactorDomain()[i]->isReduction() || + producer_tv->getMaybeRFactorDomain()[i]->isStride()) { + continue; + } + ++consumer_axis; + } + + TORCH_INTERNAL_ASSERT( + consumer_axis >= 0 && + consumer_axis < (int)gather_op->windowShape().size(), + "Invalid consumer axis", + consumer_axis, + ", producer_axis: ", + producer_root_axis); + + auto offset = getConcreteProducerOffsetWithGather( + consumer_axis, consumer_tv, concrete_index_map, true); + return SimplifyingIrBuilder::addExpr(producer_index, offset); +} + // Adjusts a global consumer index when its root domain is partially // split. Note that non-global consumer indices don't need any // adjustment. @@ -1120,10 +1214,10 @@ std::unordered_map indexMapReferenceTo( //! Returns an iterdomain that corresponds to the //! indexing sub-expression to hoist or a nullopt //! if the index should not be hoisted. -c10::optional getMaybeIndexedConsumerIdToHoist( - IterDomain* consumer_root_id, - const TensorView* consumer_tv, - const IndexCompute& consumer_indexing, +c10::optional getMaybeIndexedIdToHoist( + IterDomain* root_id, + const TensorView* tv, + const IndexCompute& indexing, Val* index) { if (isDisabled(DisableOption::IndexHoist) || index->definition() == nullptr) { return c10::nullopt; @@ -1131,26 +1225,25 @@ c10::optional getMaybeIndexedConsumerIdToHoist( // The old swizzle interface, which should be deprecated, is not // supported. - if (consumer_tv->swizzleType() != SwizzleType::NoSwizzle) { + if (tv->swizzleType() != SwizzleType::NoSwizzle) { return c10::nullopt; } - // auto indexed_consumer_id = consumer_root_id; // Find the true indexed domain, which can be a merged contiguous domain. - auto contig_id_it = consumer_indexing.rootToContigID().find(consumer_root_id); + auto contig_id_it = indexing.rootToContigID().find(root_id); TORCH_INTERNAL_ASSERT( - contig_id_it != consumer_indexing.rootToContigID().end(), + contig_id_it != indexing.rootToContigID().end(), "Consumer indexed ID not found: ", - consumer_root_id->toString()); - auto indexed_consumer_id = contig_id_it->second; + root_id->toString()); + auto indexed_id = contig_id_it->second; // Make sure this contig ID is indeed indexed TORCH_INTERNAL_ASSERT( - consumer_indexing.indexMap().find(contig_id_it->second) != - consumer_indexing.indexMap().end(), + indexing.indexMap().find(contig_id_it->second) != + indexing.indexMap().end(), "Invalid contig index: ", contig_id_it->second->toString()); - return indexed_consumer_id; + return indexed_id; } Val* hoistConsumerIndex( @@ -1161,7 +1254,7 @@ Val* hoistConsumerIndex( const IndexCompute& ref_indexing, const std::vector& loops, Val* index) { - auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist( + auto maybe_hoisted_consumer_id = getMaybeIndexedIdToHoist( consumer_root_id, consumer_tv, consumer_indexing, index); if (!maybe_hoisted_consumer_id.has_value()) { @@ -1191,11 +1284,11 @@ Val* hoistConsumerIndex( IterDomain* consumer_root_id, const TensorView* consumer_tv, const IndexCompute& consumer_indexing, - const std::vector& loop_domains, + std::vector loop_domains, const std::unordered_map initial_loop_index_map, const std::vector& loops, Val* index) { - auto maybe_hoisted_consumer_id = getMaybeIndexedConsumerIdToHoist( + auto maybe_hoisted_consumer_id = getMaybeIndexedIdToHoist( consumer_root_id, consumer_tv, consumer_indexing, index); if (!maybe_hoisted_consumer_id.has_value()) { @@ -1241,37 +1334,68 @@ Val* hoistProducerIndex( const IndexCompute& ref_indexing, const std::vector& loops, Val* index) { - // If index has no defining expression, there's nothing to hoist - if (isDisabled(DisableOption::IndexHoist) || index->definition() == nullptr) { + auto maybe_indexed_producer_id = getMaybeIndexedIdToHoist( + producer_root_id, producer_tv, producer_indexing, index); + + if (!maybe_indexed_producer_id.has_value()) { return index; } - // The old swizzle interface, which should be deprecated, is not - // supported. - if (producer_tv->swizzleType() != SwizzleType::NoSwizzle) { + auto indexed_consumer_id_it = p2c_map.find(maybe_indexed_producer_id.value()); + + // There can be no corresponding consumer ID. For example, consider: + // consumer: [b1, i2, i3] + // producer: [i2, i3]. + // Suppose the consumer is transformed as: + // consumer: [(b1*i2)*i3] + // Then the producer would be transformed when indexed: + // producer: [i2*i3] + // Assuming i2 and i3 are contiguous, the producer indexing is done + // with the mreged i2*i3 domain, but there's no domain in the + // cosumer that maps with the producer indexed domain. + // It seems non-trivial to support patterns like this. Skip for now. + if (indexed_consumer_id_it == p2c_map.end()) { return index; } - // auto indexed_producer_id = producer_root_id; - auto contig_id_it = producer_indexing.rootToContigID().find(producer_root_id); - TORCH_INTERNAL_ASSERT( - contig_id_it != producer_indexing.rootToContigID().end(), - "Producer indexed ID not found: ", - producer_root_id->toString()); - auto indexed_producer_id = contig_id_it->second; - // Make sure this contig ID is indeed indexed - TORCH_INTERNAL_ASSERT( - producer_indexing.indexMap().find(indexed_producer_id) != - producer_indexing.indexMap().end(), - "Invalid contig id: ", - indexed_producer_id->toString()); + IterDomain* indexed_consumer_id = indexed_consumer_id_it->second; + + auto common_index = GpuLower::current() + ->commonIndexMap() + .insert( + indexed_consumer_id, + consumer_tv->domain(), + ref_td, + ref_indexing.indexMap(), + loops, + index) + .first; + + return common_index; +} + +Val* hoistProducerIndex( + IterDomain* producer_root_id, + const TensorView* producer_tv, + const IndexCompute& producer_indexing, + const TensorView* consumer_tv, + const std::unordered_map& p2c_map, + std::vector loop_domains, + const std::unordered_map initial_loop_index_map, + const std::vector& loops, + Val* index) { + auto maybe_indexed_producer_id = getMaybeIndexedIdToHoist( + producer_root_id, producer_tv, producer_indexing, index); + + if (!maybe_indexed_producer_id.has_value()) { + return index; + } // Use the corresponding consumer domain to find matching // for-loops. Note that there's no CA mapping with the producer // domains as the producer TensorDomain is a temporary replay // domain. - - auto indexed_consumer_id_it = p2c_map.find(indexed_producer_id); + auto indexed_consumer_id_it = p2c_map.find(maybe_indexed_producer_id.value()); // There can be no corresponding consumer ID. For example, consider: // consumer: [b1, i2, i3] @@ -1295,8 +1419,8 @@ Val* hoistProducerIndex( .insert( indexed_consumer_id, consumer_tv->domain(), - ref_td, - ref_indexing.indexMap(), + loop_domains, + initial_loop_index_map, loops, index) .first; @@ -1313,11 +1437,6 @@ std::vector Index::getGlobalProducerStridedIndices( FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex"); const auto gpu_lower = GpuLower::current(); - // Get a reference tensor replayed as existing loop structure - auto reference = IndexReferenceReplay::getReference(loops, consumer_tv); - auto reference_domain = reference.domain; - auto reference_id_map = reference.concrete_to_id; - // Replay producer to look like consumer so we can index on producer since // our loop nests look like consumer auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); @@ -1328,10 +1447,6 @@ std::vector Index::getGlobalProducerStridedIndices( // Make the producer_tv look like consumer while performing indexing math ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC); - // Map everything we can from reference to producer using compute at index - // map. Use consumer as a proxy between producer and the generated reference. - std::unordered_map index_map_ref_to_producer; - // Map sent to best effort replay needs to match the exact incantation for // compute_at_mode.cpp with MappingMode::Index auto c2p_root_map = @@ -1346,27 +1461,6 @@ std::vector Index::getGlobalProducerStridedIndices( const auto& c2p_map = replay_producer_as_consumer.getReplay(); const auto p2c_map = invertOneToOneMap(c2p_map); - { - std::unordered_map index_map_ref_to_consumer = - indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map); - - for (auto entry : index_map_ref_to_consumer) { - auto r_id = entry.first; - auto c_id = entry.second; - auto c2p_it = c2p_map.find(c_id); - if (c2p_it != c2p_map.end()) { - auto p_id = c2p_it->second; - index_map_ref_to_producer[r_id] = p_id; - } - } - } - - kir::ForLoop* db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop( - consumer_tv, loops, true); - - // Index into the reference tensor. Reference indexing will handle vectorized - // dims where index should be set to 0 - auto ref_compute = getReferenceIndexing(loops, reference_domain, db_loop); // Forward vectorized IDs to index into producer correctly // We want p_id to be vectorized like consumer just for the indexing, then we @@ -1374,8 +1468,8 @@ std::vector Index::getGlobalProducerStridedIndices( // need to do this as replaying producer as consumer can use replay best // effort which means some domains may be producer's original domains. std::vector> p_id_backup; - for (auto entry : index_map_ref_to_producer) { - auto ref_id = entry.first; + for (auto entry : c2p_map) { + auto ref_id = ir_utils::caMapExactConcreteId(entry.first); auto p_id = entry.second; if (ref_id->getParallelType() == ParallelType::Vectorize) { p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); @@ -1385,25 +1479,10 @@ std::vector Index::getGlobalProducerStridedIndices( } } - // Adds halo info mappings for the reference - updateHaloInfoForReference(reference, consumer_tv); + auto producer_indexing_from_idgraph = + getTensorIndexFromIdGraph(loops, consumer_tv, producer_tv, true, c2p_map); - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, index_map_ref_to_producer); - - ContigIDs contig_finder( - producer_tv->domain()->domain(), - producer_tv->getMaybeRFactorDomain(), - producer_tv->domain()->contiguity(), - reference_id_map, - p2c_map); - - // Index into producer using reference indexing - auto producer_indexing = ref_compute.updateIndexCompute( - producer_tv->domain(), - index_map_ref_to_producer, - contig_finder, - reference_halo_extent_map); + auto producer_indexing = producer_indexing_from_idgraph.index; // Revert p_ids for (auto entry : p_id_backup) { @@ -1454,11 +1533,7 @@ std::vector Index::getGlobalProducerStridedIndices( " dim: ", dim, " id: ", - root_dom[dim]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[dim]->toString()); if (producer_tv->domain()->contiguity()[dim]) { // If contig, used the stored stride which may be the previous @@ -1500,11 +1575,7 @@ std::vector Index::getGlobalProducerStridedIndices( " dim: ", i, " id: ", - root_dom[i]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[i]->toString()); auto root_ind = producer_indexing.indexMap().at(root_dom[i]); @@ -1515,8 +1586,8 @@ std::vector Index::getGlobalProducerStridedIndices( producer_indexing, consumer_tv, p2c_map, - reference.domain, - ref_compute, + producer_indexing_from_idgraph.resolved_loop_domains, + producer_indexing_from_idgraph.initial_concrete_index_map, loops, root_ind); @@ -1527,8 +1598,7 @@ std::vector Index::getGlobalProducerStridedIndices( i, producer_tv, consumer_tv, - reference_id_map, - ref_compute.indexMap()); + producer_indexing_from_idgraph.concrete_index.indexMap()); root_ind = getProducerIndexWithPartialSplit( root_ind, root_dom[i], producer_tv, consumer_tv); @@ -1546,10 +1616,6 @@ std::vector Index::getGlobalProducerStridedIndices( } } - // Save indexing info necessary for validating vectorization at launch time - fillProducerVectorizedContigRootDomains( - producer_tv, consumer_tv, c2p_map, contig_finder); - return strided_inds; } @@ -1599,11 +1665,6 @@ std::vector Index::getNonGlobalProducerStridedIndices( const std::vector& loops) { const auto gpu_lower = GpuLower::current(); - // Get a reference tensor replayed as existing loop structure - auto reference = IndexReferenceReplay::getReference(loops, consumer_tv); - auto reference_domain = reference.domain; - auto reference_id_map = reference.concrete_to_id; - // Replay producer to look like consumer so we can index on producer since our // loop nests look like consumer auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); @@ -1613,45 +1674,9 @@ std::vector Index::getNonGlobalProducerStridedIndices( ir_utils::TVDomainGuard domain_guard( producer_tv, producer_replayed_as_consumer); - const auto p2c_alloc_map = mapAllProducerDomainsToConsumer(producer_tv, consumer_tv); - kir::ForLoop* consumer_db_loop = - gpu_lower->doubleBufferInfo().getDoubleBufferLoop( - consumer_tv, loops, true); - - // Find allocation point of producer relative to loop nests. P2C map is - // required because producer was replayed as consumer, so we can't use the - // regular compute at maps to line up its iter domains with the for loops. - auto alloc_info = - loop_utils::getAllocInformation(producer_tv, loops, p2c_alloc_map, true); - std::unordered_map loop_to_ind_map; - std::unordered_set zero_loops; - std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV( - producer_tv, loops, alloc_info.init_for_loop, false, consumer_db_loop); - - ensureStaticIndexing( - producer_tv, alloc_info.init_for_loop, loops, p2c_alloc_map); - - // Map loop nests to indicies, zeroing out those not used due to locality of - // memory - std::unordered_map ref_id_to_ind_map; - // Track which domains are not used - std::unordered_set ref_zero_domains; - - // Due to rfactor/initialization reference_domain may be bigger than loop nest - // structure, ignore IterDomains that aren't present in the loop nest when - // indexing reference. - TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); - for (const auto loop_i : c10::irange(loops.size())) { - auto ref_axis = reference_domain->axis(loop_i); - ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]]; - if (zero_loops.count(loops[loop_i]) > 0) { - ref_zero_domains.insert(ref_axis); - } - } - // Map everything we can from reference to producer using compute at index // map. All producer id's don't exist in the compute at map. The rfactor axes // all may be, but since I haven't proven that to be the case, going to do a @@ -1660,59 +1685,21 @@ std::vector Index::getNonGlobalProducerStridedIndices( std::unordered_map index_map_ref_to_producer; std::unordered_map c2p_index_map; std::unordered_map p2c_index_map; - { - // Map sent to best effort replay needs to match the exact incantation for - // compute_at_mode.cpp with MappingMode::Index - auto c2p_root_map = PairwiseRootDomainMap(producer_tv, consumer_tv, true) - .mapConsumerToProducer( - consumer_tv->domain(), producer_tv->domain()); - - // This replay has to be consistent with compute at index map. - BestEffortReplay replay_producer_as_consumer( - producer_tv->domain()->domain(), - consumer_tv->domain()->domain(), - c2p_root_map); - - c2p_index_map = replay_producer_as_consumer.getReplay(); - p2c_index_map = invertOneToOneMap(c2p_index_map); - - std::unordered_map index_map_ref_to_consumer = - indexMapReferenceTo(consumer_tv, gpu_lower->caMap(), reference_id_map); - - for (auto entry : index_map_ref_to_consumer) { - auto r_id = entry.first; - auto c_id = entry.second; - auto c2p_it = c2p_index_map.find(c_id); - if (c2p_it != c2p_index_map.end()) { - auto p_id = c2p_it->second; - index_map_ref_to_producer[r_id] = p_id; - } - } - } - // Grab roots that map into producer and save them into the preferred roots - // set for references indexing - std::unordered_set preferred_roots; - for (auto entry : index_map_ref_to_producer) { - if (entry.second->isBroadcast() || entry.second->isReduction() || - entry.second->isStride()) { - continue; - } - preferred_roots.emplace(entry.first); - } + // Map sent to best effort replay needs to match the exact incantation for + // compute_at_mode.cpp with MappingMode::Index + auto c2p_root_map = + PairwiseRootDomainMap(producer_tv, consumer_tv, true) + .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain()); - // Make sure propagation of indexing while mixing with 0 indicies we propagate - // in a way that the producer will be able to see what's going on (propagating - // into common roots of reference and producer). - auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots); + // This replay has to be consistent with compute at index map. + BestEffortReplay replay_producer_as_consumer( + producer_tv->domain()->domain(), + consumer_tv->domain()->domain(), + c2p_root_map); - // Index into the reference tensor - auto ref_compute = getReferenceIndexing( - loops, - reference_domain, - ref_id_to_ind_map, - ref_zero_domains, - preferred_paths); + c2p_index_map = replay_producer_as_consumer.getReplay(); + p2c_index_map = invertOneToOneMap(c2p_index_map); // Forward vectorized IDs to index into producer correctly // We want p_id to be vectorized like consumer just for the indexing, then we @@ -1720,8 +1707,8 @@ std::vector Index::getNonGlobalProducerStridedIndices( // need to do this as replaying producer as consumer can use replay best // effort which means some domains may be the originals. std::vector> p_id_backup; - for (auto entry : index_map_ref_to_producer) { - auto ref_id = entry.first; + for (auto entry : c2p_index_map) { + auto ref_id = ir_utils::caMapExactConcreteId(entry.first); auto p_id = entry.second; if (ref_id->getParallelType() == ParallelType::Vectorize) { p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); @@ -1731,26 +1718,10 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } - // Index into producer using reference indexing + auto producer_indexing_from_idgraph = getTensorIndexFromIdGraph( + loops, consumer_tv, producer_tv, false, c2p_index_map); - // Adds halo info mappings for the reference - updateHaloInfoForReference(reference, consumer_tv); - - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, index_map_ref_to_producer); - - ContigIDs contig_finder( - producer_tv->domain()->domain(), - producer_tv->getMaybeRFactorDomain(), - producer_tv->domain()->contiguity(), - reference_id_map, - p2c_index_map); - - auto producer_indexing = ref_compute.updateIndexCompute( - producer_tv->domain(), - index_map_ref_to_producer, - contig_finder, - reference_halo_extent_map); + auto producer_indexing = producer_indexing_from_idgraph.index; // Revert p_ids for (auto entry : p_id_backup) { @@ -1812,11 +1783,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( " dim: ", i, " id: ", - root_dom[i]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[i]->toString()); auto root_ind_i = index_map.at(root_dom[i]); @@ -1827,8 +1794,8 @@ std::vector Index::getNonGlobalProducerStridedIndices( producer_indexing, consumer_tv, p2c_index_map, - reference.domain, - ref_compute, + producer_indexing_from_idgraph.resolved_loop_domains, + producer_indexing_from_idgraph.initial_concrete_index_map, loops, root_ind_i); @@ -1840,8 +1807,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( i, producer_tv, consumer_tv, - reference_id_map, - ref_compute.indexMap()); + producer_indexing_from_idgraph.concrete_index.indexMap()); root_ind_i = getProducerIndexWithPartialSplit( root_ind_i, root_dom[i], producer_tv, consumer_tv); @@ -1864,11 +1830,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( " dim: ", j, " id: ", - root_dom[j]->toString(), - ", reference domain: ", - reference_domain->toString(), - ", reference root: ", - ir_utils::toString(reference_domain->getRootDomain())); + root_dom[j]->toString()); auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end() ? root_dom[j]->extent() @@ -1908,10 +1870,6 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } - // Save indexing info necessary for validating vectorization at launch time - fillProducerVectorizedContigRootDomains( - producer_tv, consumer_tv, c2p_index_map, contig_finder); - return strided_inds; } diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp index b4d94b45b90f1..1b7ec61015022 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp @@ -15,9 +15,11 @@ namespace cuda { IndexFromIdGraph::IndexFromIdGraph( IndexCompute index_, + IndexCompute concrete_index_, std::unordered_map initial_concrete_index_map_, std::vector loop_domains_) : index(index_), + concrete_index(concrete_index_), initial_concrete_index_map(initial_concrete_index_map_), resolved_loop_domains(loop_domains_) {} @@ -93,6 +95,19 @@ std::unordered_map mapAllProducerDomainsToConsumer( return p2c_alloc_map; } +std::unordered_map invertOneToOneMap( + const std::unordered_map& map) { + std::unordered_map inverted; + for (const auto& kv : map) { + bool inserted = inverted.emplace(kv.second, kv.first).second; + TORCH_INTERNAL_ASSERT( + inserted, + "Multiple mappings to the same value detected: ", + kv.second->toString()); + } + return inverted; +} + //! A struct to keep track of necessary parameters used in //! configuring index compute pass. //! These parameters are needed to propagate the indexing from the leaf nodes of @@ -123,7 +138,6 @@ IndexingParameters getGlobalIndexParameters( const LoopIndexing& loop_indexing, bool index_producer = false) { IndexingParameters index_parameters; - TORCH_INTERNAL_ASSERT(!index_producer, " not yet implemented"); auto& loops = loop_indexing.loops(); auto& loop_domain = loop_indexing.loopDomains(); @@ -151,6 +165,32 @@ IndexingParameters getGlobalIndexParameters( loop_indexing.loopDomains(), index_parameters.initial_concrete_id_index); + // Setup double buffer increment for producer case: + // TODO: could unify these double buffer index calculation + // in follow ups. + if (index_producer) { + auto double_buffer_loop = + GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + loop_indexing.consumerTv(), loops, true); + + for (auto loop_idx : c10::irange(loops.size())) { + auto loop = loops[loop_idx]; + if (loop == double_buffer_loop) { + TORCH_INTERNAL_ASSERT( + !loop->isTrivial(), "The double buffer loop must be materialized"); + + auto loop_id = loop_indexing.loopDomains()[loop_idx]; + + auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id); + + index_parameters.initial_concrete_id_index[concrete_loop_id] = + SimplifyingIrBuilder::addExpr( + index_parameters.initial_concrete_id_index[concrete_loop_id], + GpuLower::current()->kernel()->oneVal()); + } + } + } + return index_parameters; } @@ -573,22 +613,24 @@ IndexFromIdGraph getTensorIndexFromIdGraph( bool index_producer = producer_tv != nullptr; auto target_tv = index_producer ? producer_tv : consumer_tv; - // TODO: remove this check when adding other indexing - // support. - TORCH_INTERNAL_ASSERT( - !index_producer, - "ConsumerIndexFromIdGraph: currently only consumer indexing is supported."); - auto loop_indexing = LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); IndexingParameters index_parameters; + std::unordered_map p2c_map; + + // The p2c map is only needed when indexing producer + // as producer has replayed ids. + if (index_producer) { + p2c_map = invertOneToOneMap(c2p_map); + } + if (is_global) { index_parameters = getGlobalIndexParameters(loop_indexing, index_producer); } else { index_parameters = getNonGlobalInitialIndexParameters( - loop_indexing, consumer_tv, index_producer, producer_tv); + loop_indexing, consumer_tv, index_producer, producer_tv, p2c_map); } IndexCompute indexing( @@ -636,8 +678,7 @@ IndexFromIdGraph getTensorIndexFromIdGraph( // Exact id will have to be pulled from consumer side as the // producer side are replayed ids. - auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( - consumer_id, IdMappingMode::EXACT); + auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id); index_update_map[exact_concrete_id] = target_id; @@ -647,16 +688,11 @@ IndexFromIdGraph getTensorIndexFromIdGraph( } } - // TODO: - // This part will be filled in when updating the producer - // indexing logic. This is just placeholder for now. - std::unordered_map p2c_map; - // No contig indexing was done in reference indexing ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - consumer_tv->domain()->contiguity(), + target_tv->domain()->domain(), + target_tv->getMaybeRFactorDomain(), + target_tv->domain()->contiguity(), initial_indexable_map, p2c_map); @@ -665,17 +701,16 @@ IndexFromIdGraph getTensorIndexFromIdGraph( // Fill validation info. // TODO: cleanup seems possible. - if (is_global) { - if (index_producer) { - fillProducerVectorizedContigRootDomains( - target_tv, consumer_tv, c2p_map, contig_finder); - } else { - fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder); - } + if (index_producer) { + fillProducerVectorizedContigRootDomains( + producer_tv, consumer_tv, c2p_map, contig_finder); + } else { + fillConsumerVectorizedContigRootDomains(consumer_tv, contig_finder); } return IndexFromIdGraph( target_indexing, + indexing, index_parameters.initial_concrete_id_index, loop_indexing.loopDomains()); } diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h index c50b429da94b2..a10931e925964 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h +++ b/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h @@ -13,11 +13,13 @@ namespace cuda { // indexing math there. struct IndexFromIdGraph { IndexCompute index; + IndexCompute concrete_index; std::unordered_map initial_concrete_index_map; std::vector resolved_loop_domains; explicit IndexFromIdGraph( IndexCompute index, + IndexCompute concrete_index, std::unordered_map initial_concrete_index_map, std::vector loop_domains); }; From 7205f3a950840ee13b5ee93fecedd8f2243b5e1b Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 24 Jun 2022 21:33:36 -0700 Subject: [PATCH 12/12] rename the new files --- build_variables.bzl | 2 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 2 +- torch/csrc/jit/codegen/cuda/index_reference_replay.cpp | 2 +- .../cuda/{index_idgraph_utils.cpp => lower_index_compute.cpp} | 2 +- .../cuda/{index_idgraph_utils.h => lower_index_compute.h} | 0 torch/csrc/jit/codegen/cuda/lower_shift.cpp | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename torch/csrc/jit/codegen/cuda/{index_idgraph_utils.cpp => lower_index_compute.cpp} (99%) rename torch/csrc/jit/codegen/cuda/{index_idgraph_utils.h => lower_index_compute.h} (100%) diff --git a/build_variables.bzl b/build_variables.bzl index b258c0337f455..37c1b980638fe 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -656,7 +656,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/graph_fuser.cpp", "torch/csrc/jit/codegen/cuda/grouped_reduction.cpp", "torch/csrc/jit/codegen/cuda/index_compute.cpp", - "torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp", + "torch/csrc/jit/codegen/cuda/lower_index_compute.cpp", "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp", "torch/csrc/jit/codegen/cuda/instrumentation.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 818d15900ebeb..2fe1b65854ddd 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -14,6 +13,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp index 3a94d04deefd8..73a57b6b501d2 100644 --- a/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/index_reference_replay.cpp @@ -1,12 +1,12 @@ #include #include -#include #include #include #include #include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp similarity index 99% rename from torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp rename to torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 1b7ec61015022..ad1610c40aa8f 100644 --- a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/codegen/cuda/index_idgraph_utils.h b/torch/csrc/jit/codegen/cuda/lower_index_compute.h similarity index 100% rename from torch/csrc/jit/codegen/cuda/index_idgraph_utils.h rename to torch/csrc/jit/codegen/cuda/lower_index_compute.h diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index c3fa6ed7f8f6d..17eb863486a90 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -1,12 +1,12 @@ #include #include -#include #include #include #include #include #include #include +#include #include #include