Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
517b672
fixes and refactors spec-decode cudagraph
fhl2000 Aug 25, 2025
40e1ccb
remove build_for_cudagraph_capture
fhl2000 Aug 26, 2025
3550717
support capturing mutiple uniform_query_len
fhl2000 Aug 26, 2025
a142f14
fix typo
fhl2000 Aug 26, 2025
f7d73f8
fix typo
fhl2000 Aug 26, 2025
02390fc
fix broken examples/offline_inference/spec_decode.py
fhl2000 Aug 26, 2025
198fb66
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Sep 4, 2025
14c6918
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 7, 2025
ec02778
fix pre-commit
fhl2000 Sep 7, 2025
6b90770
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 9, 2025
286677f
revert spec_decode.py
fhl2000 Sep 10, 2025
874639c
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 14, 2025
0eda111
address comments
fhl2000 Sep 15, 2025
9c50e6e
revert build_for_cudagraph_capturing
fhl2000 Sep 15, 2025
e4a1a78
remove unnecessary assertion
fhl2000 Sep 15, 2025
ce32326
solving conflicts/Merge remote-tracking branch 'origin/main' into fix…
fhl2000 Sep 16, 2025
ad5ba70
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 20, 2025
691c21e
fixes for ubatching
fhl2000 Sep 21, 2025
43b2753
fix CI
fhl2000 Sep 23, 2025
fde10ba
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Sep 28, 2025
0a3fe05
fix
fhl2000 Sep 28, 2025
804598b
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Sep 29, 2025
40bd81b
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Oct 6, 2025
a51344e
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Oct 14, 2025
d170341
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Oct 14, 2025
0ee4aef
WIP:address dp padding issue
fhl2000 Oct 14, 2025
a4872bc
clean up
fhl2000 Oct 15, 2025
872015e
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Oct 15, 2025
d1499c2
Merge remote-tracking branch 'origin/main' into fix_cudagraph_drafter
fhl2000 Nov 1, 2025
c18486a
refactor eagle dummy run
fhl2000 Nov 1, 2025
9b99056
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Nov 4, 2025
299ce7d
fix drafter when enforce_eager
fhl2000 Nov 4, 2025
b5c315a
fix pre-commit
fhl2000 Nov 4, 2025
25d3f3b
Merge branch 'main' into fix_cudagraph_drafter
fhl2000 Nov 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -71,6 +73,7 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--compilation-config", type=str, default="")
return parser.parse_args()


