Skip to content

Commit b296490

Browse files
kaixihAkshat-Tripathi
authored andcommitted
[NVIDIA] Support nvfp4 cutlass gemm (vllm-project#13571)
1 parent f605734 commit b296490

File tree

7 files changed

+494
-1
lines changed

7 files changed

+494
-1
lines changed

CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
229229

230230
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
231231
# Please keep this in sync with FetchContent_Declare line below.
232-
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
232+
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
233233

234234
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
235235
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
267267
"csrc/permute_cols.cu"
268268
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
269269
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
270+
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
270271
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
271272
"csrc/cutlass_extensions/common.cpp")
272273

@@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
383384
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
384385
set(SRCS
385386
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
387+
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
386388
)
387389
set_gencode_flags_for_srcs(
388390
SRCS "${SRCS}"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
152152
int64_t row);
153153

154154
#ifndef USE_ROCM
155+
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
156+
torch::Tensor const& B, torch::Tensor const& A_sf,
157+
torch::Tensor const& B_sf,
158+
torch::Tensor const& alpha);
159+
155160
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
156161
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
157162

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
20+
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
21+
torch::Tensor const& B,
22+
torch::Tensor const& A_sf,
23+
torch::Tensor const& B_sf,
24+
torch::Tensor const& alpha);
25+
#endif
26+
27+
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
28+
torch::Tensor const& B, torch::Tensor const& A_sf,
29+
torch::Tensor const& B_sf,
30+
torch::Tensor const& alpha) {
31+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
32+
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
33+
#endif
34+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should "
35+
"be compiled using CUDA 12.8 and target "
36+
"compute capability 100 or above.");
37+
}
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#include <ATen/cuda/CUDAContext.h>
20+
#include <c10/cuda/CUDAGuard.h>
21+
22+
#include "cutlass_extensions/common.hpp"
23+
24+
#include "cutlass/cutlass.h"
25+
26+
#include "cutlass/gemm/collective/collective_builder.hpp"
27+
#include "cutlass/epilogue/collective/collective_builder.hpp"
28+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
29+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
30+
31+
#include "cutlass/util/packed_stride.hpp"
32+
33+
using namespace cute;
34+
35+
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
36+
// Kernel Perf config
37+
template <typename T>
38+
struct KernelTraits;
39+
40+
template <>
41+
struct KernelTraits<float> {
42+
using MmaTileShape = Shape<_128, _128, _256>;
43+
using ClusterShape = Shape<_1, _1, _1>;
44+
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
45+
};
46+
47+
template <>
48+
struct KernelTraits<cutlass::half_t> {
49+
using MmaTileShape = Shape<_256, _256, _256>;
50+
using ClusterShape = Shape<_4, _4, _1>;
51+
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
52+
};
53+
54+
template <>
55+
struct KernelTraits<cutlass::bfloat16_t> {
56+
using MmaTileShape = Shape<_256, _256, _256>;
57+
using ClusterShape = Shape<_4, _4, _1>;
58+
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
59+
};
60+
61+
template <typename T>
62+
struct Fp4GemmSm100 {
63+
// A matrix configuration
64+
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
65+
using LayoutATag = cutlass::layout::RowMajor;
66+
static constexpr int AlignmentA = 32;
67+
68+
// B matrix configuration
69+
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
70+
using LayoutBTag = cutlass::layout::ColumnMajor;
71+
static constexpr int AlignmentB = 32;
72+
73+
// C/D matrix configuration
74+
using ElementD = T;
75+
using ElementC = T;
76+
using LayoutCTag = cutlass::layout::RowMajor;
77+
using LayoutDTag = cutlass::layout::RowMajor;
78+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
79+
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
80+
// Kernel functional config
81+
using ElementAccumulator = float;
82+
using ArchTag = cutlass::arch::Sm100;
83+
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
84+
85+
// Kernel Perf config
86+
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
87+
using ClusterShape = typename KernelTraits<T>::ClusterShape;
88+
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
89+
90+
using CollectiveEpilogue =
91+
typename cutlass::epilogue::collective::CollectiveBuilder<
92+
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
93+
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
94+
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
95+
LayoutDTag, AlignmentD,
96+
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
97+
98+
using CollectiveMainloop =
99+
typename cutlass::gemm::collective::CollectiveBuilder<
100+
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
101+
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
102+
ClusterShape,
103+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
104+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
105+
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
106+
107+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
108+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
109+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
110+
using StrideA = typename Gemm::GemmKernel::StrideA;
111+
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
112+
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
113+
using StrideB = typename Gemm::GemmKernel::StrideB;
114+
using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
115+
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
116+
using StrideC = typename Gemm::GemmKernel::StrideC;
117+
using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
118+
using StrideD = typename Gemm::GemmKernel::StrideD;
119+
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
120+
};
121+
122+
template <typename T>
123+
typename T::Gemm::Arguments args_from_options(
124+
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
125+
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
126+
int64_t M, int64_t N, int64_t K) {
127+
using ElementA = typename T::Gemm::ElementA;
128+
using ElementB = typename T::Gemm::ElementB;
129+
using ElementSFA = cutlass::float_ue4m3_t;
130+
using ElementSFB = cutlass::float_ue4m3_t;
131+
using ElementD = typename T::Gemm::ElementD;
132+
using ElementCompute = float;
133+
using StrideA = typename T::StrideA;
134+
using StrideB = typename T::StrideB;
135+
using StrideD = typename T::StrideD;
136+
using Sm100BlkScaledConfig =
137+
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
138+
139+
int m = static_cast<int>(M);
140+
int n = static_cast<int>(N);
141+
int k = static_cast<int>(K);
142+
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
143+
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
144+
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
145+
146+
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(
147+
cute::make_shape(m, n, k, 1));
148+
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
149+
cute::make_shape(m, n, k, 1));
150+
151+
typename T::Gemm::Arguments arguments{
152+
cutlass::gemm::GemmUniversalMode::kGemm,
153+
{m, n, k, 1},
154+
{// Mainloop arguments
155+
static_cast<ElementA const*>(A.data_ptr()), stride_A,
156+
static_cast<ElementB const*>(B.data_ptr()), stride_B,
157+
static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA,
158+
static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB},
159+
{ // Epilogue arguments
160+
{}, // epilogue.thread
161+
static_cast<ElementD const*>(D.data_ptr()),
162+
stride_D,
163+
static_cast<ElementD*>(D.data_ptr()),
164+
stride_D}};
165+
auto& fusion_args = arguments.epilogue.thread;
166+
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
167+
return arguments;
168+
}
169+
170+
template <typename T>
171+
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
172+
at::Tensor const& A_sf, at::Tensor const& B_sf,
173+
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
174+
cudaStream_t stream) {
175+
typename Fp4GemmSm100<T>::Gemm gemm;
176+
177+
auto arguments =
178+
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
179+
180+
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
181+
auto const workspace_options =
182+
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
183+
auto workspace = torch::empty(workspace_size, workspace_options);
184+
185+
CUTLASS_CHECK(gemm.can_implement(arguments));
186+
187+
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
188+
189+
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
190+
}
191+
#else
192+
template <typename T>
193+
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
194+
at::Tensor const& A_sf, at::Tensor const& B_sf,
195+
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
196+
cudaStream_t stream) {
197+
TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
198+
"a CUTLASS 3.8 source directory to enable support.");
199+
}
200+
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
201+
202+
#define CHECK_TYPE(x, st, m) \
203+
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
204+
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
205+
#define CHECK_CONTIGUOUS(x, m) \
206+
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
207+
#define CHECK_INPUT(x, st, m) \
208+
CHECK_TH_CUDA(x, m); \
209+
CHECK_CONTIGUOUS(x, m); \
210+
CHECK_TYPE(x, st, m)
211+
212+
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
213+
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
214+
215+
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
216+
torch::Tensor const& B,
217+
torch::Tensor const& A_sf,
218+
torch::Tensor const& B_sf,
219+
torch::Tensor const& alpha) {
220+
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
221+
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
222+
223+
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
224+
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
225+
226+
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
227+
228+
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
229+
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
230+
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
231+
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
232+
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
233+
234+
auto const m = A.sizes()[0];
235+
auto const n = B.sizes()[0];
236+
auto const k = A.sizes()[1] * 2;
237+
238+
constexpr int alignment = 32;
239+
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
240+
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
241+
"), k: ", k, ".");
242+
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
243+
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
244+
245+
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
246+
int rounded_m = round_up(m, 128);
247+
int rounded_n = round_up(n, 128);
248+
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
249+
// integer.
250+
int rounded_k = round_up(k / 16, 4);
251+
252+
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
253+
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
254+
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
255+
"scale_a and scale_b shapes cannot be multiplied (",
256+
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
257+
"x", B_sf.sizes()[1], ")");
258+
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
259+
"scale_a must be padded and swizzled to a shape (", rounded_m,
260+
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
261+
A_sf.sizes()[1], ")");
262+
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
263+
"scale_b must be padded and swizzled to a shape (", rounded_n,
264+
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
265+
B_sf.sizes()[1], ")");
266+
267+
auto out_dtype = D.dtype();
268+
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
269+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
270+
271+
if (out_dtype == at::ScalarType::Half) {
272+
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
273+
} else if (out_dtype == at::ScalarType::BFloat16) {
274+
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
275+
} else if (out_dtype == at::ScalarType::Float) {
276+
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
277+
} else {
278+
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
279+
}
280+
}

csrc/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
302302
"SymInt size_k) -> Tensor");
303303
// conditionally compiled so impl registration is in source file
304304

305+
// CUTLASS nvfp4 block scaled GEMM
306+
ops.def(
307+
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
308+
" Tensor block_scale_a, Tensor block_scale_b,"
309+
" Tensor alpha) -> ()");
310+
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
311+
305312
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
306313
// quantization, as well as bias
307314
ops.def(

0 commit comments

Comments
 (0)