@@ -634,25 +634,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
634634 num_computed_tokens = req_data .num_computed_tokens [i ]
635635 new_block_ids = req_data .new_block_ids [i ]
636636 resumed_from_preemption = req_data .resumed_from_preemption [i ]
637+ num_output_tokens = req_data .num_output_tokens [i ]
637638
638639 # Update the cached states.
639- if num_computed_tokens <= req_state .num_computed_tokens :
640- # The request was rescheduled after a KV load failure. Clear
641- # the last sampled tokens and rewind the generator state
642- len_output_token_ids = len (req_state .output_token_ids )
643- del req_state .output_token_ids [req_state .
644- len_last_output_token_ids :]
645- if req_state .generator :
646- req_state .generator .set_offset (
647- req_state .last_generator_offset )
648- req_index = self .input_batch .req_id_to_index .get (req_id )
649- if req_index is not None :
650- len_last_sampled = (len_output_token_ids -
651- req_state .len_last_output_token_ids )
652- end_idx = self .input_batch .num_tokens_no_spec [
653- req_index ] - len_last_sampled
654- self .input_batch .num_tokens [req_index ] = end_idx
655- self .input_batch .num_tokens_no_spec [req_index ] = end_idx
656640
657641 req_state .num_computed_tokens = num_computed_tokens
658642
@@ -671,12 +655,21 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
671655 elif num_new_tokens > 0 :
672656 req_state .output_token_ids .extend (
673657 new_token_ids [- num_new_tokens :])
658+ elif num_output_tokens < len (req_state .output_token_ids ):
659+ # Some output tokens were discarded due to a sync-KV-load
660+ # failure. Align the cached state.
661+ del req_state .output_token_ids [num_output_tokens :]
674662
675- req_state .len_last_output_token_ids = len (
676- req_state .output_token_ids )
677- if req_state .generator :
678- req_state .last_generator_offset = (
679- req_state .generator .get_offset ())
663+ req_index = self .input_batch .req_id_to_index .get (req_id )
664+ if req_index is not None :
665+ old_end_idx = self .input_batch .num_tokens_no_spec [
666+ req_index ]
667+ end_idx = self .input_batch .num_prompt_tokens [
668+ req_index ] + num_output_tokens
669+ self .input_batch .num_tokens [req_index ] = end_idx
670+ self .input_batch .num_tokens_no_spec [req_index ] = end_idx
671+ self .input_batch .is_token_ids [req_index ,
672+ end_idx :old_end_idx ] = False
680673
681674 # Update the block IDs.
682675 if not resumed_from_preemption :
@@ -699,11 +692,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
699692 reqs_to_add .append (req_state )
700693 continue
701694
702- if req_state .generator :
703- assert (req_state .last_generator_offset is not None )
704- self .input_batch .generators_last_offset [
705- req_index ] = req_state .last_generator_offset
706-
707695 # Update the persistent batch.
708696 self .input_batch .num_computed_tokens_cpu [req_index ] = (
709697 num_computed_tokens )
@@ -2185,8 +2173,7 @@ def _bookkeeping_sync(
21852173 for i in discard_sampled_tokens_req_indices :
21862174 gen = self .input_batch .generators .get (int (i ))
21872175 if gen is not None :
2188- offset = self .input_batch .generators_last_offset .get (int (i ))
2189- gen .set_offset (offset )
2176+ gen .set_offset (gen .get_offset () - 4 )
21902177
21912178 # Copy some objects so they don't get modified after returning.
21922179 # This is important when using async scheduling.
0 commit comments