From c20f70befc82751e6e8a0cda6cda9bd8a5a11931 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 29 Aug 2025 00:02:18 -0700 Subject: [PATCH] allow calc_kv_scales Summary: 1. The self.calc_kv_scales() should be invoked without checking `attn_metadata` 2. `attn_metadata` is avail only when full graph mode of cudagraph. If user did not use it, there is an error when checking `attn_metadata.enable_kv_scales_calculation` This diff should fix the above problem. But we can not use torch.compile when we set `calculate_kv_scales=True`, it will complain using .item() in `def calc_kv_scales()` Differential Revision: D81300417 --- vllm/attention/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde..6483b72e4eb9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -240,9 +240,7 @@ def forward( `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(query, key, value) + self.calc_kv_scales(query, key, value) if self.use_output: output_shape = (output_shape if output_shape is not None else query.shape)