Skip to content

Commit 8d551aa

Browse files
committed
Add CUTLASS fused moe kernels from TensorRT-LLM.
1 parent 8a95bb3 commit 8d551aa

File tree

129 files changed

+34588
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+34588
-3
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "cutlass_fused_moe_kernels.cuh"
18+
#include "moe_kernels.h"
19+
20+
namespace tensorrt_llm::kernels {
21+
// ==================== Variable batched GEMM specializations ==================================
22+
template class CutlassMoeFCRunner<float, float>;
23+
24+
#ifdef ENABLE_BF16
25+
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>;
26+
template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>;
27+
template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>;
28+
#endif
29+
30+
template class CutlassMoeFCRunner<half, half>;
31+
template class CutlassMoeFCRunner<half, uint8_t>;
32+
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
33+
#ifdef ENABLE_FP8
34+
// template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>;
35+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
36+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
37+
#ifdef ENABLE_BF16
38+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
39+
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
40+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
41+
#endif
42+
#endif
43+
#ifdef ENABLE_FP4
44+
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>;
45+
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>;
46+
#ifdef ENABLE_BF16
47+
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>;
48+
template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>;
49+
#endif
50+
#endif
51+
52+
}; // namespace tensorrt_llm::kernels

0 commit comments

Comments
 (0)