3434from vllm .sampling_params import SamplingType
3535from vllm .sequence import IntermediateTensors
3636from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
37- GiB_bytes , LazyLoader , cdiv , check_use_alibi ,
38- is_pin_memory_available )
37+ GiB_bytes , LazyLoader , async_tensor_h2d , cdiv ,
38+ check_use_alibi , is_pin_memory_available )
3939from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
4040from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
4141from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
@@ -281,7 +281,7 @@ def __init__(
281281 def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> bool :
282282 """
283283 Update the order of requests in the batch based on the attention
284- backend's needs. For example, some attention backends (namely MLA) may
284+ backend's needs. For example, some attention backends (namely MLA) may
285285 want to separate requests based on if the attention computation will be
286286 compute-bound or memory-bound.
287287
@@ -898,6 +898,7 @@ def _calc_spec_decode_metadata(
898898 target_logits_indices = target_logits_indices ,
899899 bonus_logits_indices = bonus_logits_indices ,
900900 logits_indices = logits_indices ,
901+ total_num_scheduled_tokens = cu_num_scheduled_tokens [- 1 ],
901902 )
902903 return metadata
903904
@@ -1360,9 +1361,10 @@ def execute_model(
13601361 scheduler_output .num_scheduled_tokens [req_id ])
13611362 next_token_id = req_state .get_token_id (seq_len )
13621363 next_token_ids .append (next_token_id )
1363- next_token_ids = torch .tensor (next_token_ids ,
1364- dtype = torch .int32 ,
1365- device = self .device )
1364+ next_token_ids = async_tensor_h2d (next_token_ids ,
1365+ dtype = torch .int32 ,
1366+ target_device = self .device ,
1367+ pin_memory = True )
13661368 eagle_attn_metadata = attn_metadata [self .drafter .attn_layer_name ]
13671369
13681370 # NOTE: deepseek_mtp uses MLA which does not have `block_table`
@@ -1390,14 +1392,17 @@ def execute_model(
13901392 n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
13911393 for i , n in enumerate (num_draft_tokens )
13921394 ]
1393- num_rejected_tokens = torch . tensor (
1395+ num_rejected_tokens_tensor = async_tensor_h2d (
13941396 num_rejected_tokens ,
13951397 dtype = torch .int32 ,
1396- device = self .device ,
1397- )
1398+ target_device = self .device ,
1399+ pin_memory = True )
1400+ num_tokens = spec_decode_metadata .total_num_scheduled_tokens - \
1401+ sum (num_rejected_tokens )
13981402 cu_num_tokens , token_indices = self .drafter .prepare_inputs (
13991403 eagle_attn_metadata .query_start_loc ,
1400- num_rejected_tokens ,
1404+ num_rejected_tokens_tensor ,
1405+ num_tokens ,
14011406 )
14021407 target_token_ids = self .input_ids [token_indices ]
14031408 target_positions = positions [token_indices ]
@@ -1408,7 +1413,6 @@ def execute_model(
14081413 target_hidden_states = hidden_states [token_indices ]
14091414 target_slot_mapping = eagle_attn_metadata .slot_mapping [
14101415 token_indices ]
1411-
14121416 draft_token_ids = self .drafter .propose (
14131417 target_token_ids = target_token_ids ,
14141418 target_positions = target_positions ,
0 commit comments