Skip to content

Commit 46dcace

Browse files
author
David Ben-David
committed
Fix rollback of invalid output tokens and generator state
Signed-off-by: David Ben-David <davidb@pliops.com>
1 parent 3d41b47 commit 46dcace

File tree

5 files changed

+22
-38
lines changed

5 files changed

+22
-38
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
250250
new_token_ids=[[]],
251251
new_block_ids=([[0]], ),
252252
num_computed_tokens=[0],
253+
num_output_tokens=[0],
253254
)
254255

255256
scheduler_output = SchedulerOutput(

vllm/v1/core/sched/output.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class CachedRequestData:
101101
new_token_ids: list[list[int]]
102102
new_block_ids: list[Optional[tuple[list[int], ...]]]
103103
num_computed_tokens: list[int]
104+
num_output_tokens: list[int]
104105

105106
@property
106107
def num_reqs(self) -> int:
@@ -114,6 +115,7 @@ def make_empty(cls) -> CachedRequestData:
114115
new_token_ids=[],
115116
new_block_ids=[],
116117
num_computed_tokens=[],
118+
num_output_tokens=[],
117119
)
118120

119121

vllm/v1/core/sched/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ def _make_cached_request_data(
672672
new_token_ids: list[list[int]] = []
673673
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
674674
num_computed_tokens: list[int] = []
675+
num_output_tokens: list[int] = []
675676

676677
use_connector = self.connector is not None
677678
for req in itertools.chain(running_reqs, resumed_reqs):
@@ -696,6 +697,7 @@ def _make_cached_request_data(
696697
new_block_ids.append(
697698
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
698699
num_computed_tokens.append(req.num_computed_tokens)
700+
num_output_tokens.append(len(req.output_token_ids))
699701
# Because resumed_reqs is usually empty, it is more efficient to do
700702
# in-place appending so that we don't need to allocate a new list.
701703
resumed_from_preemption = [False] * len(running_reqs)
@@ -707,6 +709,7 @@ def _make_cached_request_data(
707709
new_token_ids=new_token_ids,
708710
new_block_ids=new_block_ids,
709711
num_computed_tokens=num_computed_tokens,
712+
num_output_tokens=num_output_tokens,
710713
)
711714

712715
def _try_schedule_encoder_inputs(

vllm/v1/worker/gpu_input_batch.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@ class CachedRequestState:
4848
def __post_init__(self):
4949
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
5050
self.prompt_token_ids, self.prompt_embeds)
51-
# 'last_generator_offset' and 'len_last_output_token_ids' are used to
52-
# allow safe rollback in case a sampled token turns out to be invalid
53-
# (e.g., due to KV load errors).
54-
self.last_generator_offset = 0 if self.generator else None
55-
self.len_last_output_token_ids = len(self.output_token_ids)
5651

5752
@property
5853
def num_tokens(self) -> int:
@@ -242,7 +237,6 @@ def __init__(
242237
# NOTE(woosuk): The indices of the requests that do not have their own
243238
# generator should not be included in the dictionary.
244239
self.generators: dict[int, torch.Generator] = {}
245-
self.generators_last_offset: dict[int, int] = {}
246240

247241
self.num_logprobs: dict[str, int] = {}
248242
# NOTE(rob): num_prompt_logprobs only includes reqs
@@ -393,9 +387,6 @@ def add_request(
393387
# do not have their own generator.
394388
if request.generator is not None:
395389
self.generators[req_index] = request.generator
396-
assert (request.last_generator_offset is not None)
397-
self.generators_last_offset[
398-
req_index] = request.last_generator_offset
399390

400391
if sampling_params.logprobs is not None:
401392
self.num_logprobs[req_id] = (self.vocab_size

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -634,25 +634,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
634634
num_computed_tokens = req_data.num_computed_tokens[i]
635635
new_block_ids = req_data.new_block_ids[i]
636636
resumed_from_preemption = req_data.resumed_from_preemption[i]
637+
num_output_tokens = req_data.num_output_tokens[i]
637638

638639
# Update the cached states.
639-
if num_computed_tokens <= req_state.num_computed_tokens:
640-
# The request was rescheduled after a KV load failure. Clear
641-
# the last sampled tokens and rewind the generator state
642-
len_output_token_ids = len(req_state.output_token_ids)
643-
del req_state.output_token_ids[req_state.
644-
len_last_output_token_ids:]
645-
if req_state.generator:
646-
req_state.generator.set_offset(
647-
req_state.last_generator_offset)
648-
req_index = self.input_batch.req_id_to_index.get(req_id)
649-
if req_index is not None:
650-
len_last_sampled = (len_output_token_ids -
651-
req_state.len_last_output_token_ids)
652-
end_idx = self.input_batch.num_tokens_no_spec[
653-
req_index] - len_last_sampled
654-
self.input_batch.num_tokens[req_index] = end_idx
655-
self.input_batch.num_tokens_no_spec[req_index] = end_idx
656640

657641
req_state.num_computed_tokens = num_computed_tokens
658642

@@ -671,12 +655,21 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
671655
elif num_new_tokens > 0:
672656
req_state.output_token_ids.extend(
673657
new_token_ids[-num_new_tokens:])
658+
elif num_output_tokens < len(req_state.output_token_ids):
659+
# Some output tokens were discarded due to a sync-KV-load
660+
# failure. Align the cached state.
661+
del req_state.output_token_ids[num_output_tokens:]
674662

675-
req_state.len_last_output_token_ids = len(
676-
req_state.output_token_ids)
677-
if req_state.generator:
678-
req_state.last_generator_offset = (
679-
req_state.generator.get_offset())
663+
req_index = self.input_batch.req_id_to_index.get(req_id)
664+
if req_index is not None:
665+
old_end_idx = self.input_batch.num_tokens_no_spec[
666+
req_index]
667+
end_idx = self.input_batch.num_prompt_tokens[
668+
req_index] + num_output_tokens
669+
self.input_batch.num_tokens[req_index] = end_idx
670+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
671+
self.input_batch.is_token_ids[req_index,
672+
end_idx:old_end_idx] = False
680673

681674
# Update the block IDs.
682675
if not resumed_from_preemption:
@@ -699,11 +692,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
699692
reqs_to_add.append(req_state)
700693
continue
701694

702-
if req_state.generator:
703-
assert (req_state.last_generator_offset is not None)
704-
self.input_batch.generators_last_offset[
705-
req_index] = req_state.last_generator_offset
706-
707695
# Update the persistent batch.
708696
self.input_batch.num_computed_tokens_cpu[req_index] = (
709697
num_computed_tokens)
@@ -2185,8 +2173,7 @@ def _bookkeeping_sync(
21852173
for i in discard_sampled_tokens_req_indices:
21862174
gen = self.input_batch.generators.get(int(i))
21872175
if gen is not None:
2188-
offset = self.input_batch.generators_last_offset.get(int(i))
2189-
gen.set_offset(offset)
2176+
gen.set_offset(gen.get_offset() - 4)
21902177

21912178
# Copy some objects so they don't get modified after returning.
21922179
# This is important when using async scheduling.

0 commit comments

Comments
 (0)