3636 CommonAttentionMetadata ,
3737)
3838from vllm .v1 .kv_cache_interface import KVCacheConfig
39+ from vllm .v1 .outputs import TokenIDs , convert_to_token_id_list , get_token_count
3940from vllm .v1 .sample .metadata import SamplingMetadata
4041from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
4142from vllm .v1 .utils import CpuGpuBuffer
@@ -475,7 +476,7 @@ def propose(
475476
476477 def prepare_next_token_ids_cpu (
477478 self ,
478- sampled_token_ids : list [list [ int ] ],
479+ sampled_token_ids : list [TokenIDs ],
479480 requests : dict [str , CachedRequestState ],
480481 gpu_input_batch : InputBatch ,
481482 num_scheduled_tokens : dict [str , int ],
@@ -490,6 +491,7 @@ def prepare_next_token_ids_cpu(
490491 req_ids = gpu_input_batch .req_ids
491492 next_token_ids : list [int ] = []
492493 for i , token_ids in enumerate (sampled_token_ids ):
494+ token_ids = convert_to_token_id_list (token_ids )
493495 if token_ids :
494496 # Common case.
495497 next_token_id = token_ids [- 1 ]
@@ -807,7 +809,7 @@ def propose_tree(
807809 def prepare_inputs (
808810 self ,
809811 common_attn_metadata : CommonAttentionMetadata ,
810- sampled_token_ids : list [list [ int ] ],
812+ sampled_token_ids : list [TokenIDs ],
811813 num_draft_tokens : list [int ],
812814 ) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
813815 """
@@ -833,7 +835,7 @@ def prepare_inputs(
833835 # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
834836
835837 num_rejected_tokens = [
836- n + 1 - len (sampled_token_ids [i ]) if n > 0 else 0
838+ n + 1 - get_token_count (sampled_token_ids [i ]) if n > 0 else 0
837839 for i , n in enumerate (num_draft_tokens )
838840 ]
839841 num_rejected_tokens = torch .tensor (num_rejected_tokens , dtype = torch .int32 )
0 commit comments