Skip to content

Commit 29255cf

Browse files
[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
1 parent da44556 commit 29255cf

File tree

6 files changed

+84
-16
lines changed

6 files changed

+84
-16
lines changed

vllm/config/compilation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ def decode_mode(self) -> "CUDAGraphMode":
5050
def mixed_mode(self) -> "CUDAGraphMode":
5151
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
5252

53+
def has_mode(self, mode: "CUDAGraphMode") -> bool:
54+
assert not mode.separate_routine()
55+
if self.separate_routine():
56+
return mode.value in self.value
57+
return self == mode
58+
5359
def requires_piecewise_compilation(self) -> bool:
54-
return (
55-
self.decode_mode() == CUDAGraphMode.PIECEWISE
56-
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
57-
)
60+
return self.has_mode(CUDAGraphMode.PIECEWISE)
5861

5962
def max_cudagraph_mode(self) -> "CUDAGraphMode":
6063
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self

vllm/forward_context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ def set_forward_context(
283283
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
284284
)
285285

286+
# Convenience: if cudagraph is used and num_tokens is given, we can just
287+
# create a batch descriptor here if not given (there's no harm since if it
288+
# doesn't match in the wrapper it'll fall through).
289+
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
290+
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
291+
286292
forward_context = create_forward_context(
287293
attn_metadata,
288294
vllm_config,

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88
from transformers import PretrainedConfig
99

10+
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import VllmConfig
1112
from vllm.model_executor.layers.fused_moe import FusedMoE
1213
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -162,6 +163,7 @@ def compute_logits(
162163
return logits
163164

164165

166+
@support_torch_compile
165167
class DeepSeekMTP(nn.Module, SupportsPP):
166168
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
167169
super().__init__()

vllm/model_executor/models/llama_eagle3.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn as nn
99
from transformers import LlamaConfig
1010

11+
from vllm.compilation.decorators import support_torch_compile
1112
from vllm.config import VllmConfig, get_current_vllm_config
1213
from vllm.logger import init_logger
1314
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -21,6 +22,7 @@
2122
)
2223
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2324
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
25+
from vllm.multimodal import MULTIMODAL_REGISTRY
2426
from vllm.multimodal.inputs import NestedTensors
2527

2628
from .utils import AutoWeightsLoader, maybe_prefix
@@ -119,6 +121,15 @@ def forward(
119121
return hidden_states, residual
120122

121123

124+
@support_torch_compile(
125+
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
126+
# violations with dynamic shapes during tensor concatenation operations.
127+
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
128+
# Non-multimodal EAGLE3 models can still use torch.compile safely.
129+
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
130+
vllm_config.model_config
131+
),
132+
)
122133
class LlamaModel(nn.Module):
123134
def __init__(
124135
self,

vllm/v1/spec_decode/eagle.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
12+
from vllm.config import (
13+
CompilationLevel,
14+
CUDAGraphMode,
15+
VllmConfig,
16+
get_layers_from_vllm_config,
17+
)
1318
from vllm.distributed.parallel_state import get_pp_group
1419
from vllm.forward_context import set_forward_context
1520
from vllm.logger import init_logger
@@ -80,12 +85,25 @@ def __init__(
8085
self.attn_layer_names: list[str] = []
8186
self.indexer_layer_names: list[str] = []
8287

83-
self.use_cuda_graph = (
84-
not current_platform.is_xpu()
85-
and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
86-
and not self.vllm_config.model_config.enforce_eager
87-
and not self.speculative_config.enforce_eager
88-
)
88+
self.use_cuda_graph = False
89+
90+
compilation_config = self.vllm_config.compilation_config
91+
if compilation_config.level == CompilationLevel.PIECEWISE:
92+
cudagraph_mode = compilation_config.cudagraph_mode
93+
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
94+
CUDAGraphMode.PIECEWISE
95+
):
96+
logger.warning(
97+
"Currently the eagle proposer only supports cudagraph_mode "
98+
"PIECEWISE, if you want the drafter to use cuda graphs, "
99+
"please set compilation_config.cudagraph_mode to PIECEWISE "
100+
"or FULL_AND_PIECEWISE"
101+
)
102+
self.use_cuda_graph = (
103+
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
104+
and not self.speculative_config.enforce_eager
105+
)
106+
89107
self.cudagraph_batch_sizes = (
90108
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
91109
if self.use_cuda_graph
@@ -239,12 +257,15 @@ def propose(
239257
per_layer_attn_metadata = {}
240258
for layer_name in self.attn_layer_names:
241259
per_layer_attn_metadata[layer_name] = attn_metadata
260+
242261
for layer_name in self.indexer_layer_names:
243262
assert draft_indexer_metadata is not None
244263
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
245264

265+
cudagraph_runtime_mode = CUDAGraphMode.NONE
246266
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
247267
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
268+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
248269
else:
249270
num_input_tokens = num_tokens
250271
# copy inputs to buffer for cudagraph
@@ -267,7 +288,10 @@ def propose(
267288
inputs_embeds = None
268289

269290
with set_forward_context(
270-
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
291+
per_layer_attn_metadata,
292+
self.vllm_config,
293+
num_tokens=num_input_tokens,
294+
cudagraph_runtime_mode=cudagraph_runtime_mode,
271295
):
272296
ret_hidden_states = self.model(
273297
input_ids=input_ids,
@@ -326,8 +350,10 @@ def propose(
326350

327351
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
328352
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
353+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
329354
else:
330355
input_batch_size = batch_size
356+
cudagraph_runtime_mode = CUDAGraphMode.NONE
331357

332358
common_attn_metadata.num_actual_tokens = batch_size
333359
common_attn_metadata.max_query_len = 1
@@ -424,7 +450,10 @@ def propose(
424450

425451
# Run the model.
426452
with set_forward_context(
427-
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
453+
per_layer_attn_metadata,
454+
self.vllm_config,
455+
num_tokens=input_batch_size,
456+
cudagraph_runtime_mode=cudagraph_runtime_mode,
428457
):
429458
ret_hidden_states = self.model(
430459
input_ids=input_ids,
@@ -731,11 +760,16 @@ def propose_tree(
731760

732761
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
733762
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
763+
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
734764
else:
735765
num_input_tokens = num_tokens
766+
cudagraph_runtime_mode = CUDAGraphMode.NONE
736767
# Run the model.
737768
with set_forward_context(
738-
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
769+
per_layer_attn_metadata,
770+
self.vllm_config,
771+
num_tokens=num_input_tokens,
772+
cudagraph_runtime_mode=cudagraph_runtime_mode,
739773
):
740774
last_hidden_states, hidden_states = self.model(
741775
input_ids=self.input_ids[:num_input_tokens],
@@ -1015,8 +1049,19 @@ def load_model(self, target_model: nn.Module) -> None:
10151049
def dummy_run(
10161050
self,
10171051
num_tokens: int,
1052+
use_cudagraphs=True,
10181053
) -> None:
1019-
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
1054+
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
1055+
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
1056+
1057+
with set_forward_context(
1058+
None,
1059+
self.vllm_config,
1060+
num_tokens=num_tokens,
1061+
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
1062+
if use_cudagraphs
1063+
else CUDAGraphMode.NONE,
1064+
):
10201065
if self.supports_mm_inputs:
10211066
input_ids = None
10221067
inputs_embeds = self.inputs_embeds[:num_tokens]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3441,7 +3441,8 @@ def _dummy_run(
34413441

34423442
if self.speculative_config and self.speculative_config.use_eagle():
34433443
assert isinstance(self.drafter, EagleProposer)
3444-
self.drafter.dummy_run(num_tokens)
3444+
use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
3445+
self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs)
34453446

34463447
# This is necessary to avoid blocking DP.
34473448
# For dummy runs, we typically skip EPLB since we don't have any real

0 commit comments

Comments
 (0)