Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/v1/test_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from vllm.v1.outputs import (
get_token_count,
list_to_token_ids,
token_ids_to_list,
)


def test_token_ids_to_list():
assert token_ids_to_list(1) == [1]
assert token_ids_to_list(2) == [2]

# Return the original list back without creating new lists
vs = [3, 4]
assert token_ids_to_list(vs) is vs
vs = [5, 6, 7]
assert token_ids_to_list(vs) is vs


def test_list_to_token_ids():
assert list_to_token_ids([10]) == 10
assert list_to_token_ids([20]) == 20

# Return the original list back without creating new lists
vs = [30, 40]
assert list_to_token_ids(vs) is vs
vs = [50, 60, 70]
assert list_to_token_ids(vs) is vs


def test_get_token_count():
assert get_token_count(100) == 1
assert get_token_count(200) == 1
assert get_token_count([300]) == 1
assert get_token_count([400, 500]) == 2
assert get_token_count([600, 700, 800]) == 3
12 changes: 9 additions & 3 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.outputs import (
DraftTokenIds,
KVConnectorOutput,
ModelRunnerOutput,
list_to_token_ids,
token_ids_to_list,
)
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
Expand Down Expand Up @@ -952,7 +958,7 @@ def update_from_output(
continue

req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = (
generated_token_ids: list[int] = token_ids_to_list(
sampled_token_ids[req_index] if sampled_token_ids else []
)

Expand Down Expand Up @@ -1028,7 +1034,7 @@ def update_from_output(
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
new_token_ids=list_to_token_ids(new_token_ids),
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, TokenIDs

# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
Expand Down Expand Up @@ -106,7 +106,7 @@ class EngineCoreOutput(
gc=False,
): # type: ignore[call-arg]
request_id: str
new_token_ids: list[int]
new_token_ids: TokenIDs

new_logprobs: LogprobsLists | None = None
new_prompt_logprobs_tensors: LogprobsTensors | None = None
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
from vllm.v1.outputs import token_ids_to_list


class RequestOutputCollector:
Expand Down Expand Up @@ -421,7 +422,9 @@ def process_outputs(
req_state, engine_core_output, engine_core_timestamp, iteration_stats
)

new_token_ids = engine_core_output.new_token_ids
new_token_ids: list[int] = token_ids_to_list(
engine_core_output.new_token_ids
)
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from vllm.v1.outputs import get_token_count
from vllm.v1.spec_decode.metrics import SpecDecodingStats

if TYPE_CHECKING:
Expand Down Expand Up @@ -232,7 +233,7 @@ def update_from_output(
req_stats: RequestStateStats,
lora_stats: LoRAStats | None,
):
num_new_generation_tokens = len(output.new_token_ids)
num_new_generation_tokens = get_token_count(output.new_token_ids)

self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
Expand Down
37 changes: 35 additions & 2 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, TypeAlias

import torch

Expand All @@ -12,6 +12,39 @@
else:
KVConnectorStats = object

# Alias type for token ID(s).
# 1) int: a single token (most requests emit one output token per batch)
# 2) list[int]: multiple tokens (for spec decoding, requests might emit
# more than one token per batch)
#
# It's introduced in order to mitigate GC costs for large batch size
# scenarios. Without the alias, output tokens per batch are stored as list[int],
# for large batch size, each ModelRunnerOutput will trigger a generation 0
# GC collection and ultimately boost up triggering frequency across all
# generations.
TokenIDs: TypeAlias = int | list[int]


def token_ids_to_list(tokens: TokenIDs) -> list[int]:
"""
Converts TokenIDs to a list of token IDs (i.e. list[int]).
"""
return [tokens] if isinstance(tokens, int) else tokens


def list_to_token_ids(tokens: list[int]) -> TokenIDs:
"""
Converts a list of token IDs (i.e. list[int]) to TokenIDs.
"""
return tokens[0] if len(tokens) == 1 else tokens


def get_token_count(tokens: TokenIDs) -> int:
"""
Returns the number of tokens in TokenIDs.
"""
return 1 if isinstance(tokens, int) else len(tokens)


class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
Expand Down Expand Up @@ -111,7 +144,7 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
sampled_token_ids: list[TokenIDs]

# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import TokenIDs, list_to_token_ids
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
Expand Down Expand Up @@ -117,7 +118,7 @@ def forward(
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
) -> list[TokenIDs]:
"""Parse the output of the rejection sampler.

Args:
Expand All @@ -136,7 +137,8 @@ def parse_output(
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
list_to_token_ids(row[valid_mask[i]].tolist())
for i, row in enumerate(output_token_ids_np)
]
return outputs

Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import TokenIDs, get_token_count, token_ids_to_list
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
Expand Down Expand Up @@ -475,7 +476,7 @@ def propose(

def prepare_next_token_ids_cpu(
self,
sampled_token_ids: list[list[int]],
sampled_token_ids: list[TokenIDs],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
Expand All @@ -490,6 +491,7 @@ def prepare_next_token_ids_cpu(
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
token_ids = token_ids_to_list(token_ids)
if token_ids:
# Common case.
next_token_id = token_ids[-1]
Expand Down Expand Up @@ -807,7 +809,7 @@ def propose_tree(
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
sampled_token_ids: list[TokenIDs],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
Expand All @@ -833,7 +835,7 @@ def prepare_inputs(
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
n + 1 - get_token_count(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numba import get_num_threads, jit, njit, prange, set_num_threads

from vllm.config import VllmConfig
from vllm.v1.outputs import TokenIDs, get_token_count


class NgramProposer:
Expand Down Expand Up @@ -131,7 +132,7 @@ def batch_propose(

def propose(
self,
sampled_token_ids: list[list[int]],
sampled_token_ids: list[TokenIDs],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
Expand All @@ -140,9 +141,8 @@ def propose(
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
if get_token_count(sampled_ids) == 0:
# Skip speculative decoding, if no tokens are sampled
continue

# Skip requests that require sampling parameters that are not
Expand Down
20 changes: 12 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@
ModelRunnerOutput,
PoolerOutput,
SamplerOutput,
TokenIDs,
get_token_count,
token_ids_to_list,
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
Expand Down Expand Up @@ -2248,7 +2251,7 @@ def _bookkeeping_sync(
) -> tuple[
dict[str, int],
LogprobsLists | None,
list[list[int]],
list[TokenIDs],
dict[str, LogprobsTensors | None],
list[str],
dict[str, int],
Expand Down Expand Up @@ -2287,12 +2290,13 @@ def _bookkeeping_sync(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = []
valid_sampled_token_ids: list[TokenIDs]
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
valid_sampled_token_ids = self._to_token_ids(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
Expand All @@ -2301,7 +2305,7 @@ def _bookkeeping_sync(
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
valid_sampled_token_ids[int(i)] = []
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
Expand All @@ -2328,7 +2332,7 @@ def _bookkeeping_sync(
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
sampled_ids = token_ids_to_list(valid_sampled_token_ids[req_idx])
if not sampled_ids:
continue

Expand Down Expand Up @@ -2691,7 +2695,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None:
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: torch.Tensor | list[list[int]],
sampled_token_ids: torch.Tensor | list[TokenIDs],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
Expand Down Expand Up @@ -2724,7 +2728,7 @@ def propose_draft_token_ids(
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids
):
indices.append(offset + len(tokens) - 1)
indices.append(offset + get_token_count(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
Expand Down Expand Up @@ -4688,7 +4692,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:

return kv_cache_spec

def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
def _to_token_ids(self, sampled_token_ids: torch.Tensor) -> list[TokenIDs]:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
Expand All @@ -4701,4 +4705,4 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
return pinned.squeeze(dim=-1).tolist()