Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bank conflict checker improvements #2032

Merged
merged 8 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) {
}
}

void ExpressionEvaluator::bind(
const std::string& name,
const IntOrDouble& concrete_value) {
known_named_scalars_[name] = concrete_value;
}

c10::optional<IntOrDouble> ExpressionEvaluator::evaluate(Val* value) {
if (evaluator_precomputed_values_ != nullptr) {
return toOptionalIntOrDouble(
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <c10/util/Optional.h>

#include <string>
#include <unordered_map>

namespace torch {
Expand All @@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch {
//! Bind a concrete value to an IR variable
void bind(Val* value, const IntOrDouble& concrete_value);

//! Bind a concrete value to a named scalar
void bind(const std::string& name, const IntOrDouble& concrete_value);

//! Try to evaluate a Fusion IR value
c10::optional<IntOrDouble> evaluate(Val* value);

Expand Down
207 changes: 190 additions & 17 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/lower_bank_conflict.h>

#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
Expand Down Expand Up @@ -48,23 +49,78 @@ inline int64_t getPhaseSize(int64_t word_size_bytes) {
return 32;
}

bool isThreadIdx(const std::string& name) {
return name == "threadIdx.x" || name == "threadIdx.y" ||
name == "threadIdx.z";
}

bool isBlockIdx(const std::string& name) {
return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z";
}

bool isBlockDim(const std::string& name) {
return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z";
}

bool isGridDim(const std::string& name) {
return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z";
}

ParallelType getParallelType(const std::string& name) {
if (name == "threadIdx.x") {
return ParallelType::TIDx;
} else if (name == "threadIdx.y") {
return ParallelType::TIDy;
} else if (name == "threadIdx.z") {
return ParallelType::TIDz;
} else if (name == "blockIdx.x") {
return ParallelType::BIDx;
} else if (name == "blockIdx.y") {
return ParallelType::BIDy;
} else if (name == "blockIdx.z") {
return ParallelType::BIDz;
}
TORCH_INTERNAL_ASSERT(false, "Not a parallel type");
}

std::vector<int64_t> evaluateAddressesOnFirstPhase(
kir::TensorIndex* ti,
const std::vector<kir::ForLoop*>& for_loops) {
const std::vector<kir::ForLoop*>& for_loops,
c10::optional<LaunchParams> launch_params,
const ExpressionEvaluator& expr_eval_common) {
std::vector<int64_t> addresses;
const auto word_size_bytes =
dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti);
int64_t phase_size = getPhaseSize(word_size_bytes);

for (auto tidx : c10::irange(phase_size)) {
if (launch_params.has_value()) {
phase_size = std::min<int64_t>(phase_size, launch_params->nThreads());
}

for (int64_t linear_tidx : c10::irange(phase_size)) {
int64_t tidx = linear_tidx;
int64_t tidy = 0;
int64_t tidz = 0;
if (launch_params.has_value()) {
tidy = tidx / launch_params->bdimx();
tidx = tidx % launch_params->bdimx();
tidz = tidy / launch_params->bdimy();
tidy = tidy % launch_params->bdimy();
}
int64_t index = 0;
ExpressionEvaluator expr_eval(ti->fusion());
// make a copy of the expression evaluator
ExpressionEvaluator expr_eval = expr_eval_common;
expr_eval.bind("threadIdx.x", tidx);
expr_eval.bind("threadIdx.y", tidy);
expr_eval.bind("threadIdx.z", tidz);
for (auto fl : for_loops) {
if (fl->index()->isA<NamedScalar>() &&
fl->index()->as<NamedScalar>()->name() == "threadIdx.x") {
expr_eval.bind(fl->index(), tidx);
if (fl->index()->isA<NamedScalar>()) {
auto name = fl->index()->as<NamedScalar>()->name();
TORCH_INTERNAL_ASSERT(
isThreadIdx(name) || isBlockIdx(name), "unknow loop index");
} else {
expr_eval.bind(fl->index(), 0);
auto start = expr_eval.evaluate(fl->start())->as<int64_t>();
expr_eval.bind(fl->index(), start);
}
}
for (auto ind : ti->indices()) {
Expand All @@ -89,17 +145,97 @@ int getConflictWays(const std::vector<int64_t>& addresses) {
return conflict;
}

} // namespace
class InferLaunchParams : public kir::IrVisitor {
public:
static c10::optional<LaunchParams> get(
const std::vector<Expr*>& exprs,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
if (exprs.empty()) {
return c10::nullopt;
}
return InferLaunchParams(exprs, known_values).launch_params_;
}

private:
InferLaunchParams(
const std::vector<Expr*>& exprs,
const std::unordered_map<std::string, IntOrDouble>& known_values)
: expr_eval_(exprs[0]->fusion()) {
for (auto pair : known_values) {
expr_eval_.bind(pair.first, pair.second);
}
handle(exprs);
}

using kir::IrVisitor::handle;

void handle(Expr* expr) final {
if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
kir::IrVisitor::handle(expr);
return;
}

for (auto fl : for_loops_) {
if (fl->index()->isA<NamedScalar>()) {
auto name = fl->index()->as<NamedScalar>()->name();
if (isThreadIdx(name) || isBlockIdx(name)) {
auto ptype = getParallelType(name);
auto stop = expr_eval_.evaluate(fl->stop());
if (stop.has_value()) {
if (!launch_params_.has_value()) {
launch_params_ = LaunchParams();
}
if (launch_params_->getRawVal(ptype) ==
LaunchParams::UNINITIALIZED_VAL) {
launch_params_->bind(stop->as<int64_t>(), ptype);
} else {
TORCH_INTERNAL_ASSERT(
launch_params_->getDim(ptype) == stop,
"Unable to infer launch parameters");
}
}
}
}
}
}

ExpressionEvaluator expr_eval_;
c10::optional<LaunchParams> launch_params_;
};

class BankConflictInfo : public kir::IrVisitor {
public:
static std::unordered_map<const Expr*, std::pair<int, int>> get(
const std::vector<Expr*>& exprs) {
return BankConflictInfo(exprs).bank_conflict_info_;
const std::vector<Expr*>& exprs,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
if (exprs.empty()) {
return {};
}
return BankConflictInfo(exprs, launch_params, known_values)
.bank_conflict_info_;
}

private:
BankConflictInfo(const std::vector<Expr*>& exprs) {
BankConflictInfo(
const std::vector<Expr*>& exprs,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values)
: launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) {
expr_eval_common_.bind("blockIdx.x", 0);
expr_eval_common_.bind("blockIdx.y", 0);
expr_eval_common_.bind("blockIdx.z", 0);
if (launch_params.has_value()) {
expr_eval_common_.bind("blockDim.x", launch_params->bdimx());
expr_eval_common_.bind("blockDim.y", launch_params->bdimy());
expr_eval_common_.bind("blockDim.z", launch_params->bdimz());
expr_eval_common_.bind("gridDim.x", launch_params->gdimx());
expr_eval_common_.bind("gridDim.y", launch_params->gdimy());
expr_eval_common_.bind("gridDim.z", launch_params->gdimz());
}
for (auto pair : known_values) {
expr_eval_common_.bind(pair.first, pair.second);
}
handle(exprs);
}

Expand All @@ -119,11 +255,17 @@ class BankConflictInfo : public kir::IrVisitor {
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(uop->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
uop->in()->as<kir::TensorIndex>(), for_loops_));
uop->in()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (isSmemTensorIndex(uop->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
uop->out()->as<kir::TensorIndex>(), for_loops_));
uop->out()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
Expand All @@ -133,11 +275,17 @@ class BankConflictInfo : public kir::IrVisitor {
std::pair<int, int> conflict_ways{0, 0};
if (isSmemTensorIndex(ldst->in())) {
conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->in()->as<kir::TensorIndex>(), for_loops_));
ldst->in()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (isSmemTensorIndex(ldst->out())) {
conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase(
ldst->out()->as<kir::TensorIndex>(), for_loops_));
ldst->out()->as<kir::TensorIndex>(),
for_loops_,
launch_params_,
expr_eval_common_));
}
if (conflict_ways.first > 1 || conflict_ways.second > 1) {
bank_conflict_info_[expr] = conflict_ways;
Expand All @@ -146,11 +294,36 @@ class BankConflictInfo : public kir::IrVisitor {
}

std::unordered_map<const Expr*, std::pair<int, int>> bank_conflict_info_;
c10::optional<LaunchParams> launch_params_;
ExpressionEvaluator expr_eval_common_;
};

} // namespace

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel) {
return BankConflictInfo::get(kernel->topLevelExprs());
kir::Kernel* kernel,
c10::optional<LaunchParams> launch_params,
const std::unordered_map<std::string, IntOrDouble>& known_values) {
for (auto pair : known_values) {
TORCH_CHECK(
!isThreadIdx(pair.first),
"threadIdx.{x,y,z} should be computed instead of provided");
TORCH_CHECK(
!isBlockIdx(pair.first),
"blockIdx.{x,y,z} should not be provided (they are always zero)");
TORCH_CHECK(
!isBlockDim(pair.first),
"blockDim.{x,y,z} should be provided by launch_params");
TORCH_CHECK(
!isGridDim(pair.first),
"gridDim.{x,y,z} should be provided by launch_params");
}
if (!launch_params.has_value()) {
launch_params =
InferLaunchParams::get(kernel->topLevelExprs(), known_values);
}
return BankConflictInfo::get(
kernel->topLevelExprs(), launch_params, known_values);
}

} // namespace cuda
Expand Down
22 changes: 11 additions & 11 deletions torch/csrc/jit/codegen/cuda/lower_bank_conflict.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>

Expand All @@ -18,27 +20,25 @@ namespace cuda {
// nsight compute. This utility currently has the following assumptions and
// limitations:
//
// 1. This utility assumes that `blockDim.x` is large enough to hold one phase
// 2. This utility assumes that the address only depends on loop variables
// (there can not be a thing like `T0.stride[0]`, `blockDim.x`)
// 3. This utility assumes that the data of the tensor is accessed by
// 1. This utility assumes that the data of the tensor is accessed by
// `T0[index]`, where `index` is the one stored in the `TensorIndex`
// object.
// 4. This utility only checks the first iteration, and the start of all
// loop variables are assumed to be `0` (if we have something like
// 2. This utility only checks the first iteration. If we have something like
// `T1_s[tidx, 5]`, then different iterations should have different
// results, which this utility will not be able to handle all of them now)
// 5. This utility assumes that all tensors are independent, which means:
// 5.1 All shared memory tensors are allocated starting from a multiple of
// conflictions, which will not be evaluated for all of them
// 3. This utility assumes that all tensors are independent, which means:
// 3.1 All shared memory tensors are allocated starting from a multiple of
// 4*32 bytes
// 5.2 The only source of bank confliction is from within a tensor.
// 3.2 The only source of bank confliction is from within a tensor.
// There is no bank conflict between different tensors.
//
// Also note that this utility will not provide accurate estimation if the above
// assumptions are satisfied

std::unordered_map<const Expr*, std::pair<int, int>> getBankConflictInfo(
kir::Kernel* kernel);
kir::Kernel* kernel,
c10::optional<LaunchParams> launch_params = c10::nullopt,
const std::unordered_map<std::string, IntOrDouble>& known_values = {});

} // namespace cuda
} // namespace fuser
Expand Down
Loading