Skip to content

Commit 20d93dc

Browse files
committed
Add ngram-eagle SD method
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
1 parent 23a6c52 commit 20d93dc

File tree

7 files changed

+127
-31
lines changed

7 files changed

+127
-31
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
3+
import json
44
from transformers import AutoTokenizer
55

66
from vllm import LLM, SamplingParams
@@ -53,9 +53,10 @@ def parse_args():
5353
"--method",
5454
type=str,
5555
default="eagle",
56-
choices=["ngram", "eagle", "eagle3", "mtp"],
56+
choices=["ngram", "eagle", "eagle3", "mtp", "ngram-eagle"],
5757
)
5858
parser.add_argument("--num-spec-tokens", type=int, default=2)
59+
parser.add_argument("--num-speculative-tokens-per-method", type=str, default='{\"ngram\": 2, \"eagle\": 2}')
5960
parser.add_argument("--prompt-lookup-max", type=int, default=5)
6061
parser.add_argument("--prompt-lookup-min", type=int, default=2)
6162
parser.add_argument("--tp", type=int, default=1)
@@ -118,6 +119,22 @@ def main():
118119
"prompt_lookup_max": args.prompt_lookup_max,
119120
"prompt_lookup_min": args.prompt_lookup_min,
120121
}
122+
elif args.method == "ngram-eagle":
123+
num_speculative_tokens_per_method = json.loads(args.num_speculative_tokens_per_method)
124+
eagle_dir = args.eagle_dir
125+
if eagle_dir is None:
126+
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
127+
args.num_spec_tokens = max(
128+
num_speculative_tokens_per_method["ngram"],
129+
num_speculative_tokens_per_method["eagle"],
130+
)
131+
speculative_config = {
132+
"method": "ngram-eagle",
133+
"model": eagle_dir,
134+
"num_speculative_tokens_per_method": num_speculative_tokens_per_method,
135+
"prompt_lookup_max": args.prompt_lookup_max,
136+
"prompt_lookup_min": args.prompt_lookup_min,
137+
}
121138
else:
122139
raise ValueError(f"unknown method: {args.method}")
123140

@@ -150,6 +167,7 @@ def main():
150167
print("-" * 50)
151168
print(f"prompt: {output.prompt}")
152169
print(f"generated text: {output.outputs[0].text}")
170+
print(f"num of generated tokens: {len(output.outputs[0].token_ids)}")
153171
print("-" * 50)
154172

155173
try:
@@ -179,6 +197,10 @@ def main():
179197
assert isinstance(metric, Vector)
180198
for pos in range(len(metric.values)):
181199
acceptance_counts[pos] += metric.values[pos]
200+
elif metric.name == "vllm:generation_tokens":
201+
assert isinstance(metric, Counter)
202+
print(f"num generation tokens: {metric.value}")
203+
total_tokens_generated = metric.value
182204

183205
print("-" * 50)
184206
print(f"total_num_output_tokens: {total_num_output_tokens}")
@@ -187,6 +209,10 @@ def main():
187209
print(f"num_accepted_tokens: {num_accepted_tokens}")
188210
acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
189211
print(f"mean acceptance length: {acceptance_length:.2f}")
212+
num_tokens_generated_without_sd = total_tokens_generated - (num_drafts + num_accepted_tokens)
213+
seq_normalized_acceptance_length = (total_tokens_generated) / (num_drafts + num_tokens_generated_without_sd)
214+
print(f"num_tokens_generated_without_sd: {num_tokens_generated_without_sd}")
215+
print(f"seq normalized acceptance length: {seq_normalized_acceptance_length:.2f}")
190216
print("-" * 50)
191217

192218
# print acceptance at each token position

vllm/benchmarks/datasets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,8 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]:
12781278

12791279

12801280
def get_samples(args, tokenizer) -> list[SampleRequest]:
1281+
if not hasattr(args, "request_id_prefix"):
1282+
args.request_id_prefix = ""
12811283
if args.dataset_name == "custom":
12821284
dataset = CustomDataset(dataset_path=args.dataset_path)
12831285
input_requests = dataset.sample(

vllm/config/__init__.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,7 @@ def __post_init__(self):
19361936
self.device = torch.device(self.device_type)
19371937

19381938

1939-
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
1939+
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "ngram-eagle", "medusa",
19401940
"mlp_speculator", "draft_model", "deepseek_mtp",
19411941
"ernie_mtp"]
19421942

