Skip to content

Commit

Permalink
Support FP8 grouped GEMM with cudagraph (#3373)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#463

Enable cudagraph support for FP8 grouped GEMM

It's quite challenging to make cudagraph support to handle more complicated kernel arguments with various pointer array and memory alignment, compared to cudagraph support in CK grouped GEMM in D65634843

Differential Revision: D65864972
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Nov 14, 2024
1 parent 9b4b04b commit 1c3720a
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,80 @@ struct GroupedGemmConfigs {
};
} // namespace GroupedGemmArgs

__global__ void set_kernel_args_kernel(
int64_t xq_ptr,
int64_t wq_ptr,
int64_t scale_ptr,
int64_t* input_args_ptr,
int64_t* output_args_ptr,
at::BFloat16* output_data,
int output_offset,
int xq_ptr_offset,
int wq_ptr_offset,
int scale_ptr_offset,
int problem_shape_buf_offset,
int stride_buf_offset,
int stride_size,
int problem_count,
int problem_shape_size,
int group_index,
int M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
int64_t* xq_ptr_ = input_args_ptr + xq_ptr_offset;
int64_t* wq_ptr_ = input_args_ptr + wq_ptr_offset;
int64_t* scale_ptr_ = input_args_ptr + scale_ptr_offset;
uint8_t* problem_shape_buf =
reinterpret_cast<uint8_t*>(input_args_ptr + problem_shape_buf_offset);
uint8_t* stride_buf =
reinterpret_cast<uint8_t*>(input_args_ptr + stride_buf_offset);

GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
reinterpret_cast<int64_t>(output_data + output_offset);

// Write kernel arguments directly to memory.
xq_ptr_[group_index] = xq_ptr;
wq_ptr_[group_index] = wq_ptr;
scale_ptr_[group_index] = scale_ptr;
problem_shape_ptr[group_index] =
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(M, N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
{M, K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
{M, N, 1});
}
}

template <
int TB_M,
int TB_N,
Expand Down Expand Up @@ -150,14 +224,11 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
// in units of elements
// (up to 16 bytes)

int64_t output_offset = 0;
int64_t total_output_size = 0;
std::vector<int64_t> output_sizes;
output_sizes.reserve(problem_count);
at::Tensor output_args = at::empty(
{problem_count},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
int64_t* output_ptr = output_args.data_ptr<int64_t>();
at::Tensor output_args =
at::empty({problem_count}, XQ[0].options().dtype(at::kLong));

const int64_t problem_shape_size = problem_count *
((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape));
Expand All @@ -166,30 +237,14 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

at::Tensor input_args = at::empty(
{problem_count * 3 + problem_shape_size + stride_size * 3},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
XQ[0].options().dtype(at::kLong));

int64_t* xq_ptr = input_args.data_ptr<int64_t>();
int64_t* wq_ptr =
input_args.data_ptr<int64_t>() + (problem_count * sizeof(int64_t));
int64_t* scale_ptr =
input_args.data_ptr<int64_t>() + (problem_count * 2 * sizeof(int64_t));
uint8_t* problem_shape_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)));
uint8_t* stride_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)) +
problem_shape_size);

GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
typename GroupedGemmConfigs::StrideInputA* stride_input_A_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputA*>(stride_buf);
typename GroupedGemmConfigs::StrideInputB* stride_input_B_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputB*>(
stride_buf + stride_size);
typename GroupedGemmConfigs::StrideOutput* stride_output_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideOutput*>(
stride_buf + (stride_size * 2));
int xq_ptr_offset = 0;
int wq_ptr_offset = problem_count * sizeof(int64_t);
int scale_ptr_offset = problem_count * 2 * sizeof(int64_t);
int problem_shape_buf_offset = problem_count * 3 * sizeof(int64_t);
int stride_buf_offset =
problem_count * 3 * sizeof(int64_t) + problem_shape_size;

for (int i = 0; i < problem_count; ++i) {
const int64_t output_size = XQ[i].size(0) * WQ[i].size(0);
Expand All @@ -199,78 +254,61 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

at::Tensor output_tensor =
at::empty(total_output_size, XQ[0].options().dtype(at::kBFloat16));
at::BFloat16* output_data = output_tensor.data_ptr<at::BFloat16>();

int blockSize = 256;
int numBlocks = 1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t output_offset = 0;

// Set arguments
for (int i = 0; i < problem_count; ++i) {
int m = XQ[i].size(0);
int n = WQ[i].size(0);
int k = XQ[i].size(1);
TORCH_CHECK_EQ(WQ[i].size(1), k);
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int K = XQ[i].size(1);
TORCH_CHECK_EQ(WQ[i].size(1), K);
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>()),
reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
scale_ptr_offset,
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
problem_shape_size,
i,
M,
N,
K);

output_ptr[i] = reinterpret_cast<int64_t>(output_data + output_offset);
output_offset += output_sizes[i];