Expand Down Expand Up @@ -139,6 +142,9 @@ def main(args):
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
compilation_config=(
json.loads(args.compilation_config) if args.compilation_config else None
),
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
59 changes: 58 additions & 1 deletion vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,36 @@ class CompilationConfig:
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
disable_cudagraph_uniform_alignment: bool = False
"""Whether to disable uniformly alignment of cudagraph capture sizes for
uniform decode batch with query length>1 (i.e., for spec-decode). This flag
only takes effective when cudagraph_mode is FULL_DECODE_ONLY or
FULL_AND_PIECEWISE.

Uniform alignment make sure all capture sizes for uniform-decode batch
are multiples of 1+num_speculative_tokens. This aligmnment is typically
useful for padded speculation (see #21984 for details), and is needed by
some attention backends to achieve their sota performance, which support
uniform-decode but no in a varible-length fashion. However, we should
realize here is a trade-off that while it is good for attention layer,
it may introduce slight regressions to other layers if these sizes after
alignment don't hit the multiple of 8.

Note: for DP_size>1, the uniformity of sizes may be broken after dp_padding
sync. Therefore, we only ensure running full cudagraph of uniform-decode batch
of current rank if all dp ranks are uniform-decode batch. Otherwise, it would
fall back to piecewise cudagraphs, where the uniformity batch before padded
should still be utilized by attention layers under eager exectution.
"""
uniform_cudagraph_capture_sizes: list[int] | None = None
"""
List for capture sizes for uniform decode for the main model. Its elements
should be multiples of uniform_decode_len(1 for common pure decode, or
1+num_speculative_tokens for speculative decode).
Not configurable, computed after init
"""
max_uniform_capture_size: int = field(default=None, init=False) # type: ignore
"""not configurable, computed after init"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
Expand All @@ -438,6 +468,11 @@ class CompilationConfig:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""
bs_to_padded_graph_size_uniform: list[int] = field(
default=None, # type: ignore
init=False,
)
"""same as bs_to_padded_graph_size, but for uniform capture sizes"""

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
Expand Down Expand Up @@ -503,6 +538,7 @@ def __repr__(self) -> str:
"disabled_custom_ops": True,
"compilation_time": True,
"bs_to_padded_graph_size": True,
"bs_to_padded_graph_size_uniform": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
Expand Down Expand Up @@ -718,7 +754,8 @@ def post_init_cudagraph_sizes(self) -> None:
"""To complete the initialization after cudagraph related
configs are set. This includes:
- initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
- pre-compute the mapping bs_to_padded_graph_size and
bs_to_padded_graph_size_uniform
"""

computed_compile_sizes = []
Expand All @@ -739,8 +776,14 @@ def post_init_cudagraph_sizes(self) -> None:

# make sure the sizes are in ascending order
self.cudagraph_capture_sizes.sort()
self.uniform_cudagraph_capture_sizes.sort()
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
if self.uniform_cudagraph_capture_sizes:
assert (
self.uniform_cudagraph_capture_sizes[-1]
== self.max_uniform_capture_size
)

# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
Expand All @@ -756,6 +799,20 @@ def post_init_cudagraph_sizes(self) -> None:
else:
self.bs_to_padded_graph_size[bs] = end

# pre-compute the mapping for uniform decode padding.
self.bs_to_padded_graph_size_uniform = [
0 for i in range(self.max_uniform_capture_size + 1)
]
for end, start in zip(
self.uniform_cudagraph_capture_sizes + [self.max_uniform_capture_size + 1],
[0] + self.uniform_cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size_uniform[bs] = start
else:
self.bs_to_padded_graph_size_uniform[bs] = end

def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
Expand Down
65 changes: 57 additions & 8 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,23 @@ def compute_hash(self) -> str:
).hexdigest()[:10]
return hash_str

def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_cudagraph_capture_size,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]
def pad_for_cudagraph(self, batch_size: int, uniform_aligned: bool = False) -> int:
"""Get the padded graph size for the batch size.
uniform_aligned: if True, means the padding batch size would be
divisible by the uniform_decode_len for the main model.
For drafter, caller should make sure uniform_aligned is False because
drafter's uniform_decode_len is 1.
"""
if self.compilation_config.disable_cudagraph_uniform_alignment:
uniform_aligned = False
# if batch_size > max_cudagraph_capture_size (uniform_aligned=False)
# or batch_size > max_uniform_capture_size (uniform_aligned=True),
# it would raise an IndexError. So the caller should make sure the
# batch_size is within the range
if not uniform_aligned:
return self.compilation_config.bs_to_padded_graph_size[batch_size]
else:
return self.compilation_config.bs_to_padded_graph_size_uniform[batch_size]

def enable_trace_function_call_for_thread(self) -> None:
"""
Expand Down Expand Up @@ -800,7 +811,6 @@ def _set_cudagraph_sizes(self):
- If batch size > largest `cudagraph_capture_sizes`, cudagraph will
not be used.
"""

if (
self.model_config is not None
and not self.model_config.enforce_eager
Expand Down Expand Up @@ -847,14 +857,41 @@ def _set_cudagraph_sizes(self):
cudagraph_capture_sizes += list(
range(256, max_cudagraph_capture_size + 1, 16)
)

uniform_decode_len = (
1
if not self.speculative_config
else 1 + self.speculative_config.num_speculative_tokens
)
max_num_decode_tokens = min(
max_num_tokens,
self.scheduler_config.max_num_seqs * uniform_decode_len,
)
if (
self.compilation_config.disable_cudagraph_uniform_alignment
or uniform_decode_len == 1
):
uniform_cudagraph_capture_sizes = [
x for x in cudagraph_capture_sizes if x < max_num_decode_tokens
]
else:
uniform_cudagraph_capture_sizes = [
size * uniform_decode_len
for size in cudagraph_capture_sizes
if size >= uniform_decode_len
and size * uniform_decode_len <= max_num_decode_tokens
]
if (
self.parallel_config.tensor_parallel_size > 1
and self.compilation_config.pass_config.enable_sequence_parallelism
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
)
uniform_cudagraph_capture_sizes = (
self.update_sizes_for_sequence_parallelism(
uniform_cudagraph_capture_sizes
)
)

# user-specific compilation_config.max_cudagraph_capture_size get
# truncated to valid_max_size when they are inconsistent.
Expand Down Expand Up @@ -899,10 +936,22 @@ def _set_cudagraph_sizes(self):
# always write back the final sizes
self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes

# set uniform_cudagraph_sizes related values.
self.compilation_config.max_uniform_capture_size = (
uniform_cudagraph_capture_sizes[-1]
if uniform_cudagraph_capture_sizes
else 0
)
self.compilation_config.uniform_cudagraph_capture_sizes = (
uniform_cudagraph_capture_sizes
)

else:
# no cudagraph in use
self.compilation_config.max_cudagraph_capture_size = 0
self.compilation_config.cudagraph_capture_sizes = []
self.compilation_config.max_uniform_capture_size = 0
self.compilation_config.uniform_cudagraph_capture_sizes = []

# complete the remaining process.
self.compilation_config.post_init_cudagraph_sizes()
Expand Down
10 changes: 9 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
uniform_query_len: int = 0
"""
For non-uniform batches, should set to 0 for uniquely identifying the batch.
For uniform batches, it is the max_query_len of a uniform batch.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
Expand All @@ -51,7 +56,10 @@ def non_uniform(self) -> "BatchDescriptor":
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
self.num_tokens,
uniform_decode=False,
uniform_query_len=0,
has_lora=self.has_lora,
)


Expand Down
17 changes: 0 additions & 17 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,23 +741,6 @@ def _build_decode(
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
"MLA only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)

assert m.max_query_len <= self.reorder_batch_threshold # decode only

return self.build(0, m)

def build(
self,
common_prefix_len: int,
Expand Down
10 changes: 0 additions & 10 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ def __init__(
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata

Comment on lines -89 to -98
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, with the frozen CG feature, the time is almost the same as FA2.

Apologize for this partially true statement. Just found that when max_model_len is extremely large, FA2 is slow at capturing FULL, while Triton attention is just as fast as when the max_model_len is small. So Triton attention is ok with removing this function, but we should do something else to make FA2 avoid slowing down when capturing FULL with extremely large max_model_len. Not sure if the same situation happened for some other attention backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashInfer backend is also slow at capturing FULL when max_model_len is large.

def build(
self,
common_prefix_len: int,
Expand Down
Loading