@@ -1950,6 +1950,9 @@ class SpeculativeConfig:
19501950
num_speculative_tokens: SkipValidation[int] = None # type: ignore
19511951
"""The number of speculative tokens, if provided. It will default to the
19521952
number in the draft model config if present, otherwise, it is required."""
1953+
num_speculative_tokens_per_method: Optional[dict[str, int]] = None
1954+
"""The number of speculative tokens for each method, if provided. Max of
1955+
the values will be used if `num_speculative_tokens` is not provided."""
19531956
model: Optional[str] = None
19541957
"""The name of the draft model, eagle head, or additional weights, if
19551958
provided."""
@@ -2109,6 +2112,18 @@ def __post_init__(self):
21092112
raise ValueError("num_speculative_tokens was provided without "
21102113
"speculative model.")
21112114

2115+
# set num_speculative_tokens from num_speculative_tokens_per_method
2116+
# for methods like ngram-eagle
2117+
if self.num_speculative_tokens_per_method is not None:
2118+
max_num_speculative_tokens = max(
2119+
self.num_speculative_tokens_per_method.values())
2120+
if self.num_speculative_tokens is None:
2121+
self.num_speculative_tokens = max_num_speculative_tokens
2122+
else:
2123+
assert self.num_speculative_tokens < max_num_speculative_tokens, (
2124+
"num_speculative_tokens should be None or must be less than or equal to the "
2125+
"max value in num_speculative_tokens_per_method.")
2126+
21122127
# Automatically configure the method for ngram when "model" is used
21132128
# instead of "method"
21142129
if self.method is None and (self.model is not None
@@ -2118,6 +2133,8 @@ def __post_init__(self):
21182133
if self.method in ("ngram", "[ngram]"):
21192134
# Unified to "ngram" internally
21202135
self.method = "ngram"
2136+
2137+
if self.method in ("ngram", "ngram-eagle"):
21212138
# Set default values if not provided
21222139
if (self.prompt_lookup_min is None
21232140
and self.prompt_lookup_max is None):
@@ -2148,9 +2165,13 @@ def __post_init__(self):
21482165
# draft related config as None here.
21492166
self.draft_model_config = self.target_model_config
21502167
self.draft_parallel_config = self.target_parallel_config
2151-
else:
2152-
self.prompt_lookup_max = 0
2153-
self.prompt_lookup_min = 0
2168+
2169+
# allow ngram-eagle to use this code block similar to eagle
2170+
if self.method not in ("ngram"):
2171+
2172+
if self.method != "ngram-eagle":
2173+
self.prompt_lookup_max = 0
2174+
self.prompt_lookup_min = 0
21542175

21552176
if self.model is not None:
21562177
self.draft_model_config = ModelConfig(
@@ -2179,7 +2200,7 @@ def __post_init__(self):
21792200
)
21802201

21812202
# Automatically detect the method
2182-
if self.method in ('eagle', 'eagle3'):
2203+
if self.method in ('eagle', 'eagle3', 'ngram-eagle'):
21832204
pass
21842205
elif "eagle-" in self.draft_model_config.model.lower() or \
21852206
"eagle3-" in self.draft_model_config.model.lower():
@@ -2216,7 +2237,7 @@ def __post_init__(self):
22162237
"eagle, or deepseek_mtp.")
22172238

