@@ -2367,7 +2367,7 @@ def propose_draft_token_ids(
23672367 sampling_metadata : SamplingMetadata ,
23682368 hidden_states : torch .Tensor ,
23692369 sample_hidden_states : torch .Tensor ,
2370- aux_hidden_states : Optional [torch .Tensor ],
2370+ aux_hidden_states : Optional [list [ torch .Tensor ] ],
23712371 spec_decode_metadata : Optional [SpecDecodeMetadata ],
23722372 common_attn_metadata : CommonAttentionMetadata ,
23732373 ) -> Union [list [list [int ]], torch .Tensor ]:
@@ -2387,6 +2387,7 @@ def propose_draft_token_ids(
23872387 else :
23882388 indices = []
23892389 offset = 0
2390+ assert spec_decode_metadata is not None
23902391 for num_draft , tokens in zip (
23912392 spec_decode_metadata .num_draft_tokens ,
23922393 sampled_token_ids ):
@@ -2437,6 +2438,7 @@ def propose_draft_token_ids(
24372438 # TODO(woosuk): Support M-RoPE.
24382439 target_positions = self .positions .gpu [:num_scheduled_tokens ]
24392440 if self .use_aux_hidden_state_outputs :
2441+ assert aux_hidden_states is not None
24402442 target_hidden_states = torch .cat (
24412443 [h [:num_scheduled_tokens ] for h in aux_hidden_states ],
24422444 dim = - 1 )
@@ -2462,6 +2464,7 @@ def propose_draft_token_ids(
24622464 # TODO(woosuk): Support M-RoPE.
24632465 target_positions = self .positions .gpu [token_indices ]
24642466 if self .use_aux_hidden_state_outputs :
2467+ assert aux_hidden_states is not None
24652468 target_hidden_states = torch .cat (
24662469 [h [token_indices ] for h in aux_hidden_states ], dim = - 1 )
24672470 else :
@@ -2897,7 +2900,9 @@ def _dummy_run(
28972900 assert not create_mixed_batch
28982901 num_reqs = cdiv (num_tokens , max_query_len )
28992902 assert num_reqs <= max_num_reqs , \
2900- "Do not capture num_reqs > max_num_reqs for uniform batch"
2903+ f"Do not capture num_reqs { num_reqs } > max_num_reqs " \
2904+ f"{ max_num_reqs } for uniform batch. Num tokens: " \
2905+ f"{ num_tokens } , max_query_len: { max_query_len } "
29012906 num_scheduled_tokens_list = [max_query_len ] * num_reqs
29022907 if num_tokens % max_query_len != 0 :
29032908 num_scheduled_tokens_list [- 1 ] = num_tokens % max_query_len
0 commit comments