diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 9527520f6041f..6e1c628111113 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -61,6 +61,12 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) { } } +void ExpressionEvaluator::bind( + const std::string& name, + const IntOrDouble& concrete_value) { + known_named_scalars_[name] = concrete_value; +} + c10::optional ExpressionEvaluator::evaluate(Val* value) { if (evaluator_precomputed_values_ != nullptr) { return toOptionalIntOrDouble( diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index d6001137725d7..4329f9604304b 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -7,6 +7,7 @@ #include +#include #include namespace torch { @@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { //! Bind a concrete value to an IR variable void bind(Val* value, const IntOrDouble& concrete_value); + //! Bind a concrete value to a named scalar + void bind(const std::string& name, const IntOrDouble& concrete_value); + //! Try to evaluate a Fusion IR value c10::optional evaluate(Val* value); diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp index 2f29d79f26678..0b97b973f786e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -48,23 +49,78 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) { return 32; } +bool isThreadIdx(const std::string& name) { + return name == "threadIdx.x" || name == "threadIdx.y" || + name == "threadIdx.z"; +} + +bool isBlockIdx(const std::string& name) { + return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z"; +} + +bool isBlockDim(const std::string& name) { + return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z"; +} + +bool isGridDim(const std::string& name) { + return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z"; +} + +ParallelType getParallelType(const std::string& name) { + if (name == "threadIdx.x") { + return ParallelType::TIDx; + } else if (name == "threadIdx.y") { + return ParallelType::TIDy; + } else if (name == "threadIdx.z") { + return ParallelType::TIDz; + } else if (name == "blockIdx.x") { + return ParallelType::BIDx; + } else if (name == "blockIdx.y") { + return ParallelType::BIDy; + } else if (name == "blockIdx.z") { + return ParallelType::BIDz; + } + TORCH_INTERNAL_ASSERT(false, "Not a parallel type"); +} + std::vector evaluateAddressesOnFirstPhase( kir::TensorIndex* ti, - const std::vector& for_loops) { + const std::vector& for_loops, + c10::optional launch_params, + const ExpressionEvaluator& expr_eval_common) { std::vector addresses; const auto word_size_bytes = dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); int64_t phase_size = getPhaseSize(word_size_bytes); - for (auto tidx : c10::irange(phase_size)) { + if (launch_params.has_value()) { + phase_size = std::min(phase_size, launch_params->nThreads()); + } + + for (int64_t linear_tidx : c10::irange(phase_size)) { + int64_t tidx = linear_tidx; + int64_t tidy = 0; + int64_t tidz = 0; + if (launch_params.has_value()) { + tidy = tidx / launch_params->bdimx(); + tidx = tidx % launch_params->bdimx(); + tidz = tidy / launch_params->bdimy(); + tidy = tidy % launch_params->bdimy(); + } int64_t index = 0; - ExpressionEvaluator expr_eval(ti->fusion()); + // make a copy of the expression evaluator + ExpressionEvaluator expr_eval = expr_eval_common; + expr_eval.bind("threadIdx.x", tidx); + expr_eval.bind("threadIdx.y", tidy); + expr_eval.bind("threadIdx.z", tidz); for (auto fl : for_loops) { - if (fl->index()->isA() && - fl->index()->as()->name() == "threadIdx.x") { - expr_eval.bind(fl->index(), tidx); + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + TORCH_INTERNAL_ASSERT( + isThreadIdx(name) || isBlockIdx(name), "unknow loop index"); } else { - expr_eval.bind(fl->index(), 0); + auto start = expr_eval.evaluate(fl->start())->as(); + expr_eval.bind(fl->index(), start); } } for (auto ind : ti->indices()) { @@ -89,17 +145,97 @@ int getConflictWays(const std::vector& addresses) { return conflict; } -} // namespace +class InferLaunchParams : public kir::IrVisitor { + public: + static c10::optional get( + const std::vector& exprs, + const std::unordered_map& known_values) { + if (exprs.empty()) { + return c10::nullopt; + } + return InferLaunchParams(exprs, known_values).launch_params_; + } + + private: + InferLaunchParams( + const std::vector& exprs, + const std::unordered_map& known_values) + : expr_eval_(exprs[0]->fusion()) { + for (auto pair : known_values) { + expr_eval_.bind(pair.first, pair.second); + } + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + for (auto fl : for_loops_) { + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + if (isThreadIdx(name) || isBlockIdx(name)) { + auto ptype = getParallelType(name); + auto stop = expr_eval_.evaluate(fl->stop()); + if (stop.has_value()) { + if (!launch_params_.has_value()) { + launch_params_ = LaunchParams(); + } + if (launch_params_->getRawVal(ptype) == + LaunchParams::UNINITIALIZED_VAL) { + launch_params_->bind(stop->as(), ptype); + } else { + TORCH_INTERNAL_ASSERT( + launch_params_->getDim(ptype) == stop, + "Unable to infer launch parameters"); + } + } + } + } + } + } + + ExpressionEvaluator expr_eval_; + c10::optional launch_params_; +}; class BankConflictInfo : public kir::IrVisitor { public: static std::unordered_map> get( - const std::vector& exprs) { - return BankConflictInfo(exprs).bank_conflict_info_; + const std::vector& exprs, + c10::optional launch_params, + const std::unordered_map& known_values) { + if (exprs.empty()) { + return {}; + } + return BankConflictInfo(exprs, launch_params, known_values) + .bank_conflict_info_; } private: - BankConflictInfo(const std::vector& exprs) { + BankConflictInfo( + const std::vector& exprs, + c10::optional launch_params, + const std::unordered_map& known_values) + : launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) { + expr_eval_common_.bind("blockIdx.x", 0); + expr_eval_common_.bind("blockIdx.y", 0); + expr_eval_common_.bind("blockIdx.z", 0); + if (launch_params.has_value()) { + expr_eval_common_.bind("blockDim.x", launch_params->bdimx()); + expr_eval_common_.bind("blockDim.y", launch_params->bdimy()); + expr_eval_common_.bind("blockDim.z", launch_params->bdimz()); + expr_eval_common_.bind("gridDim.x", launch_params->gdimx()); + expr_eval_common_.bind("gridDim.y", launch_params->gdimy()); + expr_eval_common_.bind("gridDim.z", launch_params->gdimz()); + } + for (auto pair : known_values) { + expr_eval_common_.bind(pair.first, pair.second); + } handle(exprs); } @@ -119,11 +255,17 @@ class BankConflictInfo : public kir::IrVisitor { std::pair conflict_ways{0, 0}; if (isSmemTensorIndex(uop->in())) { conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( - uop->in()->as(), for_loops_)); + uop->in()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); } if (isSmemTensorIndex(uop->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( - uop->out()->as(), for_loops_)); + uop->out()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -133,11 +275,17 @@ class BankConflictInfo : public kir::IrVisitor { std::pair conflict_ways{0, 0}; if (isSmemTensorIndex(ldst->in())) { conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( - ldst->in()->as(), for_loops_)); + ldst->in()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); } if (isSmemTensorIndex(ldst->out())) { conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( - ldst->out()->as(), for_loops_)); + ldst->out()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); } if (conflict_ways.first > 1 || conflict_ways.second > 1) { bank_conflict_info_[expr] = conflict_ways; @@ -146,11 +294,36 @@ class BankConflictInfo : public kir::IrVisitor { } std::unordered_map> bank_conflict_info_; + c10::optional launch_params_; + ExpressionEvaluator expr_eval_common_; }; +} // namespace + std::unordered_map> getBankConflictInfo( - kir::Kernel* kernel) { - return BankConflictInfo::get(kernel->topLevelExprs()); + kir::Kernel* kernel, + c10::optional launch_params, + const std::unordered_map& known_values) { + for (auto pair : known_values) { + TORCH_CHECK( + !isThreadIdx(pair.first), + "threadIdx.{x,y,z} should be computed instead of provided"); + TORCH_CHECK( + !isBlockIdx(pair.first), + "blockIdx.{x,y,z} should not be provided (they are always zero)"); + TORCH_CHECK( + !isBlockDim(pair.first), + "blockDim.{x,y,z} should be provided by launch_params"); + TORCH_CHECK( + !isGridDim(pair.first), + "gridDim.{x,y,z} should be provided by launch_params"); + } + if (!launch_params.has_value()) { + launch_params = + InferLaunchParams::get(kernel->topLevelExprs(), known_values); + } + return BankConflictInfo::get( + kernel->topLevelExprs(), launch_params, known_values); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h index 12c12d4bff4d8..b651c4ed33e22 100644 --- a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -18,27 +20,25 @@ namespace cuda { // nsight compute. This utility currently has the following assumptions and // limitations: // -// 1. This utility assumes that `blockDim.x` is large enough to hold one phase -// 2. This utility assumes that the address only depends on loop variables -// (there can not be a thing like `T0.stride[0]`, `blockDim.x`) -// 3. This utility assumes that the data of the tensor is accessed by +// 1. This utility assumes that the data of the tensor is accessed by // `T0[index]`, where `index` is the one stored in the `TensorIndex` // object. -// 4. This utility only checks the first iteration, and the start of all -// loop variables are assumed to be `0` (if we have something like +// 2. This utility only checks the first iteration. If we have something like // `T1_s[tidx, 5]`, then different iterations should have different -// results, which this utility will not be able to handle all of them now) -// 5. This utility assumes that all tensors are independent, which means: -// 5.1 All shared memory tensors are allocated starting from a multiple of +// conflictions, which will not be evaluated for all of them +// 3. This utility assumes that all tensors are independent, which means: +// 3.1 All shared memory tensors are allocated starting from a multiple of // 4*32 bytes -// 5.2 The only source of bank confliction is from within a tensor. +// 3.2 The only source of bank confliction is from within a tensor. // There is no bank conflict between different tensors. // // Also note that this utility will not provide accurate estimation if the above // assumptions are satisfied std::unordered_map> getBankConflictInfo( - kir::Kernel* kernel); + kir::Kernel* kernel, + c10::optional launch_params = c10::nullopt, + const std::unordered_map& known_values = {}); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 8c00fea08489c..b10360f00315e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -1041,6 +1041,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict1_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{32, 0}; TORCH_CHECK(info.second == expect); @@ -1065,6 +1066,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict2_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{0, 32}; TORCH_CHECK(info.second == expect); @@ -1089,6 +1091,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict3_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect{8, 0}; TORCH_CHECK(info.second == expect); @@ -1129,6 +1132,7 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { auto bank_conflict_info = fusion.bankConflictInfo(); + TORCH_CHECK(!bank_conflict_info.empty()); for (auto info : bank_conflict_info) { std::pair expect1{0, 8}; std::pair expect2{8, 4}; @@ -1139,6 +1143,118 @@ TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { } } +TEST_F(NVFuserTest, FusionTransposeBankConflict5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv3->axis(2)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{0, 2}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + // no bank confliction + TORCH_CHECK(bank_conflict_info.empty()); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)