Skip to content

Commit

Permalink
Heuristic Tuning for CK FP8 Grouped Gemm (#3356)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3356

X-link: facebookresearch/FBGEMM#448

This diff gets our CK grouped gemm kernel with fused fp8 rowwise scaling production ready. The major improvements made are:
1. Added a ton of new kernel configurations.
2. Uses heuristic dispatch to select efficient kernels across many relevant shapes.
3. Makes this kernel compatible with cuda graphs.

Cuda graph compatibility in particular turned out to be very tricky. Grouped gemm kernels on both AMD and NV require a special kernel arguments tensor that has to be set on the GPU. Normally this would be done with a host to device copy, but that is not allowed in a cuda graph. Instead, we need to launch a kernel that directly sets the memory on device.

What makes it tricky is that no host memory is allowed, including the shapes of the tensors. We instead have to launch one kernel for each group. This does mean that there will be a bunch of extra kernel launches, but in my testing they dont seem to be expensive. The only alternative to this approach is doing comprehensive memory planning for all layers, which is what [TensorRT LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu#L816) does.

Adding memory planning would complicate end to end integration by quite a bit, so I think this multi-kernel launch approach is a good balance of performance and simplicity.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D65634843

fbshipit-source-id: 807d05334937533e9d54c75d3bd3c6c62d78a672
  • Loading branch information
jwfromm authored and facebook-github-bot committed Nov 13, 2024
1 parent f12bfd0 commit 2146145
Show file tree
Hide file tree
Showing 76 changed files with 5,107 additions and 1,183 deletions.
42 changes: 6 additions & 36 deletions fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,6 @@
from .quantize_ops import FP8RowwiseGroupedGemm


grouped_kernel_registry: list[str] = [
"fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2",
"fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2",
"fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2",
"fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2",
"fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2",
"fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1",
"fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3",
"fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4",
"fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3",
"fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3",
"fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3",
"fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3",
"fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3",
"fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4",
"fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3",
"fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2",
"fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1",
"fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2",
"fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2",
"fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2",
]


def main(args: Any):
# Extract and format shape arguments.
M = [int(m) for m in args.M.strip().split(",")]
Expand All @@ -62,12 +33,12 @@ def main(args: Any):
quantized_vals = group_gemm_op.quantize(A, B)
# Iterate over kernels to find the most performant one.
benchmark_results = []
for kernel_name in grouped_kernel_registry:
for kernel_name in torch.ops.fbgemm.get_f8f8bf16_rowwise_grouped_kernels():
# Do a warmup run of the kernel.
output = group_gemm_op.compute(*quantized_vals, kernel_name=kernel_name)
# Benchmark this kernel implementation.
ms_runtime = group_gemm_op.benchmark(
*quantized_vals, use_cuda_graph=False, kernel_name=kernel_name
*quantized_vals, use_cuda_graph=True, kernel_name=kernel_name
)
# Compute statistics for this kernel.
tflops = 0
Expand All @@ -84,6 +55,7 @@ def main(args: Any):
/ 1e9
)
# Record results.
print(f"Kernel: {kernel_name}, ms: {ms_runtime:.4f}, TFLOPS: {tflops:.2f}")
benchmark_results.append(
{
"kernel_name": kernel_name,
Expand All @@ -92,13 +64,11 @@ def main(args: Any):
"gbps": gbps,
}
)
# Print all results.
print("Benchmark results:")
for result in benchmark_results:
print(f"Kernel: {result['kernel_name']}, TFLOPS: {result['tflops']}")
# Report best kernel.
best_kernel = min(benchmark_results, key=lambda x: x["ms_runtime"])
print(f"Best kernel for this shape: {best_kernel['kernel_name']}")
print(
f"Best kernel for this shape: {best_kernel['kernel_name']}: {best_kernel['tflops']:.2f} TFLOPS"
)

# If specified, save all results.
if args.export_csv:
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def benchmark_grouped(
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=False,
use_cuda_graph=True,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=False,
use_cuda_graph=True,
)

# Print out results for this op.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
#include <cstdlib>
#include <functional>
#include <initializer_list>
Expand All @@ -19,23 +20,139 @@
#include <c10/hip/HIPStream.h>
#include <torch/torch.h>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "kernels/fp8_rowwise_grouped_kernel_manifest.h"

namespace fbgemm_gpu {

// Define useful types that are needed for various kernels.
using KernelArguments =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>;
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using D0DataType = float;
using D1DataType = float;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = ck::bhalf_t;

RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
// We use shape heuristics to find the best kernel.
// To do this, we divide by the size of M and find the best
// option within that grouping.
if (M <= 16) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
}
if (K <= 8192) {
return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2;
}
return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2;
}
if (M <= 32) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
}
if (K <= 8192) {
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2;
}
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2;
}
if (M <= 64) {
return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
}
if (M <= 128) {
if (N < 8192 && K <= 8192) {
return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
}
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
}
if (M <= 256) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
}
if (M <= 512) {
if (K <= 8192) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
}
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
}
// Default kernel for all other shapes.
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
}

