Skip to content

Commit 8ba3b17

Browse files
jiahancyewentao256
authored andcommitted
[Speculators][Speculative Decoding] Fix gpt-oss eagle3 accuracy issue (#25406)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 8222e26 commit 8ba3b17

File tree

6 files changed

+79
-17
lines changed

6 files changed

+79
-17
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,8 @@ def create_deterministic_logits(token_ids):
534534
proposer.runner.attn_groups.append([mock.MagicMock()])
535535
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
536536
attn_metadata_builder
537+
proposer._get_attention_metadata_builder = mock.MagicMock(
538+
return_value=attn_metadata_builder)
537539

538540
result = proposer.propose(target_token_ids=target_token_ids,
539541
target_positions=target_positions,
@@ -660,6 +662,8 @@ def create_deterministic_logits(token_ids, k: int):
660662
proposer.runner.attn_groups.append([mock.MagicMock()])
661663
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
662664
attn_metadata_builder
665+
proposer._get_attention_metadata_builder = mock.MagicMock(
666+
return_value=attn_metadata_builder)
663667

664668
# Setup inputs for the proposer.
665669
target_token_ids = torch.randint(0,

vllm/config/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,7 @@ def _verify_quantization(self) -> None:
10031003
self.quantization = quantization_override
10041004
break
10051005

1006+
quant_method = quant_method if quant_method != "" else None
10061007
# Verify quantization configurations.
10071008
if self.quantization is None:
10081009
self.quantization = quant_method

vllm/model_executor/models/llama_eagle.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
134134
nn.Module.__init__(self)
135135
self.config = vllm_config. \
136136
speculative_config.draft_model_config.hf_config
137+
# Ensure draft_vocab_size is set
138+
# default to the base vocab size when absent
139+
if getattr(self.config, "draft_vocab_size", None) is None:
140+
base_vocab_size = getattr(self.config, "vocab_size", None)
141+
self.config.draft_vocab_size = base_vocab_size
137142
target_layer_num = vllm_config.model_config.get_num_layers(
138143
vllm_config.parallel_config)
139144
self.model = LlamaModel(vllm_config=vllm_config,

vllm/model_executor/models/llama_eagle3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
203203
nn.Module.__init__(self)
204204
self.config = vllm_config. \
205205
speculative_config.draft_model_config.hf_config
206+
# Ensure draft_vocab_size is set
207+
# default to the base vocab size when absent
208+
if getattr(self.config, "draft_vocab_size", None) is None:
209+
base_vocab_size = getattr(self.config, "vocab_size", None)
210+
self.config.draft_vocab_size = base_vocab_size
206211
target_layer_num = vllm_config.model_config.get_num_layers(
207212
vllm_config.parallel_config)
208213

vllm/v1/spec_decode/eagle.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12+
from vllm.attention.backends.abstract import AttentionMetadataBuilder
1213
from vllm.attention.layer import Attention
1314
from vllm.config import (CompilationLevel, VllmConfig,
1415
get_layers_from_vllm_config)
@@ -77,6 +78,8 @@ def __init__(
7778
self.is_multimodal_model = vllm_config.model_config \
7879
.is_multimodal_model
7980

81+
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
82+
8083
self.use_cuda_graph = (self.vllm_config.compilation_config.level
8184
== CompilationLevel.PIECEWISE and
8285
not self.vllm_config.model_config.enforce_eager)
@@ -117,7 +120,7 @@ def __init__(
117120
with_numpy=True)
118121

119122
# Determine allowed attention backends once during initialization.
120-
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
123+
self.allowed_attn_types: tuple[type, ...]
121124
if current_platform.is_rocm():
122125
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
123126
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@@ -190,10 +193,12 @@ def propose(
190193

191194
assert self.runner is not None
192195

193-
# FIXME: need to consider multiple kv_cache_groups
194-
attn_metadata_builder = \
195-
self.runner.attn_groups[0][0].get_metadata_builder()
196-
attn_metadata = attn_metadata_builder.build_for_drafting(
196+
# Select the correct attention metadata builders for EAGLE layers.
197+
# Get the attention metadata builders once and reuse for later.
198+
builder = (self._get_attention_metadata_builder()
199+
if self.attn_metadata_builder is None else
200+
self.attn_metadata_builder)
201+
attn_metadata = builder.build_for_drafting(
197202
common_attn_metadata=common_attn_metadata, draft_index=0)
198203

199204
# At this moment, we assume all eagle layers belong to the same KV
@@ -327,11 +332,9 @@ def propose(
327332
exceeds_max_model_len, PADDING_SLOT_ID)
328333

329334
# Rebuild attention metadata
330-
attn_metadata_builder = \
331-
self.runner.attn_groups[0][0].get_metadata_builder()
332-
attn_metadata = attn_metadata_builder\
333-
.build_for_drafting(common_attn_metadata=common_attn_metadata,
334-
draft_index=token_index + 1)
335+
attn_metadata = builder.build_for_drafting(
336+
common_attn_metadata=common_attn_metadata,
337+
draft_index=token_index + 1)
335338
for layer_name in self.attn_layer_names:
336339
per_layer_attn_metadata[layer_name] = attn_metadata
337340

@@ -851,10 +854,24 @@ def load_model(self, target_model: nn.Module) -> None:
851854
# share lm_head with the target model if needed
852855
# some model definition do not define lm_head explicitly
853856
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
854-
if self.vllm_config.speculative_config.method != "eagle3" and \
855-
hasattr(target_language_model, "lm_head"):
856-
logger.info("Loading EAGLE LM head weights from the target model.")
857-
self.model.lm_head = target_language_model.lm_head
857+
if self.vllm_config.speculative_config.method != "eagle3":
858+
if hasattr(target_language_model, "lm_head"):
859+
logger.info(
860+
"Loading EAGLE LM head weights from the target model.")
861+
self.model.lm_head = target_language_model.lm_head
862+
else:
863+
if (hasattr(self.model, "lm_head")
864+
and hasattr(target_language_model, "lm_head")
865+
and self.model.lm_head.weight.shape
866+
== target_language_model.lm_head.weight.shape):
867+
logger.info("Assuming the EAGLE head shares the same lm_head"
868+
" with the target model.")
869+
del self.model.lm_head
870+
self.model.lm_head = target_language_model.lm_head
871+
else:
872+
logger.info(
873+
"The EAGLE head's lm_head will be loaded separately"
874+
" from the target model.")
858875

859876
@torch.inference_mode()
860877
def dummy_run(
@@ -877,6 +894,31 @@ def dummy_run(
877894
inputs_embeds=inputs_embeds,
878895
)
879896

897+
def _get_attention_metadata_builder(
898+
self) -> list[AttentionMetadataBuilder]:
899+
"""Find and return the attention metadata builders for EAGLE layers.
900+
901+
Returns:
902+
The metadata builders for EAGLE layers.
903+
904+
Raises:
905+
AssertionError: If no metadata builders are found for EAGLE layers.
906+
"""
907+
builder = None
908+
chosen_layer = self.attn_layer_names[0]
909+
910+
for kv_cache_group in self.runner.attn_groups:
911+
for attn_group in kv_cache_group:
912+
if chosen_layer in attn_group.layer_names:
913+
builder = attn_group.get_metadata_builder()
914+
break
915+
if builder is not None:
916+
break
917+
918+
assert builder is not None, (
919+
"Failed to find attention metadata builder for EAGLE layers.")
920+
return builder
921+
880922
def validate_same_kv_cache_group(self,
881923
kv_cache_config: KVCacheConfig) -> None:
882924
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,9 +1177,14 @@ def _prepare_inputs(
11771177
encoder_seq_lens=encoder_seq_lens,
11781178
)
11791179

1180-
if self.speculative_config and \
1181-
spec_decode_common_attn_metadata is None:
1182-
spec_decode_common_attn_metadata = common_attn_metadata
1180+
if (self.speculative_config
1181+
and spec_decode_common_attn_metadata is None):
1182+
if isinstance(self.drafter, EagleProposer):
1183+
if (self.drafter.attn_layer_names[0]
1184+
in kv_cache_group_spec.layer_names):
1185+
spec_decode_common_attn_metadata = common_attn_metadata
1186+
else:
1187+
spec_decode_common_attn_metadata = common_attn_metadata
11831188

11841189
for attn_group in self.attn_groups[kv_cache_group_id]:
11851190
# Prepare for cascade attention if enabled & beneficial.

0 commit comments

Comments
 (0)