Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MatMul][WIP] MatMul epilog support: share memory reuse #1979

Open
wants to merge 10 commits into
base: misc_minor_codgen_change
Choose a base branch
from
34 changes: 27 additions & 7 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,18 +617,25 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// Utility function to emit a cp.async intrinsic
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();
bool is_cg = ldst->opType() == LoadStoreOpType::CpAsyncCg;

if (is_cg) {
indent() << "Ampere::cpAsyncCg";
} else {
indent() << "Ampere::cpAsync";
}

if (ldst->predicate() == nullptr) {
// Out of line predicate variant
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ");\n";
code_ << "<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ");\n";
} else {
// Inline predicate variant
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ","
<< genInline(ldst->predicate()) << ");\n";
code_ << "<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ","
<< genInline(ldst->predicate()) << ");\n";
}
}

Expand Down Expand Up @@ -1402,6 +1409,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
genLdMatrix(ldst, vector_word_size);
break;
case LoadStoreOpType::CpAsync:
case LoadStoreOpType::CpAsyncCg:
genCpAsync(ldst, vector_word_size);
break;
default:
Expand Down Expand Up @@ -2618,6 +2626,18 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << sync_call << ";\n";
}

void handle(const kir::Deallocate* dealloc) final {
const auto alloc = dealloc->buffer();
const auto tv = alloc->buffer()->as<TensorView>();
const auto size = alloc->size();
const auto buffer_dtype = alloc->buffer()->dtype();

TORCH_INTERNAL_ASSERT(size != nullptr);
indent() << "// de-alloc " << varName(tv) << "\n";
indent() << "offset -= (" << genInline(size) << " * sizeof(" << buffer_dtype
<< "));\n";
}

