@@ -527,24 +527,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
527527 self .input_batch .num_tokens [req_index ] = end_token_index
528528 else :
529529 req_data = scheduler_output .scheduled_cached_reqs
530+ is_last_rank = get_pp_group ().is_last_rank
530531 for i , req_id in enumerate (req_data .req_ids ):
531532 req_state = self .requests [req_id ]
532533 num_computed_tokens = req_data .num_computed_tokens [i ]
533- new_token_ids = req_data .new_token_ids [i ]
534534 new_block_ids = req_data .new_block_ids [i ]
535535 resumed_from_preemption = req_data .resumed_from_preemption [i ]
536536
537537 req_state .num_computed_tokens = num_computed_tokens
538- # Add the sampled token(s) from the previous step (if any).
539- # This doesn't include "unverified" tokens like spec decode tokens.
540- num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
541- req_state .num_tokens )
542- if num_new_tokens == 1 :
543- # Avoid slicing list in most common case.
544- req_state .output_token_ids .append (new_token_ids [- 1 ])
545- elif num_new_tokens > 0 :
546- req_state .output_token_ids .extend (
547- new_token_ids [- num_new_tokens :])
538+ if not is_last_rank :
539+ new_token_ids = req_data .new_token_ids [i ]
540+ # Add the sampled token(s) from the previous step (if any).
541+ # This doesn't include "unverified" tokens like spec decode tokens.
542+ num_new_tokens = (num_computed_tokens +
543+ len (new_token_ids ) -
544+ req_state .num_tokens )
545+ if num_new_tokens == 1 :
546+ # Avoid slicing list in most common case.
547+ req_state .output_token_ids .append (new_token_ids [- 1 ])
548+ elif num_new_tokens > 0 :
549+ req_state .output_token_ids .extend (
550+ new_token_ids [- num_new_tokens :])
548551 # Update the block IDs.
549552 if not resumed_from_preemption :
550553 # Append the new blocks to the existing block IDs.
@@ -570,25 +573,27 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570573
571574 self .input_batch .block_table .append_row (
572575 new_block_ids , req_index )
573- # Add new_token_ids to token_ids_cpu.
574- start_token_index = num_computed_tokens
575- end_token_index = num_computed_tokens + len (new_token_ids )
576- self .input_batch .token_ids_cpu [
577- req_index ,
578- start_token_index :end_token_index ] = new_token_ids
579- self .input_batch .num_tokens_no_spec [
580- req_index ] = end_token_index
581- # Add spec_token_ids to token_ids_cpu.
582- spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
583- req_id , ())
584- if spec_token_ids :
585- start_index = end_token_index
586- end_token_index += len (spec_token_ids )
576+
577+ if not is_last_rank :
578+ # Add new_token_ids to token_ids_cpu.
579+ start_token_index = num_computed_tokens
580+ end_token_index = num_computed_tokens + len (new_token_ids )
587581 self .input_batch .token_ids_cpu [
588582 req_index ,
589- start_index :end_token_index ] = spec_token_ids
590- # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
591- self .input_batch .num_tokens [req_index ] = end_token_index
583+ start_token_index :end_token_index ] = new_token_ids
584+ self .input_batch .num_tokens_no_spec [
585+ req_index ] = end_token_index
586+ # Add spec_token_ids to token_ids_cpu.
587+ spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
588+ req_id , ())
589+ if spec_token_ids :
590+ start_index = end_token_index
591+ end_token_index += len (spec_token_ids )
592+ self .input_batch .token_ids_cpu [
593+ req_index ,
594+ start_index :end_token_index ] = spec_token_ids
595+ # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
596+ self .input_batch .num_tokens [req_index ] = end_token_index
592597
593598 # Check if the batch has changed. If not, we can skip copying the
594599 # sampling metadata from CPU to GPU.
@@ -1641,6 +1646,30 @@ def execute_model(
16411646
16421647 for i in discard_sampled_tokens_req_indices :
16431648 valid_sampled_token_ids [i ].clear ()
1649+ if not vllm_version_is ("0.9.1" ):
1650+ # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1651+ # doesn't need to send them back.
1652+ # NOTE(woosuk): As an exception, when using PP, the scheduler sends
1653+ # the sampled tokens back, because there's no direct communication
1654+ # between the first-stage worker and the last-stage worker.
1655+ for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1656+ if not sampled_ids :
1657+ continue
1658+
1659+ start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
1660+ end_idx = start_idx + len (sampled_ids )
1661+ assert end_idx <= self .model_config .max_model_len , (
1662+ "Sampled token IDs exceed the max model length. "
1663+ f"Total number of tokens: { end_idx } > max_model_len: "
1664+ f"{ self .model_config .max_model_len } " )
1665+
1666+ self .input_batch .token_ids_cpu [
1667+ req_idx , start_idx :end_idx ] = sampled_ids
1668+ self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
1669+ self .input_batch .num_tokens [req_idx ] = end_idx
1670+ req_id = self .input_batch .req_ids [req_idx ]
1671+ req_state = self .requests [req_id ]
1672+ req_state .output_token_ids .extend (sampled_ids )
16441673
16451674 spec_token_ids = self ._get_spec_token_ids (
16461675 valid_sampled_token_ids ,
0 commit comments