Skip to content

Commit 6d1a1d4

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent dfc2236 commit 6d1a1d4

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ def _build_attention_metadata(
13081308
use_spec_decode: bool = False,
13091309
for_cudagraph_capture: bool = False,
13101310
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
1311-
common_prefix_lens: list[list[int]] | None = None,
1311+
cascade_attn_prefix_lens: list[list[int]] | None = None,
13121312
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
13131313
"""
13141314
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
@@ -1414,9 +1414,9 @@ def _build_attention_metadata(
14141414
spec_decode_common_attn_metadata = common_attn_metadata
14151415

14161416
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
1417-
common_prefix_len = (
1418-
common_prefix_lens[kv_cache_gid][attn_gid]
1419-
if common_prefix_lens
1417+
cascade_attn_prefix_len = (
1418+
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
1419+
if cascade_attn_prefix_lens
14201420
else 0
14211421
)
14221422
builder = attn_group.get_metadata_builder()
@@ -1444,7 +1444,7 @@ def _build_attention_metadata(
14441444
)
14451445
else:
14461446
attn_metadata_i = builder.build(
1447-
common_prefix_len=common_prefix_len,
1447+
common_prefix_len=cascade_attn_prefix_len,
14481448
common_attn_metadata=common_attn_metadata,
14491449
)
14501450
for layer_name in kv_cache_group.layer_names:
@@ -1458,7 +1458,7 @@ def _build_attention_metadata(
14581458
)
14591459
else:
14601460
attn_metadata_i = builder.build(
1461-
common_prefix_len=common_prefix_len,
1461+
common_prefix_len=cascade_attn_prefix_len,
14621462
common_attn_metadata=common_attn_metadata,
14631463
**extra_attn_metadata_args,
14641464
)
@@ -1473,30 +1473,33 @@ def _compute_cascade_attn_prefix_lens(
14731473
num_common_prefix_blocks: list[int],
14741474
) -> list[list[int]] | None:
14751475
"""
1476-
:return: Optional[common_prefix_lens]
1477-
common_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
1478-
None if we should not use cascade attention
1476+
:return: Optional[cascade_attn_prefix_lens]
1477+
cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
1478+
None if we should not use cascade attention
14791479
"""
14801480

14811481
use_cascade_attn = False
14821482
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
1483-
common_prefix_lens: list[list[int]] = [[] for _ in range(num_kv_cache_groups)]
1483+
cascade_attn_prefix_lens: list[list[int]] = [
1484+
[] for _ in range(num_kv_cache_groups)
1485+
]
14841486

14851487
for kv_cache_gid in range(num_kv_cache_groups):
14861488
for attn_group in self.attn_groups[kv_cache_gid]:
14871489
if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
14881490
prefix_len = 0
14891491
else:
1490-
prefix_len = self._compute_cascade_attn_prefix_len(
1492+
# 0 if cascade attention should not be used
1493+
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
14911494
num_scheduled_tokens,
14921495
num_common_prefix_blocks[kv_cache_gid],
14931496
attn_group.kv_cache_spec,
14941497
attn_group.get_metadata_builder(),
14951498
)
1496-
common_prefix_lens[kv_cache_gid].append(prefix_len)
1499+
cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
14971500
use_cascade_attn |= prefix_len > 0
14981501

1499-
return common_prefix_lens if use_cascade_attn else None
1502+
return cascade_attn_prefix_lens if use_cascade_attn else None
15001503

15011504
def _compute_cascade_attn_prefix_len(
15021505
self,
@@ -2528,12 +2531,12 @@ def execute_model(
25282531
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
25292532
)
25302533

2531-
common_prefix_lens = None
2534+
cascade_attn_prefix_lens = None
25322535
# Disable cascade attention when using microbatching (DBO)
25332536
if self.cascade_attn_enabled and ubatch_slices is None:
25342537
# Pre-compute cascade attention prefix lengths
25352538
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
2536-
common_prefix_lens = self._compute_cascade_attn_prefix_lens(
2539+
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
25372540
num_scheduled_tokens_np,
25382541
scheduler_output.num_common_prefix_blocks,
25392542
)
@@ -2552,7 +2555,7 @@ def execute_model(
25522555
logits_indices=logits_indices,
25532556
use_spec_decode=use_spec_decode,
25542557
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
2555-
common_prefix_lens=common_prefix_lens,
2558+
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
25562559
)
25572560
)
25582561

@@ -2588,7 +2591,8 @@ def execute_model(
25882591
)
25892592
cudagraph_runtime_mode, batch_descriptor = (
25902593
self.cudagraph_dispatcher.dispatch(
2591-
batch_descriptor, use_cascade_attn=common_prefix_lens is not None
2594+
batch_descriptor,
2595+
use_cascade_attn=cascade_attn_prefix_lens is not None,
25922596
)
25932597
)
25942598

0 commit comments

Comments
 (0)