Skip to content

Commit 0e1d9af

Browse files
committed
address comments
Signed-off-by: qizixi <qizixi@meta.com>
1 parent 2e5efa8 commit 0e1d9af

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

vllm/v1/spec_decode/metadata.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class SpecDecodeMetadata:
2020
bonus_logits_indices: torch.Tensor
2121
# [num_tokens + batch_size]
2222
logits_indices: torch.Tensor
23-
total_num_scheduled_tokens: int
2423

2524
def __post_init__(self):
2625
self.max_spec_len = max(self.num_draft_tokens)
@@ -59,5 +58,4 @@ def make_dummy(
5958
target_logits_indices=target_logits_indices,
6059
bonus_logits_indices=bonus_logits_indices,
6160
logits_indices=logits_indices,
62-
total_num_scheduled_tokens=num_tokens,
6361
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
from vllm.sampling_params import SamplingType
3535
from vllm.sequence import IntermediateTensors
3636
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
37-
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
38-
check_use_alibi, is_pin_memory_available)
37+
GiB_bytes, LayerBlockType, LazyLoader,
38+
async_tensor_h2d, cdiv, check_use_alibi,
39+
is_pin_memory_available)
3940
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
4041
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
4142
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -898,7 +899,6 @@ def _calc_spec_decode_metadata(
898899
target_logits_indices=target_logits_indices,
899900
bonus_logits_indices=bonus_logits_indices,
900901
logits_indices=logits_indices,
901-
total_num_scheduled_tokens=cu_num_scheduled_tokens[-1],
902902
)
903903
return metadata
904904

@@ -1397,8 +1397,7 @@ def execute_model(
13971397
dtype=torch.int32,
13981398
target_device=self.device,
13991399
pin_memory=True)
1400-
num_tokens = spec_decode_metadata.total_num_scheduled_tokens - \
1401-
sum(num_rejected_tokens)
1400+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
14021401
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
14031402
eagle_attn_metadata.query_start_loc,
14041403
num_rejected_tokens_tensor,

0 commit comments

Comments
 (0)