Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MatMul] Pipe through global memory cache ops as additional scheduler options #1978

Merged
merged 5 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 15 additions & 66 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,18 +617,25 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// Utility function to emit a cp.async intrinsic
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();
bool is_cg = ldst->opType() == LoadStoreOpType::CpAsyncCg;

if (is_cg) {
indent() << "Ampere::cpAsyncCg";
} else {
indent() << "Ampere::cpAsync";
}

if (ldst->predicate() == nullptr) {
// Out of line predicate variant
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ");\n";
code_ << "<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ");\n";
} else {
// Inline predicate variant
indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ","
<< genInline(ldst->predicate()) << ");\n";
code_ << "<" << dtype << "," << vec_size << ">("
<< genMaybeHoistedPointer(ldst->out()) << ","
<< genMaybeHoistedPointer(ldst->in()) << ","
<< genInline(ldst->predicate()) << ");\n";
}
}

Expand Down Expand Up @@ -672,26 +679,6 @@ class CudaKernelGenerator : private OptOutConstDispatch {
code_ << "(" << index1 << " == " << index2 << ");\n";
}

void handle(const FullOp* fop) final {
indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")"
<< gen(fop->getFillValue()) << ";\n";
}

void handle(const ARangeOp* aop) final {
auto index =
genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>());
indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
code_ << "(" << index << ", " << gen(aop->start()) << ", "
<< gen(aop->step()) << ");\n";
}

void handle(const EyeOp* aop) final {
auto index1 = gen(aop->getIndex1());
auto index2 = gen(aop->getIndex2());
indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")";
code_ << "(" << index1 << " == " << index2 << ");\n";
}

void handle(const UnaryOp* uop) final {
bool is_vector_op = false;
size_t vector_word_size = 1;
Expand Down Expand Up @@ -951,45 +938,6 @@ class CudaKernelGenerator : private OptOutConstDispatch {
code_ << ");\n";
}

void handle(const RNGOp* rop) final {
// TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
// innermost ID of size 4 (float) or size 2 (double)?
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
int multiple = rop->dtype() == DataType::Double ? 2 : 4;
indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
<< ";\n";
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
<< rop->name() << " / " << multiple << ";\n";
indent() << "nvfuser_index_t rng_component" << rop->name()
<< " = linear_index" << rop->name() << " % " << multiple << ";\n";
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
<< rop->getRNGOffset() << ";\n";
indent() << "if (rng_subseq != rng_subseq" << rop->name()
<< " || rng_offset != rng_offset" << rop->name() << ") {\n";
indent() << " rng_result = philox(philox_args.seed_, rng_subseq"
<< rop->name() << ", philox_offset / 4 + rng_offset" << rop->name()
<< ");\n";
indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n";
indent() << " rng_offset = rng_offset" << rop->name() << ";\n";
indent() << "}\n";
auto op_type = rop->getRNGOpType();
indent() << gen(rop->output(0)) << " = " << op_type;
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
code_ << "f";
}
code_ << "(rng_result, rng_component" << rop->name();
switch (op_type) {
case RNGOpType::UniformRange: {
auto parameters = rop->getParameters();
TORCH_INTERNAL_ASSERT(parameters.size() == 2);
code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
break;
}
default:;
}
code_ << ");\n";
}