__global__ void set_kernel_args_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
// Write kernel arguments directly to memory.
KernelArguments kernel_group_args = {
XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N};
kernel_args[0] = kernel_group_args;
}
}

void set_grouped_kernel_args(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor kernel_args,
std::vector<at::Tensor> output) {
TORCH_CHECK(
XQ.size() == WQ.size() && XQ.size() == x_scale.size() &&
XQ.size() == w_scale.size(),
"All inputs must have the same number of groups.");
int group_count = XQ.size();
// We use the smallest reasonable block size since we effectively need only 1 thread.
int blockSize = 32;
int numBlocks = 1;
auto stream = at::cuda::getCurrentHIPStream().stream();

// Launch a kernel for each group to set kernel memory on device.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
// Launch kernel to set kernel arguments.
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(
reinterpret_cast<char*>(kernel_args.data_ptr()) +
(i * sizeof(KernelArguments))),
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr()),
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K);
}
}

std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<at::TensorList> output = std::nullopt,
std::optional<std::vector<at::Tensor>> output = std::nullopt,
std::optional<std::string> kernel_name = std::nullopt) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
TORCH_CHECK(
XQ.size() == WQ.size() && XQ.size() == x_scale.size() &&
XQ.size() == w_scale.size(),
"All inputs must have the same number of groups.");
int group_count = XQ.size();
// Iterate over inputs and check they are valid.
for (at::Tensor x : XQ) {
TORCH_CHECK(x.is_cuda() && x.is_contiguous());
Expand All @@ -58,39 +175,68 @@ std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32.");
}

// Allocate output if needed.
std::vector<at::Tensor> Y;
Y.reserve(XQ.size());
if (output.has_value()) {
TORCH_CHECK(output.value().size() == XQ.size(), "Output and input must have same number of groups.");
Y = output.value();
TORCH_CHECK(
Y.size() == group_count,
"Output and input must have same number of groups.");
// Check that output shapes are correct.
for (int i = 0; i < output.value().size(); i++) {
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int out_M = output.value()[i].size(0);
int out_N = output.value()[i].size(1);
TORCH_CHECK(M == out_M && N == out_N, "Output tensors do not have the expected shape.");
TORCH_CHECK(output.value()[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
Y.push_back(output.value()[i]);
int out_M = Y[i].size(0);
int out_N = Y[i].size(1);
TORCH_CHECK(
M == out_M && N == out_N,
"Output tensors do not have the expected shape.");
TORCH_CHECK(
Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16.");
}
} else {
for (int i = 0; i < XQ.size(); i++) {
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16)));
}
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = at::empty({1000}, XQ[0].options().dtype(at::kByte));
set_grouped_kernel_args(XQ, WQ, x_scale, w_scale, kernel_args, Y);

// If provided a specific kernel implementation, dispatch to it.
if (kernel_name.has_value()) {
auto it = kernel_name_map.find(kernel_name.value());
// If not found, raise an error.
TORCH_CHECK(it != kernel_name_map.end(), "Could not find kernel " + kernel_name.value());
TORCH_CHECK(
it != kernel_name_map.end(),
"Could not find kernel " + kernel_name.value());
// If found, always use requested kernel.
return it->second(XQ, WQ, x_scale, w_scale, Y);
return it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y);
}
// Otherwise, use heuristics to find the best kernel options.
// We use the largest of each shape for heuristics.
int MaxM = 0;
int MaxN = 0;
int MaxK = 0;
for (int i = 0; i < group_count; i++) {
MaxM = max(MaxM, XQ[i].size(0));
MaxN = max(MaxN, WQ[i].size(0));
MaxK = max(MaxK, XQ[i].size(1));
}
RowwiseGroupedKernel selected_kernel =
rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y);
}

std::vector<std::string> get_f8f8bf16_rowwise_grouped_kernels() {
/* Helper function to get the names of avaialable grouped gemm kernels.*/
std::vector<std::string> kernel_names;
for (const auto& pair : kernel_name_map) {
kernel_names.push_back(pair.first);
}
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3(
XQ, WQ, x_scale, w_scale, Y);
return kernel_names;
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_grouped_common.h"

std::vector<at::Tensor>
fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2(
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor kernel_args,
std::vector<at::Tensor> Y) {
// Check if this input needs to be padded.
bool pad = false;
for (int i = 0; i < XQ.size(); i++) {
int K = XQ[i].size(1);
if (K % 128 != 0) {
pad = true;
}
}
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
128,
16,
128,
16,
16,
4,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<2, 2, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_grouped_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, kernel_args, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
128,
16,
128,
16,
16,
4,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<2, 2, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_grouped_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, kernel_args, Y);
}
}
Loading

0 comments on commit 2146145

Please sign in to comment.