Skip to content

Commit ebb554c

Browse files
committed
Support nvfp4 quantization
Signed-off-by: kaixih <kaixih@nvidia.com>
1 parent 18016a5 commit ebb554c

File tree

9 files changed

+736
-5
lines changed

9 files changed

+736
-5
lines changed

CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
264264
"csrc/custom_all_reduce.cu"
265265
"csrc/permute_cols.cu"
266266
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
267+
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
267268
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
268269
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
269270
"csrc/cutlass_extensions/common.cpp")
@@ -377,6 +378,23 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
377378
endif()
378379
endif()
379380

381+
# FP4 Archs and flags
382+
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
383+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FP4_ARCHS)
384+
set(SRCS
385+
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
386+
)
387+
set_gencode_flags_for_srcs(
388+
SRCS "${SRCS}"
389+
CUDA_ARCHS "${FP4_ARCHS}")
390+
list(APPEND VLLM_EXT_SRC "${SRCS}")
391+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
392+
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
393+
else()
394+
message(STATUS "Not building NVFP4 as no compatible archs were found.")
395+
# clear FP4_ARCHS
396+
set(FP4_ARCHS)
397+
endif()
380398

381399
#
382400
# Machete kernels

cmake/utils.cmake

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ endmacro()
257257
# where `<=` is the version comparison operator.
258258
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
259259
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
260-
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
261-
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
262-
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
260+
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
261+
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
262+
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
263263
# The result is stored in `OUT_CUDA_ARCHS`.
264264
#
265265
# Example:
@@ -272,8 +272,8 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
272272
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273273
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
274274

275-
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
276-
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
275+
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
276+
# remove x.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
277277
set(_CUDA_ARCHS)
278278
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
279279
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
@@ -283,6 +283,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
283283
endif()
284284
endif()
285285

286+
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
287+
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
288+
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
289+
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
290+
set(_CUDA_ARCHS "10.0a")
291+
endif()
292+
endif()
293+
286294
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
287295

288296
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that

csrc/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
195195

196196
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
197197

198+
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
199+
torch::Tensor& output_scale,
200+
torch::Tensor const& input_scale);
201+
198202
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
199203
torch::Tensor const& scale);
200204

csrc/quantization/fp4/cudaUtils.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "cutlass/cutlass.h"
20+
#include <climits>
21+
22+
namespace vllm {
23+
namespace common {
24+
25+
class CudaException : public std::runtime_error {
26+
public:
27+
CudaException(const std::string& file, int line, const std::string& message)
28+
: std::runtime_error("CUDA Error at " + file + ":" +
29+
std::to_string(line) + " - " + message) {}
30+
};
31+
32+
template <typename T>
33+
void check(T result, const char* func, const char* file, int line) {
34+
if (result != cudaSuccess) {
35+
throw CudaException(
36+
file, line,
37+
std::string("[VLLM][ERROR] CUDA runtime error in ") + func + ": " +
38+
cudaGetErrorString(static_cast<cudaError_t>(result)));
39+
}
40+
}
41+
42+
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
43+
44+
inline int getMaxSharedMemoryPerBlockOptin() {
45+
int device_id;
46+
int max_shared_memory_per_block;
47+
check_cuda_error(cudaGetDevice(&device_id));
48+
check_cuda_error(cudaDeviceGetAttribute(
49+
&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin,
50+
device_id));
51+
return max_shared_memory_per_block;
52+
}
53+
54+
inline int getSMVersion() {
55+
int device{-1};
56+
check_cuda_error(cudaGetDevice(&device));
57+
int sm_major = 0;
58+
int sm_minor = 0;
59+
check_cuda_error(cudaDeviceGetAttribute(
60+
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
61+
check_cuda_error(cudaDeviceGetAttribute(
62+
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
63+
return sm_major * 10 + sm_minor;
64+
}
65+
66+
} // namespace common
67+
} // namespace vllm
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
20+
void scaled_fp4_quant_sm100a(torch::Tensor& output, torch::Tensor const& input,
21+
torch::Tensor& output_sf,
22+
torch::Tensor const& input_sf);
23+
#endif
24+
25+
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
26+
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
27+
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
28+
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
29+
#endif
30+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
31+
}

0 commit comments

Comments
 (0)