@@ -775,10 +775,10 @@ def execute_model(
775775 sampling_metadata = sampling_metadata ,
776776 )
777777
778- sampled_token_ids = sampler_output .sampled_token_ids
779778 # TODO(woosuk): The following loop can be slow since it iterates over
780779 # the requests one by one. Optimize.
781780 num_reqs = self .input_batch .num_reqs
781+ request_seq_lens : List [Tuple [int , CachedRequestState , int ]] = []
782782 for i , req_id in enumerate (self .input_batch .req_ids [:num_reqs ]):
783783 assert req_id is not None
784784 req_state = self .requests [req_id ]
@@ -787,10 +787,10 @@ def execute_model(
787787 assert seq_len <= req_state .num_tokens
788788 if seq_len == req_state .num_tokens :
789789 # Append the sampled token to the output token ids.
790- token_id = sampled_token_ids [i ]
791- self .input_batch .token_ids_cpu [i , seq_len ] = token_id
792790 self .input_batch .num_tokens [i ] += 1
793- req_state .output_token_ids .append (token_id )
791+ # OPTIMIZATION: Priming the state updates for later updates.
792+ req_state .output_token_ids .append (0 )
793+ request_seq_lens .append ((i , req_state , seq_len ))
794794 else :
795795 # Ignore the sampled token from the partial request.
796796 # Rewind the generator state as if the token was not sampled.
@@ -799,6 +799,21 @@ def execute_model(
799799 # This relies on cuda-specific torch-internal impl details
800800 generator .set_offset (generator .get_offset () - 4 )
801801
802+ # num_reqs entries should be non-None
803+ assert all (
804+ req_id is not None for req_id in
805+ self .input_batch .req_ids [:num_reqs ]), "req_ids contains None"
806+ req_ids = cast (List [str ], self .input_batch .req_ids [:num_reqs ])
807+
808+ # NOTE: GPU -> CPU Sync happens here.
809+ # Move as many CPU operations as possible before this sync point.
810+ sampled_token_ids = sampler_output .sampled_token_ids .tolist ()
811+ # Update with the actual token ids
812+ for i , req_state , seq_len in request_seq_lens :
813+ token_id = sampled_token_ids [i ]
814+ self .input_batch .token_ids_cpu [i , seq_len ] = token_id
815+ req_state .output_token_ids [- 1 ] = token_id
816+
802817 if sampler_output .logprob_token_ids is None :
803818 logprob_token_ids = None
804819 else :
@@ -808,12 +823,6 @@ def execute_model(
808823 else :
809824 logprobs = sampler_output .logprobs .cpu ()
810825
811- # num_reqs entries should be non-None
812- assert all (
813- req_id is not None for req_id in
814- self .input_batch .req_ids [:num_reqs ]), "req_ids contains None"
815- req_ids = cast (List [str ], self .input_batch .req_ids [:num_reqs ])
816-
817826 model_runner_output = ModelRunnerOutput (
818827 req_ids = req_ids ,
819828 req_id_to_index = self .input_batch .req_id_to_index ,
0 commit comments