3434from vllm .sampling_params import SamplingType
3535from vllm .sequence import IntermediateTensors
3636from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
37- GiB_bytes , LazyLoader , cdiv , check_use_alibi ,
38- is_pin_memory_available )
37+ GiB_bytes , LazyLoader , async_tensor_h2d , cdiv ,
38+ check_use_alibi , is_pin_memory_available )
3939from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
4040from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
4141from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
@@ -281,7 +281,7 @@ def __init__(
281281 def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> bool :
282282 """
283283 Update the order of requests in the batch based on the attention
284- backend's needs. For example, some attention backends (namely MLA) may
284+ backend's needs. For example, some attention backends (namely MLA) may
285285 want to separate requests based on if the attention computation will be
286286 compute-bound or memory-bound.
287287
@@ -1360,9 +1360,10 @@ def execute_model(
13601360 scheduler_output .num_scheduled_tokens [req_id ])
13611361 next_token_id = req_state .get_token_id (seq_len )
13621362 next_token_ids .append (next_token_id )
1363- next_token_ids = torch .tensor (next_token_ids ,
1364- dtype = torch .int32 ,
1365- device = self .device )
1363+ next_token_ids = async_tensor_h2d (next_token_ids ,
1364+ dtype = torch .int32 ,
1365+ target_device = self .device ,
1366+ pin_memory = True )
13661367 eagle_attn_metadata = attn_metadata [self .drafter .attn_layer_name ]
13671368
13681369 # NOTE: deepseek_mtp uses MLA which does not have `block_table`
@@ -1390,14 +1391,16 @@ def execute_model(
13901391 n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
13911392 for i , n in enumerate (num_draft_tokens )
13921393 ]
1393- num_rejected_tokens = torch . tensor (
1394+ num_rejected_tokens_tensor = async_tensor_h2d (
13941395 num_rejected_tokens ,
13951396 dtype = torch .int32 ,
1396- device = self .device ,
1397- )
1397+ target_device = self .device ,
1398+ pin_memory = True )
1399+ num_tokens = num_scheduled_tokens - sum (num_rejected_tokens )
13981400 cu_num_tokens , token_indices = self .drafter .prepare_inputs (
13991401 eagle_attn_metadata .query_start_loc ,
1400- num_rejected_tokens ,
1402+ num_rejected_tokens_tensor ,
1403+ num_tokens ,
14011404 )
14021405 target_token_ids = self .input_ids [token_indices ]
14031406 target_positions = positions [token_indices ]
@@ -1408,7 +1411,6 @@ def execute_model(
14081411 target_hidden_states = hidden_states [token_indices ]
14091412 target_slot_mapping = eagle_attn_metadata .slot_mapping [
14101413 token_indices ]
1411-
14121414 draft_token_ids = self .drafter .propose (
14131415 target_token_ids = target_token_ids ,
14141416 target_positions = target_positions ,
0 commit comments