std::string genBinaryOp(
BinaryOpType op_type,
DataType data_type,
Expand Down Expand Up @@ -1492,6 +1440,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
genLdMatrix(ldst, vector_word_size);
break;
case LoadStoreOpType::CpAsync:
case LoadStoreOpType::CpAsyncCg:
genCpAsync(ldst, vector_word_size);
break;
default:
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ bool isLdMatrixOp(const Expr* expr) {

bool isCpAsyncOp(const Expr* expr) {
if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) {
return ldst->opType() == LoadStoreOpType::CpAsync;
return ldst->opType() == LoadStoreOpType::CpAsync ||
ldst->opType() == LoadStoreOpType::CpAsyncCg;
}
return false;
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ void validateArchMemoryOp(LoadStoreOp* ldst) {
validateLdMatrixOutput(ldst->out()->as<TensorView>());
return;
case LoadStoreOpType::CpAsync:
case LoadStoreOpType::CpAsyncCg:
validateMinimumArch(8, 0);
return;
default:
Expand Down
92 changes: 92 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,98 @@ DEVICE_INLINE void cpAsync(
"r"((int)predicate));
}

// Global to SMEM load that is asynchronous,
// The cache global variant, i.e. skip L1 caching.
// more details see:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
// not guaranteed to be completed until cpAsyncBarrier() is called.
template <typename dtype, int len>
DEVICE_INLINE void cpAsyncCg(void* smem_ptr, void const* gmem_ptr) {
unsigned smem_addr = util::toSmem(smem_ptr);
constexpr int byte_size = sizeof(dtype) * len;

static_assert(
byte_size == 4 || byte_size == 8 || byte_size == 16,
"cp_async : unsupported byte size");

asm volatile(
"cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr),
"l"(gmem_ptr),
"n"(byte_size));
}

// Global to SMEM load that is asynchronous,
// not guaranteed to be completed until cpAsyncBarrier() is called.
template <typename dtype, int len>
DEVICE_INLINE void cpAsyncCg(
void* smem_ptr,
void const* gmem_ptr,
bool predicate) {
unsigned smem_addr = util::toSmem(smem_ptr);
constexpr int byte_size = sizeof(dtype) * len;

static_assert(
byte_size == 4 || byte_size == 8 || byte_size == 16,
"cp_async : unsupported byte size");

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
"@p cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem_addr),
"l"(gmem_ptr),
"n"(byte_size),
"r"((int)predicate));
}

// cp.async
// This is the variant that supports lifted indexing
template <typename dtype, int len>
DEVICE_INLINE void cpAsyncCg(
nvfuser_index_t smem_index,
unsigned smem_addr,
nvfuser_index_t gmem_index,
DataPointer& gmem_ptr) {
constexpr int byte_size = sizeof(dtype) * len;

static_assert(
byte_size == 4 || byte_size == 8 || byte_size == 16,
"cp_async : unsupported byte size");

asm volatile(
"cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(
smem_addr + (unsigned)smem_index),
"l"(gmem_ptr + gmem_index),
"n"(byte_size));
}

// cp.async
// This is the variant that supports lifted indexing, with predicate inlined.
template <typename dtype, int len>
DEVICE_INLINE void cpAsyncCg(
nvfuser_index_t smem_index,
unsigned smem_addr,
nvfuser_index_t gmem_index,
DataPointer& gmem_ptr,
bool predicate) {
constexpr int byte_size = sizeof(dtype) * len;

static_assert(
byte_size == 4 || byte_size == 8 || byte_size == 16,
"cp_async : unsupported byte size");

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
"@p cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem_addr + (unsigned)smem_index),
"l"(gmem_ptr + gmem_index),
"n"(byte_size),
"r"((int)predicate));
}

// TODO: Might have a different category of sync if we want to build out this:
DEVICE_INLINE void cpAsyncBarrier() {
asm volatile("cp.async.wait_all;");
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ void scheduleMatmul(
// Use cp.async as requested in scheduler params.
c10::optional<LoadStoreOpType> load_op = c10::nullopt;
if (params.async_gmem_load_operands) {
load_op = LoadStoreOpType::CpAsync;
load_op = LoadStoreOpType::CpAsyncCg;
}

acw_smem = ar->cacheAfter(load_op);
Expand Down
50 changes: 50 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3042,6 +3042,56 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) {
}
}

// Matmul test on Ampere using ldmatrix.x4 to load operands
TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 2048;
for (auto layout : kAllSupportedLayout) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeContigTensor(2, DataType::Half);
auto tv1 = makeContigTensor(2, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1, layout);

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 128, 64);
gemm_tile.warp_tile = GemmTile(64, 64, 64);
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

auto mma_builder =
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
.layout(layout);

MatmulParam params(mma_builder);
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
params.double_buffer_options.double_buffer_smem_read = true;
params.double_buffer_options.smem_double_buffer_stage = 3;
scheduleMatmul(tv2, tv0, tv1, params);

at::manual_seed(0);
auto inputs = fp16MatmulAtInput(M, N, K, layout);

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion, {inputs.first, inputs.second}, LaunchParams()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001));
}
}

// Small repro for the replay fix needed for non-affine
// swizzle support.
TEST_F(NVFuserTest, FusionSwizzleReplayFixRepro_CUDA) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ static const char* load_store_type2string(LoadStoreOpType t) {
return "LdMatrixTranspose";
case LoadStoreOpType::CpAsync:
return "CpAsync";
case LoadStoreOpType::CpAsyncCg:
return "CpAsyncCg";
default:
TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type");
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ static constexpr std::array<IdMappingMode, 3> kIdMappingModes = {

// Used to annotate the special memory intrinsics that a loadstore
// op will be lowered to.
enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync };
enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync, CpAsyncCg };

// Used to label what part of the double buffered iterdomain
// a for loop is materializing.
Expand Down