Skip to content

Commit da68fff

Browse files
yewentao256albertoperdomo2
authored andcommitted
[Feature] Batch Invariant: Support DeepGEMM and Blackwell (vllm-project#27127)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 39e24e5 commit da68fff

File tree

3 files changed

+71
-21
lines changed

3 files changed

+71
-21
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from vllm import LLM, SamplingParams
1111
from vllm.platforms import current_platform
1212

13-
hopper_only = pytest.mark.skipif(
14-
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
15-
reason="Requires CUDA and Hopper (SM90)",
13+
skip_unsupported = pytest.mark.skipif(
14+
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
15+
reason="Requires CUDA and >= Hopper (SM90)",
1616
)
1717

1818

@@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
7474
return base_prompt
7575

7676

77-
@hopper_only
77+
@skip_unsupported
7878
@pytest.mark.timeout(1000)
7979
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
8080
"""
@@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output):
219219
return None, None
220220

221221

222-
@hopper_only
222+
@skip_unsupported
223223
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
224224
@pytest.mark.forked
225225
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
@@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
434434
pytest.fail(msg)
435435

436436

437-
@hopper_only
437+
@skip_unsupported
438438
def test_simple_generation():
439439
"""
440440
Simple test that runs the model with a basic prompt and prints the output.
@@ -480,7 +480,7 @@ def test_simple_generation():
480480
llm.shutdown()
481481

482482

483-
@hopper_only
483+
@skip_unsupported
484484
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
485485
@pytest.mark.forked
486486
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
@@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
707707
os.environ["VLLM_BATCH_INVARIANT"] = old_value
708708

709709

710-
@hopper_only
710+
@skip_unsupported
711711
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
712712
@pytest.mark.forked
713713
def test_decode_logprobs_match_prefill_logprobs(backend):

tests/v1/generation/test_rms_norm_batch_invariant.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from vllm.model_executor.layers.layernorm import RMSNorm
1515
from vllm.platforms import current_platform
1616

17-
hopper_only = pytest.mark.skipif(
18-
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
19-
reason="Requires CUDA and Hopper (SM90)",
17+
skip_unsupported = pytest.mark.skipif(
18+
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
19+
reason="Requires CUDA and >= Hopper (SM90)",
2020
)
2121

2222

23-
@hopper_only
23+
@skip_unsupported
2424
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
2525
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
2626
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
6969
)
7070

7171

72-
@hopper_only
72+
@skip_unsupported
7373
@pytest.mark.parametrize("batch_size", [1, 16, 128])
7474
@pytest.mark.parametrize("seq_len", [1, 32, 512])
7575
@pytest.mark.parametrize("hidden_size", [2048, 4096])
@@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
111111
)
112112

113113

114-
@hopper_only
114+
@skip_unsupported
115115
def test_rms_norm_numerical_stability():
116116
"""
117117
Test RMS norm numerical stability with extreme values.
@@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability():
171171
)
172172

173173

174-
@hopper_only
174+
@skip_unsupported
175175
def test_rms_norm_formula():
176176
"""
177177
Test that RMS norm follows the correct mathematical formula.
@@ -204,7 +204,7 @@ def test_rms_norm_formula():
204204
)
205205

206206

207-
@hopper_only
207+
@skip_unsupported
208208
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
209209
def test_rms_norm_different_hidden_sizes(hidden_size: int):
210210
"""
@@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
242242
)
243243

244244

245-
@hopper_only
245+
@skip_unsupported
246246
def test_rms_norm_determinism():
247247
"""
248248
Test that batch-invariant RMS norm produces deterministic results.

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
QuantizationConfig,
4242
QuantizeMethodBase,
4343
)
44+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
4445
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
4546
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
4647
FlashinferMoeBackend,
@@ -94,9 +95,11 @@
9495
from vllm.scalar_type import scalar_types
9596
from vllm.utils import has_deep_gemm
9697
from vllm.utils.deep_gemm import (
98+
fp8_gemm_nt,
9799
get_col_major_tma_aligned_tensor,
98100
is_deep_gemm_e8m0_used,
99101
is_deep_gemm_supported,
102+
should_use_deepgemm_for_fp8_linear,
100103
)
101104
from vllm.utils.flashinfer import has_flashinfer_moe
102105

@@ -539,8 +542,34 @@ def apply(
539542
x: torch.Tensor,
540543
bias: torch.Tensor | None = None,
541544
) -> torch.Tensor:
542-
# If batch invariant mode is enabled, dequantize and use BF16 compute
545+
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
546+
# we will use BF16 dequant when DeepGEMM is not supported.
543547
if vllm_is_batch_invariant():
548+
if self.block_quant and should_use_deepgemm_for_fp8_linear(
549+
torch.bfloat16, layer.weight, None
550+
):
551+
# use group quant consistent with block size across K
552+
assert self.act_q_group_shape is not None
553+
q_input, input_scale = QuantFP8(
554+
False,
555+
self.act_q_group_shape,
556+
column_major_scales=True,
557+
)(x)
558+
559+
output_2d = torch.empty(
560+
(q_input.shape[0], layer.weight.shape[0]),
561+
dtype=torch.bfloat16,
562+
device=q_input.device,
563+
)
564+
fp8_gemm_nt(
565+
(q_input, input_scale),
566+
(layer.weight, layer.weight_scale),
567+
output_2d,
568+
)
569+
if bias is not None:
570+
output_2d = output_2d + bias
571+
return output_2d
572+
544573
# Dequantize FP8 weights to BF16
545574
weight_fp8 = layer.weight.to(torch.bfloat16)
546575
weight_scale = layer.weight_scale.to(torch.bfloat16)
@@ -555,9 +584,30 @@ def apply(
555584

556585
N, K = weight_fp8.shape
557586

558-
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
559-
# We need to transpose it to [num_blocks_n, num_blocks_k] first
560-
weight_scale = weight_scale.t()
587+
# determine expected number of blocks along N and K
588+
num_blocks_n = (N + block_n - 1) // block_n
589+
num_blocks_k = (K + block_k - 1) // block_k
590+
591+
# scale layout may be [num_blocks_n, num_blocks_k]
592+
# or [num_blocks_k, num_blocks_n] depending on backend
593+
if weight_scale.dim() != 2:
594+
raise RuntimeError(
595+
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
596+
)
597+
598+
scale_rows, scale_cols = weight_scale.shape
599+
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
600+
if num_blocks_n == num_blocks_k:
601+
# ambiguous square case, warn and skip transpose
602+
logger.warning(
603+
"Batch-invariant FP8: square block-scale %dx%d; "
604+
"skipping transpose to avoid misorientation.",
605+
scale_rows,
606+
scale_cols,
607+
)
608+
else:
609+
# clear KN -> transpose to NK
610+
weight_scale = weight_scale.t()
561611

562612
# Expand scale to match weight dimensions
563613
# scale_expanded should have shape [N, K]

0 commit comments

Comments
 (0)