Skip to content

Commit ad0d3d8

Browse files
committed
CodeX comment
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
1 parent 687c10d commit ad0d3d8

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

vllm/v1/outputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
77

88
import torch
9+
from typing_extensions import TypeAlias
910

1011
if TYPE_CHECKING:
1112
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
CommonAttentionMetadata,
3737
)
3838
from vllm.v1.kv_cache_interface import KVCacheConfig
39+
from vllm.v1.outputs import TokenIDs, convert_to_token_id_list, get_token_count
3940
from vllm.v1.sample.metadata import SamplingMetadata
4041
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
4142
from vllm.v1.utils import CpuGpuBuffer
@@ -475,7 +476,7 @@ def propose(
475476

476477
def prepare_next_token_ids_cpu(
477478
self,
478-
sampled_token_ids: list[list[int]],
479+
sampled_token_ids: list[TokenIDs],
479480
requests: dict[str, CachedRequestState],
480481
gpu_input_batch: InputBatch,
481482
num_scheduled_tokens: dict[str, int],
@@ -490,6 +491,7 @@ def prepare_next_token_ids_cpu(
490491
req_ids = gpu_input_batch.req_ids
491492
next_token_ids: list[int] = []
492493
for i, token_ids in enumerate(sampled_token_ids):
494+
token_ids = convert_to_token_id_list(token_ids)
493495
if token_ids:
494496
# Common case.
495497
next_token_id = token_ids[-1]
@@ -807,7 +809,7 @@ def propose_tree(
807809
def prepare_inputs(
808810
self,
809811
common_attn_metadata: CommonAttentionMetadata,
810-
sampled_token_ids: list[list[int]],
812+
sampled_token_ids: list[TokenIDs],
811813
num_draft_tokens: list[int],
812814
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
813815
"""
@@ -833,7 +835,7 @@ def prepare_inputs(
833835
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
834836

835837
num_rejected_tokens = [
836-
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
838+
n + 1 - get_token_count(sampled_token_ids[i]) if n > 0 else 0
837839
for i, n in enumerate(num_draft_tokens)
838840
]
839841
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)

0 commit comments

Comments
 (0)