diff --git a/tests/v1/test_outputs.py b/tests/v1/test_outputs.py new file mode 100644 index 000000000000..bbe94cf58134 --- /dev/null +++ b/tests/v1/test_outputs.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.v1.outputs import ( + get_token_count, + list_to_token_ids, + token_ids_to_list, +) + + +def test_token_ids_to_list(): + assert token_ids_to_list(1) == [1] + assert token_ids_to_list(2) == [2] + + # Return the original list back without creating new lists + vs = [3, 4] + assert token_ids_to_list(vs) is vs + vs = [5, 6, 7] + assert token_ids_to_list(vs) is vs + + +def test_list_to_token_ids(): + assert list_to_token_ids([10]) == 10 + assert list_to_token_ids([20]) == 20 + + # Return the original list back without creating new lists + vs = [30, 40] + assert list_to_token_ids(vs) is vs + vs = [50, 60, 70] + assert list_to_token_ids(vs) is vs + + +def test_get_token_count(): + assert get_token_count(100) == 1 + assert get_token_count(200) == 1 + assert get_token_count([300]) == 1 + assert get_token_count([400, 500]) == 2 + assert get_token_count([600, 700, 800]) == 3 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cbbdf48c6e0c..adba6e5747ca 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -29,7 +29,13 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import ( + DraftTokenIds, + KVConnectorOutput, + ModelRunnerOutput, + list_to_token_ids, + token_ids_to_list, +) from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -952,7 +958,7 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = ( + generated_token_ids: list[int] = token_ids_to_list( sampled_token_ids[req_index] if sampled_token_ids else [] ) @@ -1028,7 +1034,7 @@ def update_from_output( outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, - new_token_ids=new_token_ids, + new_token_ids=list_to_token_ids(new_token_ids), finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e2c1ed7b561c..325dde03ff2c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -14,7 +14,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import LogprobsLists, LogprobsTensors +from vllm.v1.outputs import LogprobsLists, LogprobsTensors, TokenIDs # These are possible values of RequestOutput.finish_reason, # so form part of the external API. @@ -106,7 +106,7 @@ class EngineCoreOutput( gc=False, ): # type: ignore[call-arg] request_id: str - new_token_ids: list[int] + new_token_ids: TokenIDs new_logprobs: LogprobsLists | None = None new_prompt_logprobs_tensors: LogprobsTensors | None = None diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bc1542187c9..40ccdca8d58c 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -23,6 +23,7 @@ from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats +from vllm.v1.outputs import token_ids_to_list class RequestOutputCollector: @@ -421,7 +422,9 @@ def process_outputs( req_state, engine_core_output, engine_core_timestamp, iteration_stats ) - new_token_ids = engine_core_output.new_token_ids + new_token_ids: list[int] = token_ids_to_list( + engine_core_output.new_token_ids + ) pooling_output = engine_core_output.pooling_output finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a4a8ab32ad72..0f365d72c5ef 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +from vllm.v1.outputs import get_token_count from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: @@ -232,7 +233,7 @@ def update_from_output( req_stats: RequestStateStats, lora_stats: LoRAStats | None, ): - num_new_generation_tokens = len(output.new_token_ids) + num_new_generation_tokens = get_token_count(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens if is_prefilling: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c224555da6ca..1057bd929f68 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, TypeAlias import torch @@ -12,6 +12,39 @@ else: KVConnectorStats = object +# Alias type for token ID(s). +# 1) int: a single token (most requests emit one output token per batch) +# 2) list[int]: multiple tokens (for spec decoding, requests might emit +# more than one token per batch) +# +# It's introduced in order to mitigate GC costs for large batch size +# scenarios. Without the alias, output tokens per batch are stored as list[int], +# for large batch size, each ModelRunnerOutput will trigger a generation 0 +# GC collection and ultimately boost up triggering frequency across all +# generations. +TokenIDs: TypeAlias = int | list[int] + + +def token_ids_to_list(tokens: TokenIDs) -> list[int]: + """ + Converts TokenIDs to a list of token IDs (i.e. list[int]). + """ + return [tokens] if isinstance(tokens, int) else tokens + + +def list_to_token_ids(tokens: list[int]) -> TokenIDs: + """ + Converts a list of token IDs (i.e. list[int]) to TokenIDs. + """ + return tokens[0] if len(tokens) == 1 else tokens + + +def get_token_count(tokens: TokenIDs) -> int: + """ + Returns the number of tokens in TokenIDs. + """ + return 1 if isinstance(tokens, int) else len(tokens) + class LogprobsLists(NamedTuple): # [num_reqs, max_num_logprobs + 1] @@ -111,7 +144,7 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + sampled_token_ids: list[TokenIDs] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index f5b075e83b84..d6968a40afee 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -6,6 +6,7 @@ from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.v1.outputs import TokenIDs, list_to_token_ids from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.penalties import apply_all_penalties @@ -117,7 +118,7 @@ def forward( def parse_output( output_token_ids: torch.Tensor, vocab_size: int, - ) -> list[list[int]]: + ) -> list[TokenIDs]: """Parse the output of the rejection sampler. Args: @@ -136,7 +137,8 @@ def parse_output( output_token_ids_np < vocab_size ) outputs = [ - row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) + list_to_token_ids(row[valid_mask[i]].tolist()) + for i, row in enumerate(output_token_ids_np) ] return outputs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ad504da55fd8..06e71a37078f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -36,6 +36,7 @@ CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import TokenIDs, get_token_count, token_ids_to_list from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer @@ -475,7 +476,7 @@ def propose( def prepare_next_token_ids_cpu( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[TokenIDs], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, num_scheduled_tokens: dict[str, int], @@ -490,6 +491,7 @@ def prepare_next_token_ids_cpu( req_ids = gpu_input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): + token_ids = token_ids_to_list(token_ids) if token_ids: # Common case. next_token_id = token_ids[-1] @@ -807,7 +809,7 @@ def propose_tree( def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[TokenIDs], num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ @@ -833,7 +835,7 @@ def prepare_inputs( # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + n + 1 - get_token_count(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index e2f83cb24aa9..e68158d7dc06 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -6,6 +6,7 @@ from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig +from vllm.v1.outputs import TokenIDs, get_token_count class NgramProposer: @@ -131,7 +132,7 @@ def batch_propose( def propose( self, - sampled_token_ids: list[list[int]], + sampled_token_ids: list[TokenIDs], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, @@ -140,9 +141,8 @@ def propose( # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. + if get_token_count(sampled_ids) == 0: + # Skip speculative decoding, if no tokens are sampled continue # Skip requests that require sampling parameters that are not diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c2893bd0926..3343d0eee66d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -119,6 +119,9 @@ ModelRunnerOutput, PoolerOutput, SamplerOutput, + TokenIDs, + get_token_count, + token_ids_to_list, ) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs @@ -2248,7 +2251,7 @@ def _bookkeeping_sync( ) -> tuple[ dict[str, int], LogprobsLists | None, - list[list[int]], + list[TokenIDs], dict[str, LogprobsTensors | None], list[str], dict[str, int], @@ -2287,12 +2290,13 @@ def _bookkeeping_sync( num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] + valid_sampled_token_ids: list[TokenIDs] if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + valid_sampled_token_ids = self._to_token_ids(sampled_token_ids) else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( @@ -2301,7 +2305,7 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() + valid_sampled_token_ids[int(i)] = [] else: valid_sampled_token_ids = [] invalid_req_indices = discard_sampled_tokens_req_indices.tolist() @@ -2328,7 +2332,7 @@ def _bookkeeping_sync( if self.use_async_scheduling: sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: - sampled_ids = valid_sampled_token_ids[req_idx] + sampled_ids = token_ids_to_list(valid_sampled_token_ids[req_idx]) if not sampled_ids: continue @@ -2691,7 +2695,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[list[int]], + sampled_token_ids: torch.Tensor | list[TokenIDs], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -2724,7 +2728,7 @@ def propose_draft_token_ids( for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): - indices.append(offset + len(tokens) - 1) + indices.append(offset + get_token_count(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -4688,7 +4692,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + def _to_token_ids(self, sampled_token_ids: torch.Tensor) -> list[TokenIDs]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -4701,4 +4705,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return pinned.tolist() + return pinned.squeeze(dim=-1).tolist()