Skip to content

Commit 0ce7243

Browse files
elvischenvxuebwang-amd
authored andcommitted
[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (vllm-project#23647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 2f51024 commit 0ce7243

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def write_results_to_csv(results, filename=None):
259259
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
260260
(None, None, None),
261261
(None, FP8_DTYPE, None),
262+
(FP8_DTYPE, FP8_DTYPE, None),
262263
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
263264
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
264265
]

benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def write_results_to_csv(results, filename=None):
274274
quant_dtypes = [
275275
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
276276
(None, None, None),
277+
(FP8_DTYPE, FP8_DTYPE, None),
277278
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
278279
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
279280
]

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
3535
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
3636
(None, None, None),
3737
(None, FP8_DTYPE, None),
38+
(FP8_DTYPE, FP8_DTYPE, None),
3839
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
3940
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
4041
]
@@ -44,6 +45,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4445
HEAD_SIZE = [128]
4546
KV_LAYOUT = ["HND"] # currently only HND is supported
4647
BLOCK_SIZE = [16]
48+
WINDOW_LEFT = [-1, 127]
4749
SOFT_CAP = [None, 50.0]
4850

4951
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@@ -57,6 +59,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
5759
@pytest.mark.parametrize("head_size", HEAD_SIZE)
5860
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
5961
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
62+
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
6063
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
6164
@torch.inference_mode
6265
def test_flashinfer_trtllm_decode_with_baseline(
@@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
6972
head_size: int,
7073
kv_layout: str,
7174
block_size: int,
75+
window_left: int,
7276
soft_cap: Optional[float],
7377
) -> None:
7478
torch.set_default_device("cuda")
@@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
155159
sm_scale=sm_scale,
156160
q_data_type=dtype,
157161
kv_data_type=dtype,
162+
window_left=window_left,
158163
logits_soft_cap=soft_cap)
159164

160165
output = torch.empty(ref_query.shape, dtype=dtype)
@@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
188193
max_seq_len=max_seq_len,
189194
bmm1_scale=q_scale * k_scale * sm_scale,
190195
bmm2_scale=v_scale / o_scale,
196+
window_left=window_left,
191197
o_sf_scale=o_sf_scale,
192198
out=output_trtllm,
193199
)
@@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
222228
@pytest.mark.parametrize("head_size", HEAD_SIZE)
223229
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
224230
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
231+
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
225232
@pytest.mark.parametrize("soft_cap", [None])
226233
@torch.inference_mode
227234
def test_flashinfer_trtllm_prefill_with_baseline(
@@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
234241
head_size: int,
235242
kv_layout: str,
236243
block_size: int,
244+
window_left: int,
237245
soft_cap: Optional[float],
238246
) -> None:
239247
torch.set_default_device("cuda")
@@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
334342
sm_scale=sm_scale,
335343
q_data_type=dtype,
336344
kv_data_type=dtype,
345+
window_left=window_left,
337346
logits_soft_cap=soft_cap)
338347

339348
output = torch.empty(ref_query.shape, dtype=dtype)
@@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
371380
batch_size=batch_size,
372381
cum_seq_lens_q=q_indptr,
373382
cum_seq_lens_kv=kv_indptr,
383+
window_left=window_left,
374384
o_sf_scale=o_sf_scale,
375385
out=output_trtllm,
376386
)
@@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
390400
rtol, atol = 4e-1, 1e0
391401
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
392402
rtol, atol = 5e-2, 7e-2
403+
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
404+
rtol, atol = 4e-2, 6e-2
393405
else:
394406
rtol, atol = 1e-2, 1e-2
395407

vllm/compilation/fusion_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ def __init__(self, config: VllmConfig):
258258
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
259259
pattern_fp8.register_if_supported(self.patterns)
260260

261-
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
262-
pattern_nvfp4.register_if_supported(self.patterns)
261+
if current_platform.is_cuda() and hasattr(torch.ops._C,
262+
"scaled_fp4_quant"):
263+
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
264+
pattern_nvfp4.register_if_supported(self.patterns)
263265

264266
if len(attn_layers) == 0:
265267
logger.warning(

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
194194
FlashInferBackend.validate_head_size(self.head_dim)
195195
self.page_size = self.kv_cache_spec.block_size
196196

197-
self.enable_fusion = (
198-
self.compilation_config.pass_config.enable_attn_fusion)
199-
self.q_data_type = self.model_config.dtype
200197
self.cache_dtype = self.cache_config.cache_dtype
201198
if self.cache_dtype.startswith("fp8"):
202199
self.kv_cache_dtype = (
203200
FlashInferBackend.get_fp8_dtype_for_flashinfer(
204201
self.cache_dtype))
205-
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
206-
if self.enable_fusion:
207-
self.q_data_type = self.kv_cache_dtype
208202
else:
203+
assert self.kv_cache_spec.dtype == self.model_config.dtype
209204
self.kv_cache_dtype = self.kv_cache_spec.dtype
205+
self.q_data_type = self.kv_cache_dtype
210206

211207
self._cascade_wrapper = None # Wrapper for cascade attention
212208

@@ -668,8 +664,6 @@ def forward(
668664

669665
# The attn+quant fusion happens when output_scale is provided.
670666
if output_scale is None:
671-
assert attn_metadata.q_data_type != FP8_DTYPE, \
672-
"Query can only be FP8 if output fusion happened."
673667
assert output_block_scale is None, "output_block_scale "\
674668
"is not supported when fusion has not happened"
675669
else:
@@ -697,7 +691,8 @@ def forward(
697691
elif output.dtype == FP4_DTYPE:
698692
self.o_sf_scale = layer._o_scale_float
699693

700-
# Insert FP8 quant for query
694+
# Insert FP8 quant for query
695+
if attn_metadata.q_data_type == FP8_DTYPE:
701696
num_tokens, num_heads, head_size = query.shape
702697
query, _ = ops.scaled_fp8_quant(
703698
query.reshape(

0 commit comments

Comments
 (0)