Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,16 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);

#ifndef USE_ROCM

bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);

void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);

bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
Expand Down
6 changes: 6 additions & 0 deletions csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above.");
}

bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
}
7 changes: 4 additions & 3 deletions csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)

#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);

// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif

// Quantized GEMM for GPTQ.
Expand Down
82 changes: 82 additions & 0 deletions tests/models/decoder_only/language/test_nvfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Model Optimizer nvfp4 models against ground truth generation
Note: these tests will only pass on B200
"""
import os
from typing import List

import pytest
from transformers import AutoTokenizer

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024

MODELS = ["nvidia/Llama-3.3-70B-Instruct-FP4"]

EXPECTED_STRS_MAP = {
"nvidia/Llama-3.3-70B-Instruct-FP4": [
'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process',
'A neural network is a type of machine learning model inspired by the structure and function of the human brain',
'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts'
]
}


# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp4 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@pytest.mark.skip(
reason=
"Prevent unstable test based on golden strings from breaking the build "
" and test input model being too large and hanging the system.")
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("nvfp4"),
reason="nvfp4 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(example_prompts, model_name) -> None:
model = LLM(
model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="nvfp4",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]
params = SamplingParams(max_tokens=20, temperature=0)
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model

print(model_name, generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def _ggml_mul_mat_a8_fake(


# cutlass
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)


def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor, alpha: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def _verify_quantization(self) -> None:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark"
"compressed-tensors", "experts_int8", "quark", "nvfp4"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
Expand Down
23 changes: 17 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,23 @@
logger = init_logger(__name__)

WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod", "QuarkLinearMethod"
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod",
"HQQMarlinMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
]


Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
Expand All @@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
Expand Down
Loading