Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More generic grouped grid reduction kernel #1740

Merged
merged 8 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1529,16 +1529,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
const auto sync_buffer =
grouped_grop->sync_buffer()->buffer()->as<TensorView>();

if (grouped_grop->isAllreduce()) {
generateGroupedGridAllreduce(grouped_grop);
return;
}

TORCH_INTERNAL_ASSERT(
grouped_grop->numReductions() == 2,
"Only grouping of 2 reductions is supported. ",
grouped_grop->toString());

if (grouped_grop->isAllreduce()) {
generateGridAllreduce(grouped_grop);
return;
}

const std::string flags_str = generateGridReduceTemplateFlags2(
grouped_grop, grouped_grop->threadPredicate());

Expand All @@ -1553,7 +1553,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

ArgumentBuilder func_args(block_nest_level_ + 1, kTab);

// Apped arguments for each reduction
// Append arguments for each reduction
for (const auto i : c10::irange(grouped_grop->numReductions())) {
TORCH_INTERNAL_ASSERT(
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
Expand Down Expand Up @@ -1596,41 +1596,77 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

void generateGridAllreduce(const kir::GroupedGridReduction* grouped_grop) {
void generateGroupedGridAllreduce(
const kir::GroupedGridReduction* grouped_grop) {
TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce());

// First, build a list of function arguments
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
constexpr int max_num_reductions = 8;
TORCH_INTERNAL_ASSERT(
grouped_grop->numReductions() <= max_num_reductions,
"Too many grouped reductions: ",
grouped_grop->toString(),
". Up to ",
max_num_reductions,
" reductions are allowed.");

ArgumentBuilder types;
ArgumentBuilder outputs;
ArgumentBuilder inputs;
ArgumentBuilder work_bufs;
ArgumentBuilder init_vals;
ArgumentBuilder reduction_ops;

ArgumentBuilder bool_types;
ArgumentBuilder read_preds;
ArgumentBuilder write_preds;

for (const auto i : c10::irange(grouped_grop->numReductions())) {
const auto data_type = grouped_grop->outputs().at(i)->dtype();
TORCH_INTERNAL_ASSERT(
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());

types.arg(data_type);

// out
func_args.arg(
genCall("RefTuple", data_type, gen(grouped_grop->outputs().at(i))));
outputs.arg(gen(grouped_grop->outputs().at(i)));

// inp
func_args.arg(genCall(
"ConstRefTuple", data_type, gen(grouped_grop->inputs().at(i))));
inputs.arg(gen(grouped_grop->inputs().at(i)));

// global_work_buffer
const auto work_buffer =
grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>();
func_args.arg(genCall(
"VolatilePtrTuple", data_type, "&" + varName(work_buffer) + "[0]"));
work_bufs.arg("&").append(varName(work_buffer)).append("[0]");

// init
func_args.arg(genCall(
"LocalTuple", data_type, genInline(grouped_grop->initVal(i))));
init_vals.arg(genInline(grouped_grop->initVal(i)));

// reduction op
func_args.arg(genReductionOp(
reduction_ops.arg(genReductionOp(
grouped_grop->getReductionOpType(i),
grouped_grop->output(i)->dtype()));

// read and write predicates
bool_types.arg("bool");
// Same argument for all inputs. Different predicates would be
// used when grouping is done across iterations
TORCH_INTERNAL_ASSERT(
grouped_grop->predicate() != nullptr &&
grouped_grop->predicate()->hasValue());
const auto read_pred = genInline(grouped_grop->predicate());
read_preds.arg(read_pred);
if (grouped_grop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
write_preds.arg(genInline(grouped_grop->writePredicate()));
} else {
write_preds.arg(read_pred);
}
}

ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
func_args.arg(genCall("RefTuple", types, outputs));
func_args.arg(genCall("ConstRefTuple", types, inputs));
func_args.arg(genCall("VolatilePtrTuple", types, work_bufs));
func_args.arg(genCall("LocalTuple", types, init_vals));

// global_sync_buffer
const auto sync_buffer =
grouped_grop->sync_buffer()->buffer()->as<TensorView>();
Expand All @@ -1639,21 +1675,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// shared_buf
func_args.arg("shared_mem");

// read and write predicates
TORCH_INTERNAL_ASSERT(
grouped_grop->predicate() != nullptr &&
grouped_grop->predicate()->hasValue());
const auto read_pred = genInline(grouped_grop->predicate());
func_args.arg(read_pred);
if (grouped_grop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
func_args.arg(genInline(grouped_grop->writePredicate()));
} else {
func_args.arg(read_pred);
}
func_args.arg(genCall("LocalTuple", bool_types, read_preds));
func_args.arg(genCall("LocalTuple", bool_types, write_preds));

addProfileArguments(func_args, grouped_grop);

func_args.arg(reduction_ops);

indent() << genFusedReductionName(ir_utils::getTvOutput(grouped_grop))
<< ".reduceGroup(\n";
indent() << kTab << func_args << ");\n";
Expand Down
Loading