@@ -412,7 +412,7 @@ def build(
412412 num_decode_tokens = num_decode_tokens ,
413413 slot_mapping = slot_mapping ,
414414 # to ensure inference when chunked_prefill is disabled
415- seq_lens = seq_lens_cpu .tolist (),
415+ seq_lens = seq_lens_cpu .tolist ()[ num_decodes :], # prefill
416416 decode_seq_lens_tensor = seq_lens_cpu [:num_decodes ], # decode
417417 decode_max_seq_len = max_decode_seq_len , # decode
418418 decode_block_tables = block_table_tensor [:num_decodes ], # decode
@@ -617,7 +617,6 @@ def forward(
617617 prefill_meta .prefill_block_tables ,
618618 self .alibi_slopes ,
619619 )
620-
621620 if decode_meta := attn_metadata .decode_metadata :
622621 assert attn_type != AttentionType .ENCODER_ONLY , (
623622 "Encoder-only models should not have decode metadata."
@@ -686,7 +685,12 @@ def _run_sdpa_forward(
686685 causal_attn = attn_type == AttentionType .DECODER
687686
688687 seq_lens_q , seq_lens_kv = attn_metadata .get_seq_lens (attn_type )
689- start_q , start_kv = 0 , 0
688+ # Incoming Q and KV contain decoded tokens as well, hence start at an offset
689+ # equal to num_decode_tokens since decode requests appear first
690+ start_q , start_kv = (
691+ attn_metadata .num_decode_tokens ,
692+ attn_metadata .num_decode_tokens ,
693+ )
690694 for seq_len_q , seq_len_kv , mask in zip (seq_lens_q , seq_lens_kv , attn_masks ):
691695 end_q = start_q + seq_len_q
692696 end_kv = start_kv + seq_len_kv
0 commit comments