22182239
# Replace hf_config for EAGLE draft_model
2219-
if self.method in ("eagle", "eagle3"):
2240+
if self.method in ("eagle", "eagle3", "ngram-eagle"):
22202241
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
22212242
raise ValueError(
22222243
"Chunked prefill and EAGLE are not compatible "
@@ -2422,7 +2443,7 @@ def num_lookahead_slots(self) -> int:
24222443
return self.num_speculative_tokens
24232444

24242445
def use_eagle(self) -> bool:
2425-
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
2446+
return self.method in ("eagle", "eagle3", "ngram-eagle", "deepseek_mtp", "ernie_mtp")
24262447

24272448
def __repr__(self) -> str:
24282449
method = self.method

vllm/transformers_utils/configs/eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self,
4646
# Eagle model name should follow naming convention of
4747
# LlamaForCausalLM -> EagleLlamaForCausalLM
4848
# LlamaForCausalLM -> Eagle3LlamaForCausalLM
49-
if method == "eagle":
49+
if method in ("eagle", "ngram-eagle"):
5050
assert self.model is not None, \
5151
"model should not be None when method is eagle"
5252
kwargs["architectures"] = [
@@ -62,7 +62,7 @@ def __init__(self,
6262
]
6363
else:
6464
raise ValueError(f"Invalid method {method}. "
65-
"Supported methods are eagle and eagle3.")
65+
"Supported methods are eagle, ngram-eagle and eagle3.")
6666

6767
super().__init__(**kwargs)
6868

vllm/v1/spec_decode/eagle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,16 @@ def __init__(
6161
self.dtype = vllm_config.model_config.dtype
6262
self.max_model_len = vllm_config.model_config.max_model_len
6363
self.block_size = vllm_config.cache_config.block_size
64-
self.num_speculative_tokens = (
65-
self.speculative_config.num_speculative_tokens)
64+
65+
if self.method == "ngram-eagle":
66+
self.num_speculative_tokens = (
67+
self.speculative_config.num_speculative_tokens_per_method["eagle"])
68+
else:
69+
self.num_speculative_tokens = (
70+
self.speculative_config.num_speculative_tokens)
71+
72+
logger.info(f"EagleProposer: method={self.method}, num_speculative_tokens={self.num_speculative_tokens}")
73+
6674
self.max_num_tokens = (
6775
vllm_config.scheduler_config.max_num_batched_tokens)
6876
self.token_arange_np = np.arange(self.max_num_tokens)

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from numba import jit
77

88
from vllm.config import VllmConfig
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
912

1013

1114
class NgramProposer:
@@ -22,13 +25,18 @@ def __init__(self, vllm_config: VllmConfig):
2225
# Number of tokens follow the match. If there are less than k
2326
# tokens follow the match, we will return the maximum amount of
2427
# tokens until the end.
25-
self.k = vllm_config.speculative_config.num_speculative_tokens
28+
self.method = vllm_config.speculative_config.method
29+
if self.method == "ngram-eagle":
30+
self.k = vllm_config.speculative_config.num_speculative_tokens_per_method["ngram"]
31+
else:
32+
self.k = vllm_config.speculative_config.num_speculative_tokens
2633
# Maximum length of the model.
2734
self.max_model_len = vllm_config.model_config.max_model_len
2835

2936
# Trigger Numba JIT compilation for N-gram proposer.
3037
# This usually takes less than 1 second.
3138
self.propose(np.zeros(1024, dtype=np.int32))
39+
logger.info(f"NgramProposer: min_n={self.min_n}, max_n={self.max_n}, k={self.k}, max_model_len={self.max_model_len}") # noqa: E501
3240

3341
def propose(
3442
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,19 +185,26 @@ def __init__(
185185
# NOTE(Jiayi): currently we put the entire draft model on
186186
# the last PP rank. This is not ideal if there are many
187187
# layers in the draft model.
188+
found_draft = False
188189
if self.speculative_config and get_pp_group().is_last_rank:
189-
if self.speculative_config.method == "ngram":
190-
self.drafter = NgramProposer(self.vllm_config)
191-
elif self.speculative_config.use_eagle():
192-
self.drafter = EagleProposer(self.vllm_config, self.device,
190+
# use ifs and not elifs to allow multiple
191+
# draft models to be initialized
192+
if self.speculative_config.method == "ngram" \
193+
or self.speculative_config.method == "ngram-eagle":
194+
self.drafter_ngram = NgramProposer(self.vllm_config)
195+
found_draft = True
196+
if self.speculative_config.use_eagle():
197+
self.drafter_eagle = EagleProposer(self.vllm_config, self.device,
193198
self) # type: ignore
194199
if self.speculative_config.method == "eagle3":
195200
self.use_aux_hidden_state_outputs = True
196-
elif self.speculative_config.method == "medusa":
201+
found_draft = True
202+
if self.speculative_config.method == "medusa":
197203
self.drafter = MedusaProposer(
198204
vllm_config=self.vllm_config,
199205
device=self.device) # type: ignore
200-
else:
206+
found_draft = True
207+
if not found_draft:
201208
raise ValueError("Unknown speculative decoding method: "
202209
f"{self.speculative_config.method}")
203210
self.rejection_sampler = RejectionSampler()
@@ -1775,10 +1782,12 @@ def propose_draft_token_ids(
17751782
common_attn_metadata: CommonAttentionMetadata,
17761783
) -> Union[list[list[int]], torch.Tensor]:
17771784
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1778-
if self.speculative_config.method == "ngram":
1779-
assert isinstance(self.drafter, NgramProposer)
1785+
if self.speculative_config.method == "ngram" or self.speculative_config.method == "ngram-eagle":
1786+
assert isinstance(self.drafter_ngram, NgramProposer)
17801787
draft_token_ids = self.propose_ngram_draft_token_ids(
17811788
sampled_token_ids)
1789+
if self.speculative_config.method == "ngram-eagle":
1790+
draft_token_ids_ngram = draft_token_ids
17821791
elif self.speculative_config.method == "medusa":
17831792
assert isinstance(self.drafter, MedusaProposer)
17841793
if sample_hidden_states.shape[0] == len(sampled_token_ids):
@@ -1799,8 +1808,8 @@ def propose_draft_token_ids(
17991808
target_hidden_states=hidden_states,
18001809
sampling_metadata=sampling_metadata,
18011810
)
1802-
elif self.speculative_config.use_eagle():
1803-
assert isinstance(self.drafter, EagleProposer)
1811+
if self.speculative_config.use_eagle():
1812+
assert isinstance(self.drafter_eagle, EagleProposer)
18041813
# TODO(woosuk): Refactor the loop.
18051814
req_ids = self.input_batch.req_ids
18061815
next_token_ids: list[int] = []
@@ -1842,7 +1851,7 @@ def propose_draft_token_ids(
18421851
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
18431852
dtype=torch.int32)
18441853
common_attn_metadata, token_indices =\
1845-
self.drafter.prepare_inputs(
1854+
self.drafter_eagle.prepare_inputs(
18461855
common_attn_metadata, num_rejected_tokens_cpu)
18471856

18481857
target_token_ids = self.input_ids.gpu[token_indices]
@@ -1858,7 +1867,7 @@ def propose_draft_token_ids(
18581867
mm_embeds = self._gather_mm_embeddings(scheduler_output,
18591868
shift_computed_tokens=1)
18601869

1861-
draft_token_ids = self.drafter.propose(
1870+
draft_token_ids = self.drafter_eagle.propose(
18621871
target_token_ids=target_token_ids,
18631872
target_positions=target_positions,
18641873
target_hidden_states=target_hidden_states,
@@ -1867,6 +1876,25 @@ def propose_draft_token_ids(
18671876
common_attn_metadata=common_attn_metadata,
18681877
mm_embeds=mm_embeds,
18691878
)
1879+
if self.speculative_config.method == "ngram-eagle":
1880+
draft_token_ids_eagle = draft_token_ids
1881+
1882+
if self.speculative_config.method == "ngram-eagle":
1883+
assert draft_token_ids_ngram is not None, "ngram proposer failed"
1884+
assert draft_token_ids_eagle is not None, "eagle proposer failed"
1885+
# eagle draft is torch but we need list
1886+
draft_token_ids_eagle = draft_token_ids_eagle.tolist()
1887+
draft_token_ids = []
1888+
1889+
# combine ngram and eagle drafts
1890+
# prefer ngram drafts when available
1891+
# choose eagle drafts when ngram drafts are empty
1892+
for bid in range(len(draft_token_ids_ngram)):
1893+
if len(draft_token_ids_ngram[bid]):
1894+
draft_token_ids.append(draft_token_ids_ngram[bid])
1895+
else:
1896+
draft_token_ids.append(draft_token_ids_eagle[bid])
1897+
18701898
return draft_token_ids
18711899

18721900
def propose_ngram_draft_token_ids(
@@ -1896,7 +1924,7 @@ def propose_ngram_draft_token_ids(
18961924
draft_token_ids.append([])
18971925
continue
18981926

1899-
drafter_output = self.drafter.propose(
1927+
drafter_output = self.drafter_ngram.propose(
19001928
self.input_batch.token_ids_cpu[i, :num_tokens])
19011929
if drafter_output is None or len(drafter_output) == 0:
19021930
draft_token_ids.append([])
@@ -1963,6 +1991,9 @@ def load_model(self, eep_scale_up: bool = False) -> None:
19631991
if hasattr(self, "drafter"):
19641992
logger.info("Loading drafter model...")
19651993
self.drafter.load_model(self.model)
1994+
if hasattr(self, "drafter_eagle"):
1995+
logger.info("Loading eagle drafter model...")
1996+
self.drafter_eagle.load_model(self.model)
19661997
if self.use_aux_hidden_state_outputs:
19671998
if supports_eagle3(self.model):
19681999
self.model.set_aux_hidden_state_layers(
@@ -2379,8 +2410,8 @@ def _dummy_run(
23792410
hidden_states = outputs
23802411

23812412
if self.speculative_config and self.speculative_config.use_eagle():
2382-
assert isinstance(self.drafter, EagleProposer)
2383-
self.drafter.dummy_run(num_tokens)
2413+
assert isinstance(self.drafter_eagle, EagleProposer)
2414+
self.drafter_eagle.dummy_run(num_tokens)
23842415

23852416
# This is necessary to avoid blocking DP.
23862417
# For dummy runs, we typically skip EPLB since we don't have any real
@@ -3133,10 +3164,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
31333164
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
31343165

31353166
if self.speculative_config and self.speculative_config.use_eagle():
3136-
assert isinstance(self.drafter, EagleProposer)
3167+
assert isinstance(self.drafter_eagle, EagleProposer)
31373168
# validate all draft model layers belong to the same kv cache
31383169
# group
3139-
self.drafter.validate_same_kv_cache_group(kv_cache_config)
3170+
self.drafter_eagle.validate_same_kv_cache_group(kv_cache_config)
31403171

31413172
if has_kv_transfer_group():
31423173
get_kv_transfer_group().register_kv_caches(kv_caches)

0 commit comments

Comments
 (0)