Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/design/v1/torch_compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You
`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`

Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.

### Full Cudagraph capture

It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`

Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
97 changes: 97 additions & 0 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import os

import pytest

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig

MODEL = "Qwen/Qwen2-1.5B-Instruct"


@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v


@pytest.fixture(scope="module")
def full_cudagraph_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
compilation_config=CompilationConfig(full_cuda_graph=True))


@pytest.fixture(scope="module")
def piecewise_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
compilation_config=CompilationConfig())


def generate_text(llm: LLM, batch_size: int, max_tokens: int):
prompts = ["Hi my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)

return llm.generate(prompts, sampling_params)


@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across various test cases.

Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too.
"""
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)

# Check that all responses are the same
for i in range(len(piecewise_responses)):
assert piecewise_responses[i].outputs[
0].text == full_cudagraph_responses[i].outputs[0].text


def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError):
LLM(model=MODEL,
compilation_config=CompilationConfig(full_cuda_graph=True))
19 changes: 17 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- full_cuda_graph: whether to use a full cuda graph for the entire forward
pass rather than splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
Expand Down Expand Up @@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[list[int]] = None
cudagraph_copy_inputs: bool = False
full_cuda_graph: bool = False

class PassConfig(BaseModel):
"""
Expand Down Expand Up @@ -3871,10 +3876,14 @@ def init_with_cudagraph_sizes(self,
self.max_capture_size] = self.max_capture_size

def set_splitting_ops_for_v1(self):
# If default, override splitting ops for piecewise cudagraph on V1.
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")

if not self.splitting_ops:
self.splitting_ops = [
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
Expand Down Expand Up @@ -4151,6 +4160,12 @@ def __post_init__(self):
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True

if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):
Expand Down
13 changes: 10 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder:

def __init__(self, runner: "GPUModelRunner"):
model_config = runner.model_config
compilation_config = runner.vllm_config.compilation_config

self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
Expand All @@ -300,7 +301,14 @@ def __init__(self, runner: "GPUModelRunner"):
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size

self.aot_schedule = (get_flash_attn_version() == 3)
if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
if not self.aot_schedule:
logger.warning(
"AOT Schedule is disabled when using full_cuda_graph")
else:
self.aot_schedule = False

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
Expand All @@ -317,8 +325,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
Comment on lines -320 to -321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: Why do we have this difference?

@WoosukKwon it's because we call .long() here. We might want to still call it here, to keep the dtypes consistent in the model runner.

Copy link
Contributor Author

@chanh chanh May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth what do you think about just making the CPU tensor int64 too? (that's the route that i went with in latest update on this PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to check - that takes the slot_mapping CPU-> GPU transfer from 32KB to 64KB (by default serving on an H100). That seems fine to me since now we don't do that copy in every layer

slot_mapping = self.runner.slot_mapping[:num_actual_tokens]

if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
Expand Down
68 changes: 60 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
Expand Down Expand Up @@ -139,6 +140,16 @@ def __init__(
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")

if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = self.attn_backend.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")

self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
Expand Down Expand Up @@ -219,6 +230,16 @@ def __init__(
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)

# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None

Expand Down Expand Up @@ -271,7 +292,7 @@ def __init__(
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
Expand Down Expand Up @@ -589,10 +610,22 @@ def _prepare_inputs(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)

query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)

# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)

query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]

common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)

Expand Down Expand Up @@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict(
def _dummy_run(
self,
num_tokens: int,
skip_attn: bool = True,
) -> torch.Tensor:

# Set num_scheduled_tokens based on num_tokens and max_num_seqs
Expand All @@ -1494,6 +1528,23 @@ def _dummy_run(
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)

if skip_attn:
attn_metadata = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]

common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)

attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)

with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
Expand Down Expand Up @@ -1522,7 +1573,7 @@ def _dummy_run(
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None,
with set_forward_context(attn_metadata,
Copy link
Contributor

@hidva hidva Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that self.maybe_setup_kv_connector(scheduler_output) is not executed here, in the Full Cuda Graph scenario, the sequence unified_attention_with_output -> maybe_save_kv_layer_to_connector -> connector.save_kv_layer() will cause the connector to read uninitialized metadata.

https://github.com/LMCache/LMCache/blob/680fbdf84e2ee1040bf4e084d43c9155a91b8d5c/lmcache/integration/vllm/vllm_v1_adapter.py#L609-L610

Therefore, Full Cuda Graph should be incompatible with kvconnector?

@simon-mo

self.vllm_config,
num_tokens=num_tokens):
outputs = model(
Expand Down Expand Up @@ -1708,11 +1759,12 @@ def capture_model(self) -> None:
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, skip_attn=skip_attn)
self._dummy_run(num_tokens, skip_attn=skip_attn)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down