Skip to content

Commit 31a4b3e

Browse files
authored
Revert #24446 and #26168 (#26332)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent caf8b1c commit 31a4b3e

File tree

5 files changed

+10
-117
lines changed

5 files changed

+10
-117
lines changed

tests/entrypoints/llm/test_generate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,10 @@ def test_max_model_len():
8585
num_total_tokens = len(output.prompt_token_ids) + len(
8686
output.outputs[0].token_ids
8787
)
88-
# Total tokens must not exceed max_model_len + 1 (the last token can be
89-
# generated with the context length equal to the max model length)
88+
# Total tokens must not exceed max_model_len.
9089
# It can be less if generation finishes due to other reasons (e.g., EOS)
9190
# before reaching the absolute model length limit.
92-
assert num_total_tokens <= max_model_len + 1
91+
assert num_total_tokens <= max_model_len
9392

9493

9594
def test_log_stats():

tests/v1/e2e/test_context_length.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def schedule(self) -> SchedulerOutput:
223223
# Make sure the input position does not exceed the max model len.
224224
# This is necessary when using spec decoding.
225225
num_new_tokens = min(
226-
num_new_tokens, self.max_model_len - request.num_computed_tokens
226+
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
227227
)
228228

229229
# Schedule encoder inputs.

vllm/v1/core/sched/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def check_stop(
4444
request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None
4545
) -> bool:
4646
if (
47-
request.num_tokens > max_model_len
47+
request.num_tokens >= max_model_len
4848
or request.num_output_tokens >= request.max_tokens
4949
):
5050
request.status = RequestStatus.FINISHED_LENGTH_CAPPED

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,30 +2317,14 @@ def _bookkeeping_sync(
23172317

23182318
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
23192319
end_idx = start_idx + len(sampled_ids)
2320-
assert end_idx <= self.max_model_len + 1, (
2321-
"Sampled token IDs exceed the max model length + 1. "
2322-
f"Total number of tokens: {end_idx} > max_model_len + 1: "
2323-
f"{self.max_model_len + 1}"
2320+
assert end_idx <= self.max_model_len, (
2321+
"Sampled token IDs exceed the max model length. "
2322+
f"Total number of tokens: {end_idx} > max_model_len: "
2323+
f"{self.max_model_len}"
23242324
)
23252325

2326-
n_tokens_cache = len(sampled_ids)
2327-
2328-
# Sampled token IDs exceed the max model length by 1. This is
2329-
# legitimate as we can still sample 1 last token when the context
2330-
# length equals the max model length. Note that we do not need to
2331-
# cache this token ID as the sequence finishes after this step.
2332-
# Additionally, the buffers token_ids_cpu and is_token_ids are of
2333-
# size max model length only.
2334-
if end_idx == self.max_model_len + 1:
2335-
n_tokens_cache -= 1
2336-
2337-
self.input_batch.token_ids_cpu[
2338-
req_idx, start_idx : (start_idx + n_tokens_cache)
2339-
] = sampled_ids[:n_tokens_cache]
2340-
self.input_batch.is_token_ids[
2341-
req_idx, start_idx : (start_idx + n_tokens_cache)
2342-
] = True
2343-
2326+
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
2327+
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
23442328
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
23452329
self.input_batch.num_tokens[req_idx] = end_idx
23462330

0 commit comments

Comments
 (0)