Skip to content

Commit

Permalink
Grouped grid welford (#1921)
Browse files Browse the repository at this point in the history
Enables grouping of grid welford ops across iterations. Same functionality as the iteration grouping for GridReduction. This ins intended to improve the outer-norm grid persistence in batchnorm-like fusions.
  • Loading branch information
naoyam committed Aug 23, 2022
1 parent 6cf7eb0 commit 20cf109
Show file tree
Hide file tree
Showing 29 changed files with 2,650 additions and 338 deletions.
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,12 @@ WelfordResult Welford(
out_avg,
out_var,
out_N, /*out var/avg/count */
tv, /*in var/avg/count */
FusionGuard::getCurFusion()->zeroVal(),
FusionGuard::getCurFusion()->oneVal(),
init_avg_val,
init_var_val,
init_N, /*init var/avg/count */
tv,
FusionGuard::getCurFusion()->zeroVal(),
FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */
init_N); /*init var/avg/count */

return WelfordResult(out_avg, out_var, out_N);
}
Expand Down
168 changes: 167 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

void handle(const kir::GroupedGridWelford* grouped_gwop) final {
if (grouped_gwop->isAllreduce()) {
generateGroupedGridAllreduceWelford(grouped_gwop);
return;
} else {
TORCH_INTERNAL_ASSERT(
false, "Non-allreduce grouped grid welford is not yet supported");
}
}

// Enumerates all combinations of index values of grouped
// loops. Each combination is a vector of loop index values. The
// length of the vector is the number of grouped loops.
Expand Down Expand Up @@ -1872,6 +1882,154 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

// Mostly the same as the grouped grid redution version
void generateGroupedGridAllreduceWelford(
const kir::GroupedGridWelford* grouped_gwop) {
TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce());

const auto index_replacement_maps = getLoopIndexReplacementMaps();
const auto num_grouped_iterations = index_replacement_maps.size();

// This is also checked at the lowering validaiton time, so it
// isn't strictly necessary.
TORCH_INTERNAL_ASSERT(
num_grouped_iterations * grouped_gwop->numExprs() <=
kMaxNumGroupedReductions,
"Too many grouped reductions: ",
grouped_gwop->toString(),
". Up to ",
kMaxNumGroupedReductions,
" reductions are allowed.");

ArgumentBuilder data_types;
ArgumentBuilder index_types;

// Note that the data type of var and avg and that of N are the
// same with all the welford ops since we only support
// grouping of iterations.
const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype();
const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype();

std::array<ArgumentBuilder, 3> out_args;
std::array<ArgumentBuilder, 3> in_args;
std::array<ArgumentBuilder, 3> init_args;
std::array<ArgumentBuilder, 3> work_bufs;

ArgumentBuilder bool_types;
ArgumentBuilder read_preds;
ArgumentBuilder write_preds;

for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
const auto& output = grouped_gwop->outputVals().at(expr_index);
const auto& input = grouped_gwop->inputVals().at(expr_index);
const auto& init = grouped_gwop->initVals().at(expr_index);

for (const auto& group_index :
c10::irange(index_replacement_maps.size())) {
// Set the index replacement map with the concrete values of
// indices of grouped loops.
index_replacement_map_ = index_replacement_maps.at(group_index);

data_types.arg(data_type);
index_types.arg(index_type);

auto work_buffer_offset = group_index == 0
? "0"
: (genInline(grouped_gwop->buffer_stride()) + " * " +
std::to_string(group_index));

// Setup arguments for avg, var, and N
for (const auto i : c10::irange(3)) {
out_args[i].arg(gen(output.get(i)));
in_args[i].arg(gen(input.get(i)));
init_args[i].arg(gen(init.get(i)));
const auto work_buffer = grouped_gwop->reduction_buffers()[i]
.at(expr_index)
->buffer()
->as<TensorView>();
work_bufs[i]
.arg("&")
.append(varName(work_buffer))
.append("[")
.append(work_buffer_offset)
.append("]");
}

// read and write predicates
bool_types.arg("bool");
// Same argument for all inputs. Different predicates would be
// used when grouping is done across iterations
TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr);
TORCH_INTERNAL_ASSERT(
grouped_gwop->predicate() != nullptr &&
grouped_gwop->predicate()->hasValue());
const auto read_pred = genInline(grouped_gwop->predicate());
read_preds.arg(read_pred);
if (grouped_gwop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue());
write_preds.arg(genInline(grouped_gwop->writePredicate()));
} else {
write_preds.arg(read_pred);
}

index_replacement_map_.clear();
}
}

ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
// output
func_args.arg(genCall("RefTuple", data_types, out_args[0]));
func_args.arg(genCall("RefTuple", data_types, out_args[1]));
func_args.arg(genCall("RefTuple", index_types, out_args[2]));
// input
func_args.arg(genCall("ConstRefTuple", data_types, in_args[0]));
func_args.arg(genCall("ConstRefTuple", data_types, in_args[1]));
func_args.arg(genCall("ConstRefTuple", index_types, in_args[2]));
// init
func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
// work buffer
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2]));
// global_sync_buffer
const auto sync_buffer =
grouped_gwop->sync_buffer()->buffer()->as<TensorView>();
func_args.arg("&").append(varName(sync_buffer)).append("[0]");

// shared_buf
ArgumentBuilder smem_buffer_args;
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
func_args.arg(genCall(
"PtrTuple",
ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type),
smem_buffer_args));

