@@ -1308,7 +1308,7 @@ def _build_attention_metadata(
13081308 use_spec_decode : bool = False ,
13091309 for_cudagraph_capture : bool = False ,
13101310 scheduled_encoder_inputs : dict [str , list [int ]] | None = None ,
1311- common_prefix_lens : list [list [int ]] | None = None ,
1311+ cascade_attn_prefix_lens : list [list [int ]] | None = None ,
13121312 ) -> tuple [PerLayerAttnMetadata , CommonAttentionMetadata | None ]:
13131313 """
13141314 :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
@@ -1414,9 +1414,9 @@ def _build_attention_metadata(
14141414 spec_decode_common_attn_metadata = common_attn_metadata
14151415
14161416 for attn_gid , attn_group in enumerate (self .attn_groups [kv_cache_gid ]):
1417- common_prefix_len = (
1418- common_prefix_lens [kv_cache_gid ][attn_gid ]
1419- if common_prefix_lens
1417+ cascade_attn_prefix_len = (
1418+ cascade_attn_prefix_lens [kv_cache_gid ][attn_gid ]
1419+ if cascade_attn_prefix_lens
14201420 else 0
14211421 )
14221422 builder = attn_group .get_metadata_builder ()
@@ -1444,7 +1444,7 @@ def _build_attention_metadata(
14441444 )
14451445 else :
14461446 attn_metadata_i = builder .build (
1447- common_prefix_len = common_prefix_len ,
1447+ common_prefix_len = cascade_attn_prefix_len ,
14481448 common_attn_metadata = common_attn_metadata ,
14491449 )
14501450 for layer_name in kv_cache_group .layer_names :
@@ -1458,7 +1458,7 @@ def _build_attention_metadata(
14581458 )
14591459 else :
14601460 attn_metadata_i = builder .build (
1461- common_prefix_len = common_prefix_len ,
1461+ common_prefix_len = cascade_attn_prefix_len ,
14621462 common_attn_metadata = common_attn_metadata ,
14631463 ** extra_attn_metadata_args ,
14641464 )
@@ -1473,30 +1473,33 @@ def _compute_cascade_attn_prefix_lens(
14731473 num_common_prefix_blocks : list [int ],
14741474 ) -> list [list [int ]] | None :
14751475 """
1476- :return: Optional[common_prefix_lens ]
1477- common_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
1478- None if we should not use cascade attention
1476+ :return: Optional[cascade_attn_prefix_lens ]
1477+ cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
1478+ None if we should not use cascade attention
14791479 """
14801480
14811481 use_cascade_attn = False
14821482 num_kv_cache_groups = len (self .kv_cache_config .kv_cache_groups )
1483- common_prefix_lens : list [list [int ]] = [[] for _ in range (num_kv_cache_groups )]
1483+ cascade_attn_prefix_lens : list [list [int ]] = [
1484+ [] for _ in range (num_kv_cache_groups )
1485+ ]
14841486
14851487 for kv_cache_gid in range (num_kv_cache_groups ):
14861488 for attn_group in self .attn_groups [kv_cache_gid ]:
14871489 if isinstance (attn_group .kv_cache_spec , EncoderOnlyAttentionSpec ):
14881490 prefix_len = 0
14891491 else :
1490- prefix_len = self ._compute_cascade_attn_prefix_len (
1492+ # 0 if cascade attention should not be used
1493+ cascade_attn_prefix_len = self ._compute_cascade_attn_prefix_len (
14911494 num_scheduled_tokens ,
14921495 num_common_prefix_blocks [kv_cache_gid ],
14931496 attn_group .kv_cache_spec ,
14941497 attn_group .get_metadata_builder (),
14951498 )
1496- common_prefix_lens [kv_cache_gid ].append (prefix_len )
1499+ cascade_attn_prefix_lens [kv_cache_gid ].append (cascade_attn_prefix_len )
14971500 use_cascade_attn |= prefix_len > 0
14981501
1499- return common_prefix_lens if use_cascade_attn else None
1502+ return cascade_attn_prefix_lens if use_cascade_attn else None
15001503
15011504 def _compute_cascade_attn_prefix_len (
15021505 self ,
@@ -2528,12 +2531,12 @@ def execute_model(
25282531 scheduler_output , num_scheduled_tokens_np , max_num_scheduled_tokens
25292532 )
25302533
2531- common_prefix_lens = None
2534+ cascade_attn_prefix_lens = None
25322535 # Disable cascade attention when using microbatching (DBO)
25332536 if self .cascade_attn_enabled and ubatch_slices is None :
25342537 # Pre-compute cascade attention prefix lengths
25352538 # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
2536- common_prefix_lens = self ._compute_cascade_attn_prefix_lens (
2539+ cascade_attn_prefix_lens = self ._compute_cascade_attn_prefix_lens (
25372540 num_scheduled_tokens_np ,
25382541 scheduler_output .num_common_prefix_blocks ,
25392542 )
@@ -2552,7 +2555,7 @@ def execute_model(
25522555 logits_indices = logits_indices ,
25532556 use_spec_decode = use_spec_decode ,
25542557 scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs ,
2555- common_prefix_lens = common_prefix_lens ,
2558+ cascade_attn_prefix_lens = cascade_attn_prefix_lens ,
25562559 )
25572560 )
25582561
@@ -2588,7 +2591,8 @@ def execute_model(
25882591 )
25892592 cudagraph_runtime_mode , batch_descriptor = (
25902593 self .cudagraph_dispatcher .dispatch (
2591- batch_descriptor , use_cascade_attn = common_prefix_lens is not None
2594+ batch_descriptor ,
2595+ use_cascade_attn = cascade_attn_prefix_lens is not None ,
25922596 )
25932597 )
25942598
0 commit comments