Skip to content

Commit c692506

Browse files
adabeytayewentao256
authored andcommitted
[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#25513)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 9555929 commit c692506

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

tests/compile/test_full_graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@ def test_custom_compile_config(
139139
run_model(compilation_config, model, model_kwargs)
140140

141141

142+
@pytest.mark.parametrize(
143+
"optimization_level",
144+
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
145+
)
146+
def test_fp8_kv_scale_compile(optimization_level: int):
147+
model = "Qwen/Qwen2-0.5B"
148+
model_kwargs = {
149+
"quantization": "fp8",
150+
"kv_cache_dtype": "fp8_e4m3",
151+
"calculate_kv_scales": True,
152+
"max_model_len": 512,
153+
}
154+
run_model(optimization_level, model, model_kwargs)
155+
156+
142157
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
143158
if not is_torch_equal_or_newer("2.9.0.dev"):
144159
pytest.skip("inductor graph partition is only available "

vllm/attention/layer.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,8 @@ def forward(
277277
`vllm.forward_context.get_forward_context().attn_metadata`.
278278
"""
279279
if self.calculate_kv_scales:
280-
attn_metadata = get_forward_context().attn_metadata
281-
if attn_metadata.enable_kv_scales_calculation:
282-
self.calc_kv_scales(query, key, value)
280+
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
281+
self.layer_name)
283282

284283
output_dtype = query.dtype
285284
if self.query_quant is not None:
@@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector(
554553
attn_metadata[layer_name])
555554

556555

556+
def maybe_calc_kv_scales(
557+
query: torch.Tensor,
558+
key: torch.Tensor,
559+
value: torch.Tensor,
560+
layer_name: str,
561+
) -> None:
562+
563+
forward_context: ForwardContext = get_forward_context()
564+
attn_metadata = forward_context.attn_metadata
565+
566+
if isinstance(attn_metadata, dict):
567+
attn_metadata = attn_metadata[layer_name]
568+
569+
if attn_metadata is None or not getattr(
570+
attn_metadata, 'enable_kv_scales_calculation', False):
571+
return
572+
573+
self = forward_context.no_compile_layers[layer_name]
574+
self.calc_kv_scales(query, key, value)
575+
576+
577+
def maybe_calc_kv_scales_fake(
578+
query: torch.Tensor,
579+
key: torch.Tensor,
580+
value: torch.Tensor,
581+
layer_name: str,
582+
) -> None:
583+
return
584+
585+
586+
direct_register_custom_op(
587+
op_name="maybe_calc_kv_scales",
588+
op_func=maybe_calc_kv_scales,
589+
mutates_args=["query", "key", "value"],
590+
fake_impl=maybe_calc_kv_scales_fake,
591+
)
592+
593+
557594
def unified_attention(
558595
query: torch.Tensor,
559596
key: torch.Tensor,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,6 +2351,15 @@ def execute_model(
23512351
self.cudagraph_dispatcher.dispatch(batch_descriptor,
23522352
use_cascade_attn)
23532353

2354+
# Set cudagraph mode to none if calc_kv_scales is true.
2355+
if attn_metadata is not None:
2356+
metadata_list = (attn_metadata.values() if isinstance(
2357+
attn_metadata, dict) else [attn_metadata])
2358+
if any(
2359+
getattr(m, 'enable_kv_scales_calculation', False)
2360+
for m in metadata_list):
2361+
cudagraph_runtime_mode = CUDAGraphMode.NONE
2362+
23542363
# This is currently to get around the assert in the DPMetadata
23552364
# where it wants `num_tokens_across_dp` to align with `num_tokens`
23562365
if ubatch_slices is not None:

0 commit comments

Comments
 (0)