void handle(const kir::InitMagicZero*) final {
indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
}
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::Allocate:
ptr(handler)->handle(expr->as<kir::Allocate>());
return;
case ExprType::DeAllocate:
ptr(handler)->handle(expr->as<kir::Deallocate>());
return;
case ExprType::BlockSync:
ptr(handler)->handle(expr->as<kir::BlockSync>());
return;
Expand Down Expand Up @@ -331,6 +334,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::Allocate:
ptr(handler)->handle(expr->as<kir::Allocate>());
return;
case ExprType::DeAllocate:
ptr(handler)->handle(expr->as<kir::Deallocate>());
return;
case ExprType::BlockSync:
ptr(handler)->handle(expr->as<kir::BlockSync>());
return;
Expand Down Expand Up @@ -516,6 +522,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::Allocate:
ptr(mutator)->mutate(expr->as<kir::Allocate>());
return;
case ExprType::DeAllocate:
ptr(mutator)->mutate(expr->as<kir::Deallocate>());
return;
case ExprType::BlockSync:
ptr(mutator)->mutate(expr->as<kir::BlockSync>());
return;
Expand Down Expand Up @@ -766,6 +775,9 @@ void OptOutConstDispatch::handle(const ViewOp* stmt) {
void OptOutConstDispatch::handle(const kir::Allocate* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::Deallocate* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::BlockSync* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -913,6 +925,9 @@ void OptOutDispatch::handle(ViewOp* stmt) {
void OptOutDispatch::handle(kir::Allocate* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::Deallocate* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::BlockSync* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class TensorIndex;
class IntPair;

class Allocate;
class Deallocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
Expand Down Expand Up @@ -162,6 +163,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const ViewOp* stmt);

virtual void handle(const kir::Allocate*);
virtual void handle(const kir::Deallocate*);
virtual void handle(const kir::BlockSync*);
virtual void handle(const kir::GridSync*);
virtual void handle(const kir::CpAsyncWait*);
Expand Down Expand Up @@ -226,6 +228,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(ViewOp* stmt);

virtual void handle(kir::Allocate* stmt);
virtual void handle(kir::Deallocate* stmt);
virtual void handle(kir::BlockSync* stmt);
virtual void handle(kir::GridSync* stmt);
virtual void handle(kir::CpAsyncWait* stmt);
Expand Down Expand Up @@ -331,6 +334,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(ViewOp*);

virtual void mutate(kir::Allocate*);
virtual void mutate(kir::Deallocate*);
virtual void mutate(kir::BlockSync*);
virtual void mutate(kir::GridSync*);
virtual void mutate(kir::CpAsyncWait*);
Expand Down
45 changes: 42 additions & 3 deletions torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,46 @@ uint64_t FusionExecutor::computeSharedMemory(
return total;
}

uint64_t FusionExecutor::computeDynamicSharedMemory(
kir::ExpressionEvaluator& expr_eval,
const kir::AllocateRecord& records,
uint64_t total) {
FUSER_PERF_SCOPE("computeDynamicSharedMemory");
uint64_t max_total = total;
int N = records.actions.size();

for (int idx : c10::irange(N)) {
auto size = records.sizes[idx];
auto alloc = records.allocs[idx];

// If this buffer aliases another buffer,
// then do not allocate memory for this buffer.
const auto inferred_val = expr_eval.evaluate(size);
TORCH_INTERNAL_ASSERT(
inferred_val.has_value(),
"Failed to evaluate the size ",
size,
" of shared memory buffer ",
alloc->buffer()->toString());

const uint64_t data_size = dataTypeSize(alloc->buffer()->dtype());
const uint64_t buffer_size = data_size * inferred_val.value();

if (records.actions[idx] == kir::SmemAllocAction::Allocate) {
// Allocate:
const int align_size = 16; // always align to 16B/128b.
total = ceilDiv(total, align_size) * align_size;
total += buffer_size;
max_total = std::max(total, max_total);
} else {
// Deallocate:
total -= buffer_size;
}
}

return max_total;
}

LaunchParams FusionExecutor::computeLaunchParams(
const LaunchParams& launch_constraints,
kir::ExpressionEvaluator& expr_eval,
Expand Down Expand Up @@ -615,10 +655,9 @@ LaunchParams FusionExecutor::computeLaunchParams(
}

// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const uint64_t dynamic_smem_size = computeSharedMemory(
const uint64_t dynamic_smem_size = computeDynamicSharedMemory(
expr_eval,
kernel_summary.dynamic_smem_allocations,
true,
kernel_summary.allocation_record,
reduction_broadcast_workspace);

// Check that requested smem size can be dynamically allocated.
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
Expand Down Expand Up @@ -179,6 +180,11 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
bool align_padding = false,
uint64_t total = 0);

uint64_t computeDynamicSharedMemory(
kir::ExpressionEvaluator& expr_eval,
const kir::AllocateRecord& records,
uint64_t total = 0);

// return a pair of vector of tensors, where tensors in the first vector are
// not initialized, while the second vector contains zero-initiliazed tensors
GlobalBuffers allocGlobalVals(kir::ExpressionEvaluator& expr_eval);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,12 @@ void IrPrinter::handle(const kir::GridSync* node) {
os_ << ")\n";
}

void IrPrinter::handle(const kir::Deallocate* node) {
indent() << "DeAllocate(";
handle(node->buffer());
os_ << ")\n";
}

void IrPrinter::handle(const kir::ForLoop* node) {
indent() << "FOR ";
handle(node->index());
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
void handle(const kir::ForLoop*) final;
void handle(const kir::IfThenElse*) final;
void handle(const kir::Allocate*) final;
void handle(const kir::Deallocate*) final;
void handle(const kir::BlockSync*) final;
void handle(const kir::GridSync*) final;
void handle(const kir::CpAsyncWait*) final;
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,11 +1402,11 @@ void IterDomain::parallelize(ParallelType t) {
// to make copies of the iterdomains. We might eventually just want
// to lock these parallel types and not allowing any changes once
// they are swizzled.
TORCH_CHECK(
t == ParallelType::Vectorize || t == ParallelType::TIDx ||
t == ParallelType::Serial || t == ParallelType::Mma,
"Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids",
t);
// TORCH_CHECK(
// t == ParallelType::Vectorize || t == ParallelType::TIDx ||
// t == ParallelType::Serial || t == ParallelType::Mma,
// "Parallel type other than serial, tidx, vectorize not allowed for mma
// swizzled ids", t);
}

parallel_type_ = t;
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class KernelIrScanner : private IrVisitor {
break;
case MemoryType::Shared:
summary_.dynamic_smem_allocations.push_back(allocate);
if (!allocate->alias()) {
summary_.allocation_record.actions.push_back(
SmemAllocAction::Allocate);
summary_.allocation_record.sizes.push_back(allocate->size());
summary_.allocation_record.allocs.push_back(allocate);
}
break;
case MemoryType::Local:
if (!ExpressionEvaluator::isConst(allocate->size())) {
Expand All @@ -80,6 +86,12 @@ class KernelIrScanner : private IrVisitor {
}
}

void handle(kir::Deallocate* deallocate) final {
summary_.allocation_record.actions.push_back(SmemAllocAction::DeAllocate);
summary_.allocation_record.sizes.push_back(deallocate->buffer()->size());
summary_.allocation_record.allocs.push_back(deallocate->buffer());
}

void handle(UnaryOp* unary_op) final {
if (unary_op->getUnaryOpType() == UnaryOpType::RandLike) {
summary_.max_rng_offsets =
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ namespace fuser {
namespace cuda {
namespace kir {

// Smem allocation-deallocation trace:
enum class SmemAllocAction { Allocate = 0, DeAllocate };

struct AllocateRecord {
// Vector of allocate or deallocate actions in a kernel
std::vector<SmemAllocAction> actions;

// Vector of sizes in corresponding allocations.
std::vector<Val*> sizes;

// Vector of alloc expressions.
std::vector<const kir::Allocate*> allocs;
};

//! Summary of interesting facts about the kernel
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct KernelSummary {
Expand Down Expand Up @@ -96,6 +110,9 @@ struct KernelSummary {

//! Track information on vectorized set operations for runtime validation
std::vector<VectorizedSetInfo> vectorized_set_info;

//! Records allocation and deallocation of dynamic shared memory space.
AllocateRecord allocation_record;
};

class TORCH_CUDA_CU_API KernelPerformanceProfile {
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,19 @@ Allocate::Allocate(
"IR type only valid for Kernel container.");
}

Deallocate::Deallocate(IrBuilderPasskey passkey, Allocate* buffer)
: Expr(passkey, ExprType::DeAllocate), buffer_(buffer) {
TORCH_INTERNAL_ASSERT(
buffer->buffer()->isA<TensorView>() &&
// Only shared mem is supported for now.
buffer->buffer()->as<TensorView>()->getMemoryType() ==
MemoryType::Shared);

TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}

GridReduction::GridReduction(
IrBuilderPasskey passkey,
BinaryOpType reduction_op_type,
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ class TORCH_CUDA_CU_API Allocate final : public Expr {
const Allocate* alias_ = nullptr;
};

//! Deallocate a space that has been occupied by a buffer.
class TORCH_CUDA_CU_API Deallocate final : public Expr {
public:
explicit Deallocate(IrBuilderPasskey passkey, Allocate* buffer);

auto buffer() const {
return buffer_;
}

private:
const Allocate* buffer_ = nullptr;
};

// Sync represents __syncthreads barrier for block level coordination.
//
// TODO(kir): change name to SyncThreads as we could have other barriers.
Expand Down
Loading