func_args.arg(genCall("LocalTuple", bool_types, read_preds));
func_args.arg(genCall("LocalTuple", bool_types, write_preds));

addProfileArguments(func_args, grouped_gwop);

ArgumentBuilder func_template_args;
func_template_args.arg(
grouped_gwop->numExprs() * index_replacement_maps.size());
func_template_args.arg(data_type);
func_template_args.arg(index_type);

indent() << genCall(
genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) +
".welfordGroup",
func_template_args,
func_args)
<< ";\n";
}

void handle(const kir::GridBroadcast* grop) final {
const auto bop = grop->broadcast_op();
TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
Expand Down Expand Up @@ -2208,6 +2366,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const GroupedWelfordOp* grouped_wop) final {
TORCH_INTERNAL_ASSERT(
false,
"Should not reach here as grouped welford is only enabled for grid welford,",
" which is handled by its own handler");
}

//! True if loop is grouped. The IterDomain of the loop must have
//! ParallelType::Group, but it isn't sufficient as the loop may be
//! for an initialization expression, for which the loop shold not
Expand All @@ -2216,7 +2381,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
return false;
}
return ExprFinder::exists(loop, {ExprType::GroupedGridReduction});
return ExprFinder::exists(
loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
}

void handle(const kir::ForLoop* loop) final {
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -190,6 +193,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::GridWelford:
ptr(handler)->handle(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -287,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -364,6 +373,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::GridWelford:
ptr(handler)->handle(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -469,6 +481,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::WelfordOp:
ptr(mutator)->mutate(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(mutator)->mutate(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -546,6 +561,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::GridWelford:
ptr(mutator)->mutate(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -716,6 +734,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -793,6 +814,9 @@ void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) {
void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -860,6 +884,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) {
void OptOutDispatch::handle(WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(GroupedWelfordOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(LoadStoreOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -937,6 +964,9 @@ void OptOutDispatch::handle(kir::GridBroadcast* stmt) {
void OptOutDispatch::handle(kir::GridWelford* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TernaryOp;
class ReductionOp;
class GroupedReductionOp;
class WelfordOp;
class GroupedWelfordOp;
class LoadStoreOp;
class MmaOp;
class BroadcastOp;
Expand Down Expand Up @@ -105,6 +106,7 @@ class GridReduction;
class GroupedGridReduction;
class GridBroadcast;
class GridWelford;
class GroupedGridWelford;
class AllocateFusedReduction;
class InitMagicZero;
class UpdateMagicZero;
Expand Down Expand Up @@ -146,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const ReductionOp* stmt);
virtual void handle(const GroupedReductionOp* stmt);
virtual void handle(const WelfordOp* stmt);
virtual void handle(const GroupedWelfordOp* stmt);
virtual void handle(const LoadStoreOp* stmt);
virtual void handle(const MmaOp* stmt);
virtual void handle(const BroadcastOp* stmt);
Expand Down Expand Up @@ -173,6 +176,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::GroupedGridReduction*);
virtual void handle(const kir::GridBroadcast*);
virtual void handle(const kir::GridWelford*);
virtual void handle(const kir::GroupedGridWelford*);
virtual void handle(const kir::AllocateFusedReduction*);
virtual void handle(const kir::Swizzle2DInt*);
virtual void handle(const kir::PairSelect*);
Expand Down Expand Up @@ -209,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(ReductionOp* stmt);
virtual void handle(GroupedReductionOp* stmt);
virtual void handle(WelfordOp* stmt);
virtual void handle(GroupedWelfordOp* stmt);
virtual void handle(LoadStoreOp* stmt);
virtual void handle(MmaOp* stmt);
virtual void handle(BroadcastOp* stmt);
Expand Down Expand Up @@ -236,6 +241,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::GroupedGridReduction* stmt);
virtual void handle(kir::GridBroadcast* stmt);
virtual void handle(kir::GridWelford* stmt);
virtual void handle(kir::GroupedGridWelford* stmt);
virtual void handle(kir::AllocateFusedReduction* stmt);
virtual void handle(kir::Swizzle2DInt* stmt);
virtual void handle(kir::PairSelect* stmt);
Expand Down Expand Up @@ -313,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
virtual void mutate(WelfordOp*);
virtual void mutate(GroupedWelfordOp*);
virtual void mutate(LoadStoreOp*);
virtual void mutate(MmaOp*);
virtual void mutate(BroadcastOp*);
Expand Down Expand Up @@ -340,6 +347,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::GroupedGridReduction*);
virtual void mutate(kir::GridBroadcast*);
virtual void mutate(kir::GridWelford*);
virtual void mutate(kir::GroupedGridWelford*);
virtual void mutate(kir::AllocateFusedReduction*);
virtual void mutate(kir::Swizzle2DInt*);
virtual void mutate(kir::PairSelect*);
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,13 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
"what can be resident on the GPU at once. Need: ",
launch_params_.gdimx() * launch_params_.gdimy() *
launch_params_.gdimz(),
" but limited to ",
" (",
launch_params_.gdimx(),
" * ",
launch_params_.gdimy(),
" * ",
launch_params_.gdimz(),
") but limited to ",
num_blocks_per_SM,
" * ",
at::cuda::getDeviceProperties(options_.device.index())
Expand Down
Loading

0 comments on commit 20cf109

Please sign in to comment.