xq_ptr[i] = reinterpret_cast<int64_t>(XQ[i].data_ptr<at::Float8_e4m3fn>());
wq_ptr[i] = reinterpret_cast<int64_t>(WQ[i].data_ptr<at::Float8_e4m3fn>());
scale_ptr[i] = reinterpret_cast<int64_t>(
scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>());
problem_shape_ptr[i] =
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(m, n, k);
stride_input_A_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideInputA{}, {m, k, 1});
stride_input_B_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideInputB{}, {n, k, 1});
stride_output_ptr[i] = cutlass::make_cute_packed_stride(
typename GroupedGemmConfigs::StrideOutput{}, {m, n, 1});
}

// Allocate input args memory on the GPU
size_t input_args_size = input_args.numel() * sizeof(int64_t);
at::Tensor d_input_args = at::empty(
{problem_count * 3 + problem_shape_size + stride_size * 3},
at::TensorOptions().dtype(at::kLong).device(at::kCUDA));

// Allocate output args memory on the GPU
size_t output_args_size = output_args.numel() * sizeof(int64_t);
at::Tensor d_output_args = at::empty(
{problem_count}, at::TensorOptions().dtype(at::kLong).device(at::kCUDA));

// Copy data from CPU to GPU asynchronously
cudaMemcpyAsync(
d_input_args.data_ptr(),
input_args.data_ptr<int64_t>(),
input_args_size,
cudaMemcpyHostToDevice,
at::cuda::getCurrentCUDAStream());

cudaMemcpyAsync(
d_output_args.data_ptr(),
output_args.data_ptr<int64_t>(),
output_args_size,
cudaMemcpyHostToDevice,
at::cuda::getCurrentCUDAStream());

output_ptr = output_args.data_ptr<int64_t>();
xq_ptr = input_args.data_ptr<int64_t>();
wq_ptr = input_args.data_ptr<int64_t>() + (problem_count * sizeof(int64_t));
scale_ptr =
input_args.data_ptr<int64_t>() + (problem_count * 2 * sizeof(int64_t));

problem_shape_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)));
problem_shape_ptr =
int64_t* output_ptr = output_args.data_ptr<int64_t>();
int64_t* xq_ptr = input_args.data_ptr<int64_t>() + xq_ptr_offset;
int64_t* wq_ptr = input_args.data_ptr<int64_t>() + wq_ptr_offset;
int64_t* scale_ptr = input_args.data_ptr<int64_t>() + scale_ptr_offset;
uint8_t* problem_shape_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + problem_shape_buf_offset);
uint8_t* stride_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + stride_buf_offset);

GroupedGemmArgs::ProblemShape::UnderlyingProblemShape* problem_shape_ptr =
reinterpret_cast<GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
stride_buf = reinterpret_cast<uint8_t*>(
input_args.data_ptr<int64_t>() + (problem_count * 3 * sizeof(int64_t)) +
problem_shape_size);
stride_input_A_ptr =
typename GroupedGemmConfigs::StrideInputA* stride_input_A_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputA*>(stride_buf);
stride_input_B_ptr =
typename GroupedGemmConfigs::StrideInputB* stride_input_B_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideInputB*>(
stride_buf + stride_size);
stride_output_ptr =
typename GroupedGemmConfigs::StrideOutput* stride_output_ptr =
reinterpret_cast<typename GroupedGemmConfigs::StrideOutput*>(
stride_buf + (stride_size * 2));

Expand Down Expand Up @@ -301,7 +339,8 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
GroupedGemmConfigs::Gemm::get_workspace_size(arguments);

// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
at::Tensor workspace =
at::empty(workspace_size, XQ[0].options().dtype(at::kByte));

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
Expand All @@ -310,7 +349,8 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
}

// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
status = gemm.initialize(
arguments, reinterpret_cast<uint8_t*>(workspace.data_ptr()));
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
Expand Down
20 changes: 18 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import fbgemm_gpu.experimental.gen_ai # noqa: F401

import torch
import triton # noqa: F401

from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
matmul_fp8_block,
Expand Down Expand Up @@ -746,13 +747,15 @@ def fp8_loopover_bmm(
M=st.sampled_from([2048, 3584]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
use_cudagraph=st.sampled_from([True, False]),
)
def test_fp8_grouped_gemm(
self,
G: int,
M: int,
N: int,
K: int,
use_cudagraph: bool,
) -> None:
ms = torch.randint(1, (M // 64) + 1, (G,), dtype=torch.int) * 64
ns = torch.randint(1, (N // 64) + 1, (G,), dtype=torch.int) * 64
Expand All @@ -775,13 +778,26 @@ def test_fp8_grouped_gemm(
wq_group.append(wq)
scale_group.append(x_scale * w_scale)

# FP8 grouped gemm kernel
if use_cudagraph:
# warmup
torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group)
# With cudagraph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_group = torch.ops.fbgemm.f8f8bf16_grouped(
xq_group, wq_group, scale_group
)
g.replay()
else:
y_group = torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group)

# BF16 loopover gemm reference
y_group_ref = []
for i in range(len(x_group)):
y = torch.matmul(x_group[i], w_group[i].t())
y_group_ref.append(y)

y_group = torch.ops.fbgemm.f8f8bf16_grouped(xq_group, wq_group, scale_group)

for i in range(len(y_group)):
torch.testing.assert_close(
y_group[i], y_group_ref[i], atol=8.0e-2, rtol=8.0e-2
Expand Down

0 comments on commit 1c3720a

Please sign in to comment.