Skip to content

Commit e79b220

Browse files
ignore type
1 parent fa95264 commit e79b220

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

vllm/engine/arg_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14761476
("model" in self.speculative_config and
14771477
self.speculative_config["model"] in ("ngram", "[ngram]"))):
14781478
is_ngram_enabled = True
1479-
elif (("model" in self.speculative_config and
1480-
"eagle" in self.speculative_config["model"].lower())):
1479+
elif ("model" in self.speculative_config
1480+
and "eagle" in self.speculative_config["model"].lower()):
14811481
is_eagle_enabled = True
14821482
else:
14831483
_raise_or_fallback(feature_name="Speculative Decoding",
@@ -1517,7 +1517,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15171517

15181518
# LoRA is supported on V1, but off by default for now.
15191519
if self.enable_lora and _warn_or_fallback("LORA"):
1520-
return Falsef
1520+
return False
15211521

15221522
# PP is supported on V1 with Ray distributed executor,
15231523
# but off for MP distributed executor for now.
@@ -1529,7 +1529,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15291529
# ngram is supported on V1, but off by default for now.
15301530
if is_ngram_enabled and _warn_or_fallback("ngram"):
15311531
return False
1532-
1532+
15331533
if is_eagle_enabled and _warn_or_fallback("eagle"):
15341534
return False
15351535

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
import numpy as np
55
from numba import jit
66

7+
from vllm.config import VllmConfig
8+
79

810
class NgramProposer:
911

12+
def __init__(self, vllm_config: VllmConfig):
13+
self.vllm_config = vllm_config
14+
1015
def propose(
1116
self,
1217
context_token_ids: np.ndarray,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@ def __init__(
159159
self.use_spec_decode = True
160160
if get_pp_group().is_last_rank:
161161
if self.speculative_config.method == "ngram":
162-
self.drafter = NgramProposer(self.vllm_config)
162+
self.drafter = NgramProposer(
163+
self.vllm_config) # type:ignore
163164
elif self.speculative_config.method == "eagle":
164-
self.drafter = EagleProposer(self.vllm_config, self.device)
165+
self.drafter = EagleProposer(self.vllm_config,
166+
self.device) # type:ignore
165167
else:
166168
raise ValueError("Unknown speculative decoding method: "
167169
f"{self.speculative_config.method}")
@@ -1143,9 +1145,11 @@ def execute_model(
11431145
# Speculative decoding is not enabled.
11441146
spec_token_ids = None
11451147
elif self.speculative_config.method == "ngram":
1148+
assert isinstance(self.drafter, NgramProposer)
11461149
spec_token_ids = self.generate_draft_token_ids(
11471150
valid_sampled_token_ids, sampling_metadata)
11481151
elif self.speculative_config.method == "eagle":
1152+
assert isinstance(self.drafter, EagleProposer)
11491153
# TODO(woosuk): Refactor the loop.
11501154
next_token_ids: list[int] = []
11511155
for i, token_ids in enumerate(valid_sampled_token_ids):
@@ -1265,7 +1269,7 @@ def load_model(self) -> None:
12651269
self.lora_config,
12661270
self.device)
12671271
if (hasattr(self, "drafter")
1268-
and self.speculative_config.method != "ngram"):
1272+
and not isinstance(self.drafter, NgramProposer)):
12691273
logger.info("Loading drafter model...")
12701274
self.drafter.load_model(self.model)
12711275
time_after_load = time.perf_counter()

0 commit comments

Comments
 (0)