From 90fb4952339cbc95b9e79acae3996f7475f73faf Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 16:04:09 -0700 Subject: [PATCH 1/8] improvements on bank conflict checker --- .../csrc/jit/codegen/cuda/expr_evaluator.cpp | 6 + torch/csrc/jit/codegen/cuda/expr_evaluator.h | 4 + .../jit/codegen/cuda/lower_bank_conflict.cpp | 120 +++++++++++++++--- .../jit/codegen/cuda/lower_bank_conflict.h | 22 ++-- 4 files changed, 125 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 9527520f6041f..6e1c628111113 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -61,6 +61,12 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) { } } +void ExpressionEvaluator::bind( + const std::string& name, + const IntOrDouble& concrete_value) { + known_named_scalars_[name] = concrete_value; +} + c10::optional ExpressionEvaluator::evaluate(Val* value) { if (evaluator_precomputed_values_ != nullptr) { return toOptionalIntOrDouble( diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index d6001137725d7..4329f9604304b 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -7,6 +7,7 @@ #include +#include #include namespace torch { @@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { //! Bind a concrete value to an IR variable void bind(Val* value, const IntOrDouble& concrete_value); + //! Bind a concrete value to a named scalar + void bind(const std::string& name, const IntOrDouble& concrete_value); + //! Try to evaluate a Fusion IR value c10::optional evaluate(Val* value); diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index 2f29d79f26678..b666fa4f1ec51 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -48,23 +48,73 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) { return 32; } +bool isThreadIdx(const std::string& name) { + return name == "threadIdx.x" || name == "threadIdx.y" || + name == "threadIdx.z"; +} + +bool isBlockIdx(const std::string& name) { + return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z"; +} + +bool isBlockDim(const std::string& name) { + return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z"; +} + +bool isGridDim(const std::string& name) { + return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z"; +} + std::vector evaluateAddressesOnFirstPhase( kir::TensorIndex* ti, - const std::vector& for_loops) { + const std::vector& for_loops, + c10::optional launch_params, + const std::unordered_map& known_values) { std::vector addresses; const auto word_size_bytes = dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); int64_t phase_size = getPhaseSize(word_size_bytes); - for (auto tidx : c10::irange(phase_size)) { + ExpressionEvaluator expr_eval_common(ti->fusion()); + expr_eval_common.bind("blockIdx.x", 0); + expr_eval_common.bind("blockIdx.y", 0); + expr_eval_common.bind("blockIdx.z", 0); + if (launch_params.has_value()) { + expr_eval_common.bind("blockDim.x", launch_params->bdimx()); + expr_eval_common.bind("blockDim.y", launch_params->bdimy()); + expr_eval_common.bind("blockDim.z", launch_params->bdimz()); + expr_eval_common.bind("gridDim.x", launch_params->gdimx()); + expr_eval_common.bind("gridDim.y", launch_params->gdimy()); + expr_eval_common.bind("gridDim.z", launch_params->gdimz()); + } + for (auto pair : known_values) { + expr_eval_common.bind(pair.first, pair.second); + } + + for (int64_t linear_tidx : c10::irange(phase_size)) { + int64_t tidx = linear_tidx; + int64_t tidy = 0; + int64_t tidz = 0; + if (launch_params.has_value()) { + tidy = tidx / launch_params->bdimx(); + tidx = tidx % launch_params->bdimx(); + tidz = tidy / launch_params->bdimy(); + tidy = tidy % launch_params->bdimy(); + } int64_t index = 0; - ExpressionEvaluator expr_eval(ti->fusion()); + // make a copy of the expression evaluator + ExpressionEvaluator expr_eval = expr_eval_common; + expr_eval.bind("threadIdx.x", tidx); + expr_eval.bind("threadIdx.y", tidy); + expr_eval.bind("threadIdx.z", tidz); for (auto fl : for_loops) { - if (fl->index()->isA() && - fl->index()->as()->name() == "threadIdx.x") { - expr_eval.bind(fl->index(), tidx); + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + TORCH_INTERNAL_ASSERT( + isThreadIdx(name) || isBlockIdx(name), "unknow loop index"); } else { - expr_eval.bind(fl->index(), 0); + auto start = expr_eval.evaluate(fl->start())->as(); + expr_eval.bind(fl->index(), start); } } for (auto ind : ti->indices()) { @@ -94,12 +144,33 @@ int getConflictWays(const std::vector& addresses) { class BankConflictInfo : public kir::IrVisitor { public: static std::unordered_map> get( - const std::vector& exprs) { - return BankConflictInfo(exprs).bank_conflict_info_; + const std::vector& exprs, + c10::optional launch_params, + std::unordered_map known_values) { + return BankConflictInfo(exprs, launch_params, known_values) + .bank_conflict_info_; } private: - BankConflictInfo(const std::vector& exprs) { + BankConflictInfo( + const std::vector& exprs, + c10::optional launch_params, + std::unordered_map known_values) + : launch_params_(launch_params), known_values_(std::move(known_values)) { + for (auto pair : known_values) { + TORCH_CHECK( + !isThreadIdx(pair.first), + "threadIdx.{x,y,z} should be computed instead of provided"); + TORCH_CHECK( + !isBlockIdx(pair.first), + "blockIdx.{x,y,z} should not be provided (they are always zero)"); + TORCH_CHECK( + !isBlockDim(pair.first), + "blockDim.{x,y,z} should be provided by launch_params"); + TORCH_CHECK( + !isGridDim(pair.first), + "gridDim.{x,y,z} should be provided by launch_params"); + } handle(exprs); } @@ -119,11 +190,17 @@ class BankConflictInfo : public kir::IrVisitor { std::pair conflict_ways{0, 0}; if (isSmemTensorIndex(uop->in())) { conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( - uop->in()->as(), for_loops_)); + uop->in()->as(), + for_loops_, + launch_params_, + known_values_)); } if (isSmemTensorIndex(uop->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( - uop->out()->as(), for_loops_)); + uop->out()->as(), + for_loops_, + launch_params_, + known_values_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -133,11 +210,17 @@ class BankConflictInfo : public kir::IrVisitor { std::pair conflict_ways{0, 0}; if (isSmemTensorIndex(ldst->in())) { conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( - ldst->in()->as(), for_loops_)); + ldst->in()->as(), + for_loops_, + launch_params_, + known_values_)); } if (isSmemTensorIndex(ldst->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( - ldst->out()->as(), for_loops_)); + ldst->out()->as(), + for_loops_, + launch_params_, + known_values_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -146,11 +229,16 @@ class BankConflictInfo : public kir::IrVisitor { } std::unordered_map> bank_conflict_info_; + c10::optional launch_params_; + std::unordered_map known_values_; }; std::unordered_map> getBankConflictInfo( - kir::Kernel* kernel) { - return BankConflictInfo::get(kernel->topLevelExprs()); + kir::Kernel* kernel, + c10::optional launch_params, + std::unordered_map known_values) { + return BankConflictInfo::get( + kernel->topLevelExprs(), launch_params, std::move(known_values)); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h index 12c12d4bff4d8..acef403d139cd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h @@ -1,7 +1,9 @@ #pragma once +#include #include #include +#include #include #include @@ -18,27 +20,25 @@ namespace cuda { // nsight compute. This utility currently has the following assumptions and // limitations: // -// 1. This utility assumes that `blockDim.x` is large enough to hold one phase -// 2. This utility assumes that the address only depends on loop variables -// (there can not be a thing like `T0.stride[0]`, `blockDim.x`) -// 3. This utility assumes that the data of the tensor is accessed by +// 1. This utility assumes that the data of the tensor is accessed by // `T0[index]`, where `index` is the one stored in the `TensorIndex` // object. -// 4. This utility only checks the first iteration, and the start of all -// loop variables are assumed to be `0` (if we have something like +// 2. This utility only checks the first iteration. If we have something like // `T1_s[tidx, 5]`, then different iterations should have different -// results, which this utility will not be able to handle all of them now) -// 5. This utility assumes that all tensors are independent, which means: -// 5.1 All shared memory tensors are allocated starting from a multiple of +// conflictions, which will not be evaluated for all of them +// 3. This utility assumes that all tensors are independent, which means: +// 3.1 All shared memory tensors are allocated starting from a multiple of // 4*32 bytes -// 5.2 The only source of bank confliction is from within a tensor. +// 3.2 The only source of bank confliction is from within a tensor. // There is no bank conflict between different tensors. // // Also note that this utility will not provide accurate estimation if the above // assumptions are satisfied std::unordered_map> getBankConflictInfo( - kir::Kernel* kernel); + kir::Kernel* kernel, + c10::optional launch_params = c10::nullopt, + std::unordered_map known_values = {}); } // namespace cuda } // namespace fuser From 4d1720dee8f1ad1d626bdfc700ef68821f88bc14 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 16:15:37 -0700 Subject: [PATCH 2/8] test block parallel --- .../codegen/cuda/test/test_gpu_transpose.cpp | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 8c00fea08489c..409ffbf96cdbb 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1139,6 +1139,33 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { } } +TEST_F(NVFuserTest, FusionTransposeBankConflict5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv3->axis(2)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From b0c4352bc1a7d87422556f98c97d5868ab673c29 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 17:56:46 -0700 Subject: [PATCH 3/8] infer launch params if possible --- .../jit/codegen/cuda/lower_bank_conflict.cpp | 36 +++++++++++++++++++ .../codegen/cuda/test/test_gpu_transpose.cpp | 33 +++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index b666fa4f1ec51..057d09dc1ee94 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -86,6 +86,42 @@ std::vector evaluateAddressesOnFirstPhase( expr_eval_common.bind("gridDim.x", launch_params->gdimx()); expr_eval_common.bind("gridDim.y", launch_params->gdimy()); expr_eval_common.bind("gridDim.z", launch_params->gdimz()); + } else { + // infer launch params + for (auto fl : for_loops) { + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + if (isThreadIdx(name)) { + auto stop = expr_eval_common.evaluate(fl->stop()); + if (stop.has_value()) { + if (!launch_params.has_value()) { + launch_params = LaunchParams(); + } + if (name == "threadIdx.x") { + launch_params->bind(stop->as(), ParallelType::TIDx); + } else if (name == "threadIdx.y") { + launch_params->bind(stop->as(), ParallelType::TIDy); + } else if (name == "threadIdx.z") { + launch_params->bind(stop->as(), ParallelType::TIDz); + } + } + } else if (isBlockIdx(name)) { + auto stop = expr_eval_common.evaluate(fl->stop()); + if (stop.has_value()) { + if (!launch_params.has_value()) { + launch_params = LaunchParams(); + } + if (name == "blockIdx.x") { + launch_params->bind(stop->as(), ParallelType::BIDx); + } else if (name == "blockIdx.y") { + launch_params->bind(stop->as(), ParallelType::BIDy); + } else if (name == "blockIdx.z") { + launch_params->bind(stop->as(), ParallelType::BIDz); + } + } + } + } + } } for (auto pair : known_values) { expr_eval_common.bind(pair.first, pair.second); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 409ffbf96cdbb..72a710d39f1f9 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1041,6 +1041,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict1_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{32, 0}; TORCH_CHECK(info.second == expect); @@ -1065,6 +1066,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict2_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{0, 32}; TORCH_CHECK(info.second == expect); @@ -1089,6 +1091,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict3_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{8, 0}; TORCH_CHECK(info.second == expect); @@ -1129,6 +1132,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect1{0, 8}; std::pair expect2{8, 4}; @@ -1160,6 +1164,35 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict5_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{32, 0}; TORCH_CHECK(info.second == expect); From 5009af296c2f0c357e6bdfaadcf3e16ffd71acf5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 18:11:09 -0700 Subject: [PATCH 4/8] more test --- .../codegen/cuda/test/test_gpu_transpose.cpp | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 72a710d39f1f9..0cdbe81133970 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1199,6 +1199,37 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict6_CUDA) { } } +TEST_F(NVFuserTest, FusionTransposeBankConflict7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{0, 2}; + TORCH_CHECK(info.second == expect); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From d75f1a9cd89c3dfddf9ab4d88e346101672d32e2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 18:16:27 -0700 Subject: [PATCH 5/8] move --- .../jit/codegen/cuda/lower_bank_conflict.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index 057d09dc1ee94..a62204d4a7096 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -79,14 +79,7 @@ std::vector evaluateAddressesOnFirstPhase( expr_eval_common.bind("blockIdx.x", 0); expr_eval_common.bind("blockIdx.y", 0); expr_eval_common.bind("blockIdx.z", 0); - if (launch_params.has_value()) { - expr_eval_common.bind("blockDim.x", launch_params->bdimx()); - expr_eval_common.bind("blockDim.y", launch_params->bdimy()); - expr_eval_common.bind("blockDim.z", launch_params->bdimz()); - expr_eval_common.bind("gridDim.x", launch_params->gdimx()); - expr_eval_common.bind("gridDim.y", launch_params->gdimy()); - expr_eval_common.bind("gridDim.z", launch_params->gdimz()); - } else { + if (!launch_params.has_value()) { // infer launch params for (auto fl : for_loops) { if (fl->index()->isA()) { @@ -123,6 +116,14 @@ std::vector evaluateAddressesOnFirstPhase( } } } + if (launch_params.has_value()) { + expr_eval_common.bind("blockDim.x", launch_params->bdimx()); + expr_eval_common.bind("blockDim.y", launch_params->bdimy()); + expr_eval_common.bind("blockDim.z", launch_params->bdimz()); + expr_eval_common.bind("gridDim.x", launch_params->gdimx()); + expr_eval_common.bind("gridDim.y", launch_params->gdimy()); + expr_eval_common.bind("gridDim.z", launch_params->gdimz()); + } for (auto pair : known_values) { expr_eval_common.bind(pair.first, pair.second); } From 335e45c0c1717a9769eb8f27be3bae2ad162b665 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 18:55:07 -0700 Subject: [PATCH 6/8] infer launch params --- .../jit/codegen/cuda/lower_bank_conflict.cpp | 163 +++++++++++------- .../jit/codegen/cuda/lower_bank_conflict.h | 4 +- 2 files changed, 107 insertions(+), 60 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index a62204d4a7096..aee2d8f06eeeb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -65,6 +66,23 @@ bool isGridDim(const std::string& name) { return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z"; } +ParallelType getParallelType(const std::string& name) { + if (name == "threadIdx.x") { + return ParallelType::TIDx; + } else if (name == "threadIdx.y") { + return ParallelType::TIDy; + } else if (name == "threadIdx.z") { + return ParallelType::TIDz; + } else if (name == "blockIdx.x") { + return ParallelType::BIDx; + } else if (name == "blockIdx.y") { + return ParallelType::BIDy; + } else if (name == "blockIdx.z") { + return ParallelType::BIDz; + } + TORCH_INTERNAL_ASSERT(false, "Not a parallel type"); +} + std::vector evaluateAddressesOnFirstPhase( kir::TensorIndex* ti, const std::vector& for_loops, @@ -75,47 +93,14 @@ std::vector evaluateAddressesOnFirstPhase( dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); int64_t phase_size = getPhaseSize(word_size_bytes); + if (launch_params.has_value()) { + phase_size = std::min(phase_size, launch_params->nThreads()); + } + ExpressionEvaluator expr_eval_common(ti->fusion()); expr_eval_common.bind("blockIdx.x", 0); expr_eval_common.bind("blockIdx.y", 0); expr_eval_common.bind("blockIdx.z", 0); - if (!launch_params.has_value()) { - // infer launch params - for (auto fl : for_loops) { - if (fl->index()->isA()) { - auto name = fl->index()->as()->name(); - if (isThreadIdx(name)) { - auto stop = expr_eval_common.evaluate(fl->stop()); - if (stop.has_value()) { - if (!launch_params.has_value()) { - launch_params = LaunchParams(); - } - if (name == "threadIdx.x") { - launch_params->bind(stop->as(), ParallelType::TIDx); - } else if (name == "threadIdx.y") { - launch_params->bind(stop->as(), ParallelType::TIDy); - } else if (name == "threadIdx.z") { - launch_params->bind(stop->as(), ParallelType::TIDz); - } - } - } else if (isBlockIdx(name)) { - auto stop = expr_eval_common.evaluate(fl->stop()); - if (stop.has_value()) { - if (!launch_params.has_value()) { - launch_params = LaunchParams(); - } - if (name == "blockIdx.x") { - launch_params->bind(stop->as(), ParallelType::BIDx); - } else if (name == "blockIdx.y") { - launch_params->bind(stop->as(), ParallelType::BIDy); - } else if (name == "blockIdx.z") { - launch_params->bind(stop->as(), ParallelType::BIDz); - } - } - } - } - } - } if (launch_params.has_value()) { expr_eval_common.bind("blockDim.x", launch_params->bdimx()); expr_eval_common.bind("blockDim.y", launch_params->bdimy()); @@ -176,14 +161,70 @@ int getConflictWays(const std::vector& addresses) { return conflict; } -} // namespace +class InferLaunchParams : public kir::IrVisitor { + public: + static c10::optional get( + const std::vector& exprs, + const std::unordered_map& known_values) { + if (exprs.empty()) { + return c10::nullopt; + } + return InferLaunchParams(exprs, known_values).launch_params_; + } + + private: + InferLaunchParams( + const std::vector& exprs, + const std::unordered_map& known_values) + : expr_eval_(exprs[0]->fusion()) { + for (auto pair : known_values) { + expr_eval_.bind(pair.first, pair.second); + } + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + for (auto fl : for_loops_) { + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + if (isThreadIdx(name) || isBlockIdx(name)) { + auto ptype = getParallelType(name); + auto stop = expr_eval_.evaluate(fl->stop()); + if (stop.has_value()) { + if (!launch_params_.has_value()) { + launch_params_ = LaunchParams(); + } + if (launch_params_->getRawVal(ptype) == + LaunchParams::UNINITIALIZED_VAL) { + launch_params_->bind(stop->as(), ptype); + } else { + TORCH_INTERNAL_ASSERT( + launch_params_->getDim(ptype) == stop, + "Unable to infer launch parameters"); + } + } + } + } + } + } + + ExpressionEvaluator expr_eval_; + c10::optional launch_params_; +}; class BankConflictInfo : public kir::IrVisitor { public: static std::unordered_map> get( const std::vector& exprs, c10::optional launch_params, - std::unordered_map known_values) { + const std::unordered_map& known_values) { return BankConflictInfo(exprs, launch_params, known_values) .bank_conflict_info_; } @@ -192,22 +233,8 @@ class BankConflictInfo : public kir::IrVisitor { BankConflictInfo( const std::vector& exprs, c10::optional launch_params, - std::unordered_map known_values) - : launch_params_(launch_params), known_values_(std::move(known_values)) { - for (auto pair : known_values) { - TORCH_CHECK( - !isThreadIdx(pair.first), - "threadIdx.{x,y,z} should be computed instead of provided"); - TORCH_CHECK( - !isBlockIdx(pair.first), - "blockIdx.{x,y,z} should not be provided (they are always zero)"); - TORCH_CHECK( - !isBlockDim(pair.first), - "blockDim.{x,y,z} should be provided by launch_params"); - TORCH_CHECK( - !isGridDim(pair.first), - "gridDim.{x,y,z} should be provided by launch_params"); - } + const std::unordered_map& known_values) + : launch_params_(launch_params), known_values_(known_values) { handle(exprs); } @@ -267,15 +294,35 @@ class BankConflictInfo : public kir::IrVisitor { std::unordered_map> bank_conflict_info_; c10::optional launch_params_; - std::unordered_map known_values_; + const std::unordered_map& known_values_; }; +} // namespace + std::unordered_map> getBankConflictInfo( kir::Kernel* kernel, c10::optional launch_params, - std::unordered_map known_values) { + const std::unordered_map& known_values) { + for (auto pair : known_values) { + TORCH_CHECK( + !isThreadIdx(pair.first), + "threadIdx.{x,y,z} should be computed instead of provided"); + TORCH_CHECK( + !isBlockIdx(pair.first), + "blockIdx.{x,y,z} should not be provided (they are always zero)"); + TORCH_CHECK( + !isBlockDim(pair.first), + "blockDim.{x,y,z} should be provided by launch_params"); + TORCH_CHECK( + !isGridDim(pair.first), + "gridDim.{x,y,z} should be provided by launch_params"); + } + if (!launch_params.has_value()) { + launch_params = + InferLaunchParams::get(kernel->topLevelExprs(), known_values); + } return BankConflictInfo::get( - kernel->topLevelExprs(), launch_params, std::move(known_values)); + kernel->topLevelExprs(), launch_params, known_values); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h index acef403d139cd..b651c4ed33e22 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h @@ -1,9 +1,9 @@ #pragma once +#include #include #include #include -#include #include #include @@ -38,7 +38,7 @@ namespace cuda { std::unordered_map> getBankConflictInfo( kir::Kernel* kernel, c10::optional launch_params = c10::nullopt, - std::unordered_map known_values = {}); + const std::unordered_map& known_values = {}); } // namespace cuda } // namespace fuser From d7f062631f77813065c1fdcee4eafe09817b7214 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 19:01:08 -0700 Subject: [PATCH 7/8] test --- .../codegen/cuda/test/test_gpu_transpose.cpp | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 0cdbe81133970..b10360f00315e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1230,6 +1230,31 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict7_CUDA) { } } +TEST_F(NVFuserTest, FusionTransposeBankConflict8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + // no bank confliction + TORCH_CHECK(bank_conflict_info.empty()); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From 3c6c5bdb0144b881408b9882d446b315a6661eb9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 19:08:59 -0700 Subject: [PATCH 8/8] cleanup --- .../jit/codegen/cuda/lower_bank_conflict.cpp | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index aee2d8f06eeeb..0b97b973f786e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -87,7 +87,7 @@ std::vector evaluateAddressesOnFirstPhase( kir::TensorIndex* ti, const std::vector& for_loops, c10::optional launch_params, - const std::unordered_map& known_values) { + const ExpressionEvaluator& expr_eval_common) { std::vector addresses; const auto word_size_bytes = dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); @@ -97,22 +97,6 @@ std::vector evaluateAddressesOnFirstPhase( phase_size = std::min(phase_size, launch_params->nThreads()); } - ExpressionEvaluator expr_eval_common(ti->fusion()); - expr_eval_common.bind("blockIdx.x", 0); - expr_eval_common.bind("blockIdx.y", 0); - expr_eval_common.bind("blockIdx.z", 0); - if (launch_params.has_value()) { - expr_eval_common.bind("blockDim.x", launch_params->bdimx()); - expr_eval_common.bind("blockDim.y", launch_params->bdimy()); - expr_eval_common.bind("blockDim.z", launch_params->bdimz()); - expr_eval_common.bind("gridDim.x", launch_params->gdimx()); - expr_eval_common.bind("gridDim.y", launch_params->gdimy()); - expr_eval_common.bind("gridDim.z", launch_params->gdimz()); - } - for (auto pair : known_values) { - expr_eval_common.bind(pair.first, pair.second); - } - for (int64_t linear_tidx : c10::irange(phase_size)) { int64_t tidx = linear_tidx; int64_t tidy = 0; @@ -225,6 +209,9 @@ class BankConflictInfo : public kir::IrVisitor { const std::vector& exprs, c10::optional launch_params, const std::unordered_map& known_values) { + if (exprs.empty()) { + return {}; + } return BankConflictInfo(exprs, launch_params, known_values) .bank_conflict_info_; } @@ -234,7 +221,21 @@ class BankConflictInfo : public kir::IrVisitor { const std::vector& exprs, c10::optional launch_params, const std::unordered_map& known_values) - : launch_params_(launch_params), known_values_(known_values) { + : launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) { + expr_eval_common_.bind("blockIdx.x", 0); + expr_eval_common_.bind("blockIdx.y", 0); + expr_eval_common_.bind("blockIdx.z", 0); + if (launch_params.has_value()) { + expr_eval_common_.bind("blockDim.x", launch_params->bdimx()); + expr_eval_common_.bind("blockDim.y", launch_params->bdimy()); + expr_eval_common_.bind("blockDim.z", launch_params->bdimz()); + expr_eval_common_.bind("gridDim.x", launch_params->gdimx()); + expr_eval_common_.bind("gridDim.y", launch_params->gdimy()); + expr_eval_common_.bind("gridDim.z", launch_params->gdimz()); + } + for (auto pair : known_values) { + expr_eval_common_.bind(pair.first, pair.second); + } handle(exprs); } @@ -257,14 +258,14 @@ class BankConflictInfo : public kir::IrVisitor { uop->in()->as(), for_loops_, launch_params_, - known_values_)); + expr_eval_common_)); } if (isSmemTensorIndex(uop->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( uop->out()->as(), for_loops_, launch_params_, - known_values_)); + expr_eval_common_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -277,14 +278,14 @@ class BankConflictInfo : public kir::IrVisitor { ldst->in()->as(), for_loops_, launch_params_, - known_values_)); + expr_eval_common_)); } if (isSmemTensorIndex(ldst->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( ldst->out()->as(), for_loops_, launch_params_, - known_values_)); + expr_eval_common_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -294,7 +295,7 @@ class BankConflictInfo : public kir::IrVisitor { std::unordered_map> bank_conflict_info_; c10::optional launch_params_; - const std::unordered_map& known_values_; + ExpressionEvaluator expr_eval_common_; }; } // namespace