-
Notifications
You must be signed in to change notification settings - Fork 501
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Heuristic Tuning for CK FP8 Grouped Gemm (#3356)
Summary: Pull Request resolved: #3356 X-link: facebookresearch/FBGEMM#448 This diff gets our CK grouped gemm kernel with fused fp8 rowwise scaling production ready. The major improvements made are: 1. Added a ton of new kernel configurations. 2. Uses heuristic dispatch to select efficient kernels across many relevant shapes. 3. Makes this kernel compatible with cuda graphs. Cuda graph compatibility in particular turned out to be very tricky. Grouped gemm kernels on both AMD and NV require a special kernel arguments tensor that has to be set on the GPU. Normally this would be done with a host to device copy, but that is not allowed in a cuda graph. Instead, we need to launch a kernel that directly sets the memory on device. What makes it tricky is that no host memory is allowed, including the shapes of the tensors. We instead have to launch one kernel for each group. This does mean that there will be a bunch of extra kernel launches, but in my testing they dont seem to be expensive. The only alternative to this approach is doing comprehensive memory planning for all layers, which is what [TensorRT LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu#L816) does. Adding memory planning would complicate end to end integration by quite a bit, so I think this multi-kernel launch approach is a good balance of performance and simplicity. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D65634843 fbshipit-source-id: 807d05334937533e9d54c75d3bd3c6c62d78a672
- Loading branch information
1 parent
f12bfd0
commit 2146145
Showing
76 changed files
with
5,107 additions
and
1,183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
...owwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "fp8_rowwise_grouped_common.h" | ||
|
||
std::vector<at::Tensor> | ||
fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( | ||
at::TensorList XQ, | ||
at::TensorList WQ, | ||
at::TensorList x_scale, | ||
at::TensorList w_scale, | ||
at::Tensor kernel_args, | ||
std::vector<at::Tensor> Y) { | ||
// Check if this input needs to be padded. | ||
bool pad = false; | ||
for (int i = 0; i < XQ.size(); i++) { | ||
int K = XQ[i].size(1); | ||
if (K % 128 != 0) { | ||
pad = true; | ||
} | ||
} | ||
if (pad) { | ||
using DeviceGemmInstance = DeviceGemmHelper< | ||
128, | ||
128, | ||
16, | ||
128, | ||
16, | ||
16, | ||
4, | ||
1, | ||
S<8, 16, 1>, | ||
S<8, 16, 1>, | ||
S<1, 16, 1, 8>, | ||
S<2, 2, 1>, | ||
1, | ||
1, | ||
ck::BlockGemmPipelineScheduler::Interwave, | ||
ck::BlockGemmPipelineVersion::v2, | ||
ck::tensor_operation::device::GemmSpecialization::KPadding>; | ||
// Run kernel instance. | ||
return f8f8bf16_rowwise_grouped_impl<DeviceGemmInstance>( | ||
XQ, WQ, x_scale, w_scale, kernel_args, Y); | ||
} else { | ||
using DeviceGemmInstance = DeviceGemmHelper< | ||
128, | ||
128, | ||
16, | ||
128, | ||
16, | ||
16, | ||
4, | ||
1, | ||
S<8, 16, 1>, | ||
S<8, 16, 1>, | ||
S<1, 16, 1, 8>, | ||
S<2, 2, 1>, | ||
1, | ||
1, | ||
ck::BlockGemmPipelineScheduler::Interwave, | ||
ck::BlockGemmPipelineVersion::v2, | ||
ck::tensor_operation::device::GemmSpecialization::Default>; | ||
// Run kernel instance. | ||
return f8f8bf16_rowwise_grouped_impl<DeviceGemmInstance>( | ||
XQ, WQ, x_scale, w_scale, kernel_args, Y); | ||
} | ||
} |
Oops, something went wrong.