Skip to content

Commit 8276d03

Browse files
djmmossyzh119
andauthored
feat:enable fp8 blockscale moe for fused cultass for sm90 (#1819)
## πŸ“Œ Description Adds FP8 Block Scaling Fused Cutlass MoE for SM90. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent bbb57ad commit 8276d03

22 files changed

+10029
-132
lines changed

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuhβ€Ž

Lines changed: 543 additions & 0 deletions
Large diffs are not rendered by default.

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm.cuhβ€Ž

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

β€Žcsrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm_impl.cuhβ€Ž

Lines changed: 823 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
#include <cuda_runtime.h>
20+
#include <nvrtc.h>
21+
22+
#include <climits>
23+
#include <cstdint>
24+
#include <iostream>
25+
#include <string>
26+
#include <tuple>
27+
#include <vector>
28+
29+
#include "scheduler.cuh"
30+
31+
// Helper function to check NVRTC errors
32+
#define CHECK_NVRTC(call) \
33+
do { \
34+
nvrtcResult result = call; \
35+
if (result != NVRTC_SUCCESS) { \
36+
std::cerr << "NVRTC error: " << nvrtcGetErrorString(result) << std::endl; \
37+
exit(1); \
38+
} \
39+
} while (0)
40+
41+
// Helper function to check CUDA driver errors
42+
#define CHECK_CUDA(call) \
43+
do { \
44+
CUresult result = call; \
45+
if (result != CUDA_SUCCESS) { \
46+
const char* error_string; \
47+
cuGetErrorString(result, &error_string); \
48+
std::cerr << "CUDA error: " << error_string << std::endl; \
49+
exit(1); \
50+
} \
51+
} while (0)
52+
53+
namespace deep_gemm::jit {
54+
55+
using GemmConfig = std::tuple<int, int, int, int, int>; // block_m, block_n, num_stages,
56+
// num_tma_multicast, best_smem_size
57+
58+
std::string gemm_type_to_string(deep_gemm::GemmType gemm_type);
59+
60+
int div_up(int a, int b);
61+
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k, bool swap_ab);
62+
bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms);
63+
GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
64+
int num_groups, int num_device_sms, bool is_grouped_contiguous,
65+
bool swap_ab);
66+
} // namespace deep_gemm::jit
67+
68+
namespace deep_gemm::jit {
69+
70+
std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) {
71+
switch (gemm_type) {
72+
case deep_gemm::GemmType::Normal:
73+
return std::string("Normal");
74+
case deep_gemm::GemmType::GroupedContiguous:
75+
return std::string("GroupedContiguous");
76+
case deep_gemm::GemmType::GroupedMasked:
77+
return std::string("GroupedMasked");
78+
case deep_gemm::GemmType::GroupedWithOffset:
79+
return std::string("GroupedWithOffset");
80+
case deep_gemm::GemmType::StridedBatched:
81+
return std::string("StridedBatched");
82+
// Add other GEMM types as needed
83+
default:
84+
return std::string("Unknown");
85+
}
86+
}
87+
88+
int div_up(int a, int b) { return (a + b - 1) / b; }
89+
90+
int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128,
91+
bool swap_ab = false) {
92+
if (!swap_ab) {
93+
int smem_d = block_m * block_n * 2;
94+
int smem_a_per_stage = block_m * block_k;
95+
int smem_scales_a_per_stage = block_m * 4;
96+
int smem_b_per_stage = block_n * block_k;
97+
int smem_scales_b = div_up(k, block_k) * 4;
98+
int smem_barrier = num_stages * 8 * 2;
99+
100+
int smem_size = 0;
101+
smem_size += smem_d;
102+
smem_size += num_stages * smem_a_per_stage;
103+
smem_size += num_stages * smem_scales_a_per_stage;
104+
smem_size += num_stages * smem_b_per_stage;
105+
smem_size += div_up(smem_scales_b * (block_k % block_n == 0 ? 1 : 2), 8) * 8;
106+
smem_size += smem_barrier;
107+
108+
return smem_size;
109+
} else {
110+
int smem_d = block_n * block_m * 2;
111+
int smem_a_per_stage = block_m * block_k; // weight
112+
int smem_scales_a_per_stage = div_up(k, block_k) * 4; // weight scales
113+
int smem_b_per_stage = block_n * block_k; // act
114+
int smem_scales_b = div_up(block_n * 4, 128) * 128; // act scales,tma 128B alignment
115+
int smem_barrier = num_stages * 8 * 2;
116+
117+
int smem_size = 0;
118+
smem_size += smem_d;
119+
smem_size += num_stages * smem_a_per_stage;
120+
smem_size += num_stages * smem_scales_b;
121+
smem_size += num_stages * smem_b_per_stage;
122+
smem_size += div_up(smem_scales_a_per_stage, 8) * 8;
123+
smem_size += smem_barrier;
124+
125+
return smem_size;
126+
}
127+
}
128+
129+
bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) {
130+
if (num_tma_multicast == 1) {
131+
return true;
132+
}
133+
return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0;
134+
}
135+
136+
GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
137+
int num_groups, int num_device_sms,
138+
bool is_grouped_contiguous = false, bool swap_ab = false) {
139+
// Choose candidate block sizes
140+
std::vector<int> block_ms;
141+
block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128);
142+
143+
// Candidate block sizes for N dimension
144+
std::vector<int> block_ns;
145+
for (int i = 16; i <= 128; i += 8) {
146+
block_ns.push_back(i);
147+
}
148+
149+
// Lambda functions for calculating waves and utilization
150+
auto fix_wave_saturate = [num_device_sms](int x) -> int { return x == 0 ? num_device_sms : x; };
151+
152+
auto get_num_waves = [shape_m, shape_n, num_groups, num_device_sms](int block_m,
153+
int block_n) -> int {
154+
return div_up(div_up(shape_m, block_m) * div_up(shape_n, block_n) * num_groups, num_device_sms);
155+
};
156+
157+
auto get_last_wave_util = [shape_m, shape_n, num_groups, num_device_sms, &fix_wave_saturate](
158+
int block_m, int block_n) -> int {
159+
return fix_wave_saturate((div_up(shape_m, block_m) * div_up(shape_n, block_n) * num_groups) %
160+
num_device_sms);
161+
};
162+
163+
// Find best block sizes
164+
int best_block_m = 0;
165+
int best_block_n = 0;
166+
for (int block_m : block_ms) {
167+
for (int block_n : block_ns) {
168+
bool success = false;
169+
int num_waves = get_num_waves(block_m, block_n);
170+
int best_num_waves = best_block_m == 0 ? INT_MAX : get_num_waves(best_block_m, best_block_n);
171+
172+
if (best_block_m == 0 || best_block_n == 0) {
173+
success = true;
174+
} else if (num_waves < best_num_waves) {
175+
success = true;
176+
} else if (num_waves == best_num_waves) {
177+
// Check last wave utilization
178+
int util = get_last_wave_util(block_m, block_n);
179+
int best_util = get_last_wave_util(best_block_m, best_block_n);
180+
success = util > best_util ||
181+
(util == best_util &&
182+
(block_m > best_block_m || (block_m == best_block_m && block_n < best_block_n)));
183+
}
184+
185+
if (success) {
186+
best_block_m = block_m;
187+
best_block_n = block_n;
188+
}
189+
}
190+
}
191+
192+
// Find best number of stages
193+
int best_num_stages = 0;
194+
int best_smem_size = 0;
195+
constexpr int sm90_capacity = 232448;
196+
197+
std::vector<int> stage_candidates;
198+
if (128 % best_block_n != 0) {
199+
stage_candidates = {6, 5, 4};
200+
} else {
201+
stage_candidates = {8, 7, 6, 5, 4};
202+
}
203+
204+
for (int num_stages : stage_candidates) {
205+
int smem_size = get_smem_size(num_stages, shape_k, best_block_m, best_block_n, 128, swap_ab);
206+
if (smem_size <= sm90_capacity) {
207+
best_num_stages = num_stages;
208+
best_smem_size = smem_size;
209+
break;
210+
}
211+
}
212+
213+
// Determine TMA multicast settings
214+
int best_num_tma_multicast = 1;
215+
216+
if (!swap_ab) {
217+
if (shape_m >= 1024 && is_tma_multicast_legal(shape_n, best_block_n, 2, num_device_sms) &&
218+
num_groups == 1) {
219+
best_num_tma_multicast = 2;
220+
}
221+
} else {
222+
if (shape_n >= 1024 && is_tma_multicast_legal(shape_m, best_block_m, 2, num_device_sms) &&
223+
num_groups == 1) {
224+
best_num_tma_multicast = 2;
225+
}
226+
}
227+
228+
return std::make_tuple(best_block_m, best_block_n, best_num_stages, best_num_tma_multicast,
229+
best_smem_size);
230+
}
231+
} // namespace deep_gemm::jit

0 commit comments

Comments
Β (0)