Skip to content

Commit

Permalink
Bank conflict checker improvements (#2032)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Oct 8, 2022
1 parent d2ca7e3 commit f5bca33
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 28 deletions.
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) {
}
}

void ExpressionEvaluator::bind(
const std::string& name,
const IntOrDouble& concrete_value) {
known_named_scalars_[name] = concrete_value;
}

c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) {
if (evaluator_precomputed_values_ != nullptr) {
return toOptionalIntOrDouble(
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <c10/util/Optional.h>

#include <string>
#include <unordered_map>

namespace torch {
Expand All @@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
//! Bind a concrete value to an IR variable
void bind(Val* value, const IntOrDouble& concrete_value);

//! Bind a concrete value to a named scalar
void bind(const std::string& name, const IntOrDouble& concrete_value);

//! Try to evaluate a Fusion IR value
c10::optional<IntOrDouble> evaluate(Val* value);

Expand Down
207 changes: 190 additions & 17 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>

#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
Expand Down Expand Up @@ -48,23 +49,78 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) {
return 32;
}

bool isThreadIdx(const std::string& name) {
return name == "threadIdx.x" || name == "threadIdx.y" ||
name == "threadIdx.z";
}

bool isBlockIdx(const std::string& name) {
return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z";
}

bool isBlockDim(const std::string& name) {
return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z";
}

bool isGridDim(const std::string& name) {
return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z";
}

ParallelType getParallelType(const std::string& name) {
if (name == "threadIdx.x") {
return ParallelType::TIDx;
} else if (name == "threadIdx.y") {
return ParallelType::TIDy;
} else if (name == "threadIdx.z") {
return ParallelType::TIDz;
} else if (name == "blockIdx.x") {
return ParallelType::BIDx;
} else if (name == "blockIdx.y") {
return ParallelType::BIDy;
} else if (name == "blockIdx.z") {
return ParallelType::BIDz;
}
TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
}

std::vector<int64_t> evaluateAddressesOnFirstPhase(
kir::TensorIndex* ti,
const std::vector<kir::ForLoop*>& for_loops) {
const std::vector<kir::ForLoop*>& for_loops,
c10::optional<LaunchParams> launch_params,
const ExpressionEvaluator& expr_eval_common) {
std::vector<int64_t> addresses;
const auto word_size_bytes =
dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti);
int64_t phase_size = getPhaseSize(word_size_bytes);

for (auto tidx : c10::irange(phase_size)) {
if (launch_params.has_value()) {
phase_size = std::min<int64_t>(phase_size, launch_params->nThreads());
}

for (int64_t linear_tidx : c10::irange(phase_size)) {
int64_t tidx = linear_tidx;
int64_t tidy = 0;
int64_t tidz = 0;
if (launch_params.has_value()) {
tidy = tidx / launch_params->bdimx();
tidx = tidx % launch_params->bdimx();
tidz = tidy / launch_params->bdimy();
tidy = tidy % launch_params->bdimy();
}
int64_t index = 0;
ExpressionEvaluator expr_eval(ti->fusion());
// make a copy of the expression evaluator
ExpressionEvaluator expr_eval = expr_eval_common;
expr_eval.bind("threadIdx.x", tidx);
expr_eval.bind("threadIdx.y", tidy);
expr_eval.bind("threadIdx.z", tidz);
for (auto fl : for_loops) {
if (fl->index()->isA<NamedScalar>() &&
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") {
expr_eval.bind(fl->index(), tidx);
if (fl->index()->isA<NamedScalar>()) {
auto name = fl->index()->as<NamedScalar>()->name();
TORCH_INTERNAL_ASSERT(
isThreadIdx(name) || isBlockIdx(name), "unknow loop index");
} else {
expr_eval.bind(fl->index(), 0);
auto start = expr_eval.evaluate(fl->start())->as<int64_t>();
expr_eval.bind(fl->index(), start);
}
}
for (auto ind : ti->indices()) {
Expand All @@ -89,17 +145,97 @@ int getConflictWays(const std::vector<int64_t>& addresses) {
return conflict;
}

} // namespace
class InferLaunchParams : public kir::IrVisitor {
public:
static c10::optional<LaunchParams> get(
const std::vector<Expr*>& exprs,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
if (exprs.empty()) {
return c10::nullopt;
}
return InferLaunchParams(exprs, known_values).launch_params_;
}

private:
InferLaunchParams(
const std::vector<Expr*>& exprs,
const std::unordered_map<std::string, IntOrDouble>& known_values)
: expr_eval_(exprs[0]->fusion()) {
for (auto pair : known_values) {
expr_eval_.bind(pair.first, pair.second);
}
handle(exprs);
}

using kir::IrVisitor::handle;

void handle(Expr* expr) final {
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
kir::IrVisitor::handle(expr);
return;
}

for (auto fl : for_loops_) {
if (fl->index()->isA<NamedScalar>()) {
auto name = fl->index()->as<NamedScalar>()->name();
if (isThreadIdx(name) || isBlockIdx(name)) {
auto ptype = getParallelType(name);
auto stop = expr_eval_.evaluate(fl->stop());
if (stop.has_value()) {
if (!launch_params_.has_value()) {
launch_params_ = LaunchParams();
}
if (launch_params_->getRawVal(ptype) ==
LaunchParams::UNINITIALIZED_VAL) {
launch_params_->bind(stop->as<int64_t>(), ptype);
} else {
TORCH_INTERNAL_ASSERT(
launch_params_->getDim(ptype) == stop,
"Unable to infer launch parameters");
}
}
}
}
}
}

ExpressionEvaluator expr_eval_;
c10::optional<LaunchParams> launch_params_;
};

class BankConflictInfo : public kir::IrVisitor {
public:
static std::unordered_map<const Expr*, std::pair<int, int>> get(
const std::vector<Expr*>& exprs) {
return BankConflictInfo(exprs).bank_conflict_info_;
const std::vector<Expr*>& exprs,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
if (exprs.empty()) {
return {};
}
return BankConflictInfo(exprs, launch_params, known_values)
.bank_conflict_info_;
}

private:
BankConflictInfo(const std::vector<Expr*>& exprs) {
BankConflictInfo(
const std::vector<Expr*>& exprs,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values)
: launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) {
expr_eval_common_.bind("blockIdx.x", 0);
expr_eval_common_.bind("blockIdx.y", 0);
expr_eval_common_.bind("blockIdx.z", 0);
if (launch_params.has_value()) {
expr_eval_common_.bind("blockDim.x", launch_params->bdimx());
expr_eval_common_.bind("blockDim.y", launch_params->bdimy());
expr_eval_common_.bind("blockDim.z", launch_params->bdimz());
expr_eval_common_.bind("gridDim.x", launch_params->gdimx());
expr_eval_common_.bind("gridDim.y", launch_params->gdimy());
expr_eval_common_.bind("gridDim.z", launch_params->gdimz());
}
for (auto pair : known_values) {
expr_eval_common_.bind(pair.first, pair.second);
}
handle(exprs);
}

Expand All @@ -119,11 +255,17 @@ class BankConflictInfo : public kir::IrVisitor {
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(uop->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
uop->in()->as<kir::TensorIndex>(), for_loops_));
uop->in()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (isSmemTensorIndex(uop->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
uop->out()->as<kir::TensorIndex>(), for_loops_));
uop->out()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
Expand All @@ -133,11 +275,17 @@ class BankConflictInfo : public kir::IrVisitor {
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(ldst->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->in()->as<kir::TensorIndex>(), for_loops_));
ldst->in()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (isSmemTensorIndex(ldst->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->out()->as<kir::TensorIndex>(), for_loops_));
ldst->out()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
Expand All @@ -146,11 +294,36 @@ class BankConflictInfo : public kir::IrVisitor {
}

std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_;
c10::optional<LaunchParams> launch_params_;
ExpressionEvaluator expr_eval_common_;
};

} // namespace

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel) {
return BankConflictInfo::get(kernel->topLevelExprs());
kir::Kernel* kernel,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
for (auto pair : known_values) {
TORCH_CHECK(
!isThreadIdx(pair.first),
"threadIdx.{x,y,z} should be computed instead of provided");
TORCH_CHECK(
!isBlockIdx(pair.first),
"blockIdx.{x,y,z} should not be provided (they are always zero)");
TORCH_CHECK(
!isBlockDim(pair.first),
"blockDim.{x,y,z} should be provided by launch_params");
TORCH_CHECK(
!isGridDim(pair.first),
"gridDim.{x,y,z} should be provided by launch_params");
}
if (!launch_params.has_value()) {
launch_params =
InferLaunchParams::get(kernel->topLevelExprs(), known_values);
}
return BankConflictInfo::get(
kernel->topLevelExprs(), launch_params, known_values);
}

} // namespace cuda
Expand Down
22 changes: 11 additions & 11 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>

Expand All @@ -18,27 +20,25 @@ namespace cuda {
// nsight compute. This utility currently has the following assumptions and
// limitations:
//
// 1. This utility assumes that `blockDim.x` is large enough to hold one phase
// 2. This utility assumes that the address only depends on loop variables
// (there can not be a thing like `T0.stride[0]`, `blockDim.x`)
// 3. This utility assumes that the data of the tensor is accessed by
// 1. This utility assumes that the data of the tensor is accessed by
// `T0[index]`, where `index` is the one stored in the `TensorIndex`
// object.
// 4. This utility only checks the first iteration, and the start of all
// loop variables are assumed to be `0` (if we have something like
// 2. This utility only checks the first iteration. If we have something like
// `T1_s[tidx, 5]`, then different iterations should have different
// results, which this utility will not be able to handle all of them now)
// 5. This utility assumes that all tensors are independent, which means:
// 5.1 All shared memory tensors are allocated starting from a multiple of
// conflictions, which will not be evaluated for all of them
// 3. This utility assumes that all tensors are independent, which means:
// 3.1 All shared memory tensors are allocated starting from a multiple of
// 4*32 bytes
// 5.2 The only source of bank confliction is from within a tensor.
// 3.2 The only source of bank confliction is from within a tensor.
// There is no bank conflict between different tensors.
//
// Also note that this utility will not provide accurate estimation if the above
// assumptions are satisfied

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel);
kir::Kernel* kernel,
c10::optional<LaunchParams> launch_params = c10::nullopt,
const std::unordered_map<std::string, IntOrDouble>& known_values = {});

} // namespace cuda
} // namespace fuser
Expand Down
Loading

0 comments on commit f5bca33

Please sign in to comment.