Skip to content

Commit f5536e5

Browse files
chanhChanh Nguyen
authored andcommitted
[Core] Support full cuda graph in v1 (vllm-project#16072)
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com> Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
1 parent 08c489f commit f5536e5

File tree

5 files changed

+190
-13
lines changed

5 files changed

+190
-13
lines changed

docs/source/design/v1/torch_compile.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You
137137
`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
138138

139139
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
140+
141+
### Full Cudagraph capture
142+
143+
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`
144+
145+
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import contextlib
3+
import os
4+
5+
import pytest
6+
7+
from vllm import LLM, SamplingParams
8+
from vllm.config import CompilationConfig
9+
10+
MODEL = "Qwen/Qwen2-1.5B-Instruct"
11+
12+
13+
@contextlib.contextmanager
14+
def temporary_environ(env_vars):
15+
"""
16+
Temporarily set environment variables and restore them afterward.
17+
We have to do this vs monkeypatch because monkeypatch doesn't work
18+
with "module" scoped fixtures.
19+
"""
20+
original_env = {k: os.environ.get(k) for k in env_vars}
21+
try:
22+
os.environ.update(env_vars)
23+
yield
24+
finally:
25+
for k, v in original_env.items():
26+
if v is None:
27+
os.environ.pop(k, None)
28+
else:
29+
os.environ[k] = v
30+
31+
32+
@pytest.fixture(scope="module")
33+
def full_cudagraph_llm():
34+
with temporary_environ({
35+
"VLLM_USE_V1": "1",
36+
"VLLM_FLASH_ATTN_VERSION": "3"
37+
}):
38+
return LLM(model=MODEL,
39+
gpu_memory_utilization=0.2,
40+
compilation_config=CompilationConfig(full_cuda_graph=True))
41+
42+
43+
@pytest.fixture(scope="module")
44+
def piecewise_llm():
45+
with temporary_environ({
46+
"VLLM_USE_V1": "1",
47+
"VLLM_FLASH_ATTN_VERSION": "3"
48+
}):
49+
return LLM(model=MODEL,
50+
gpu_memory_utilization=0.5,
51+
compilation_config=CompilationConfig())
52+
53+
54+
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
55+
prompts = ["Hi my name is"] * batch_size
56+
sampling_params = SamplingParams(temperature=0.0,
57+
max_tokens=max_tokens,
58+
top_p=0.95)
59+
60+
return llm.generate(prompts, sampling_params)
61+
62+
63+
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
64+
(16, 10), (25, 10),
65+
(32, 10), (45, 10),
66+
(64, 10), (8, 5),
67+
(8, 20), (8, 200)])
68+
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
69+
piecewise_llm):
70+
"""
71+
Load full cudagraph model and piecewise model once, and at the same time to
72+
reuse them across various test cases.
73+
74+
Test various batch sizes and max_tokens to ensure that the full cudagraph
75+
compilation works for padded cases too.
76+
"""
77+
piecewise_responses = generate_text(piecewise_llm,
78+
batch_size=batch_size,
79+
max_tokens=max_tokens)
80+
full_cudagraph_responses = generate_text(full_cudagraph_llm,
81+
batch_size=batch_size,
82+
max_tokens=max_tokens)
83+
84+
# Check that all responses are the same
85+
for i in range(len(piecewise_responses)):
86+
assert piecewise_responses[i].outputs[
87+
0].text == full_cudagraph_responses[i].outputs[0].text
88+
89+
90+
def test_full_cudagraph_with_invalid_backend():
91+
with temporary_environ({
92+
"VLLM_USE_V1": "1",
93+
"VLLM_FLASH_ATTN_VERSION":
94+
"2" #FA2 not supported with full_cuda_graph
95+
}), pytest.raises(RuntimeError):
96+
LLM(model=MODEL,
97+
compilation_config=CompilationConfig(full_cuda_graph=True))

vllm/config.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
36053605
are always used, it can set this to False. Otherwise, it should
36063606
set this to True, and the compiler will copy the input to an
36073607
internally managed buffer. Default is False.
3608+
- full_cuda_graph: whether to use a full cuda graph for the entire forward
3609+
pass rather than splitting certain operations such as attention into subgraphs.
3610+
Thus this flag cannot be used together with splitting_ops. This may provide
3611+
performance benefits for smaller models.
36083612
- Inductor compilation:
36093613
- use_inductor: whether to use inductor compilation.
36103614
- False: inductor compilation is not used. graph runs in eager.
@@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
36493653
cudagraph_num_of_warmups: int = 0
36503654
cudagraph_capture_sizes: Optional[list[int]] = None
36513655
cudagraph_copy_inputs: bool = False
3656+
full_cuda_graph: bool = False
36523657

36533658
class PassConfig(BaseModel):
36543659
"""
@@ -3871,10 +3876,14 @@ def init_with_cudagraph_sizes(self,
38713876
self.max_capture_size] = self.max_capture_size
38723877

38733878
def set_splitting_ops_for_v1(self):
3874-
# If default, override splitting ops for piecewise cudagraph on V1.
38753879
# NOTE: this function needs to be called
3880+
if self.splitting_ops and self.full_cuda_graph:
3881+
raise ValueError("full_cuda_graph cannot be used together with "
3882+
"splitting_ops, as Full CUDA graph will override "
3883+
f"the splitting_ops: {self.splitting_ops}")
3884+
38763885
if not self.splitting_ops:
3877-
self.splitting_ops = [
3886+
self.splitting_ops = [] if self.full_cuda_graph else [
38783887
"vllm.unified_attention",
38793888
"vllm.unified_attention_with_output",
38803889
]
@@ -4151,6 +4160,12 @@ def __post_init__(self):
41514160
"Disabling `torch.compile`.")
41524161
self.compilation_config.level = CompilationLevel.NO_COMPILATION
41534162

4163+
if self.compilation_config.full_cuda_graph and \
4164+
not self.model_config.disable_cascade_attn:
4165+
logger.warning_once(
4166+
"full_cuda_graph is not supported with "
4167+
"cascade attention. Disabling cascade attention.")
4168+
self.model_config.disable_cascade_attn = True
41544169

41554170
if self.model_config and self.model_config.use_mla and \
41564171
not (current_platform.is_cuda() or current_platform.is_rocm()):

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder:
291291

292292
def __init__(self, runner: "GPUModelRunner"):
293293
model_config = runner.model_config
294+
compilation_config = runner.vllm_config.compilation_config
294295

295296
self.runner = runner
296297
self.num_heads_q = model_config.get_num_attention_heads(
@@ -300,7 +301,14 @@ def __init__(self, runner: "GPUModelRunner"):
300301
self.headdim = model_config.get_head_size()
301302
self.page_size = self.runner.block_size
302303

303-
self.aot_schedule = (get_flash_attn_version() == 3)
304+
if get_flash_attn_version() == 3:
305+
self.aot_schedule = not compilation_config.full_cuda_graph
306+
if not self.aot_schedule:
307+
logger.warning(
308+
"AOT Schedule is disabled when using full_cuda_graph")
309+
else:
310+
self.aot_schedule = False
311+
304312
# Sliding window size to be used with the AOT scheduler will be
305313
# populated on first build() call.
306314
self.aot_sliding_window: Optional[tuple[int, int]] = None
@@ -317,8 +325,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
317325
seq_lens = common_attn_metadata.seq_lens
318326
block_table = (
319327
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
320-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
321-
self.runner.device, non_blocking=True).long()
328+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
322329

323330
if self.aot_sliding_window is None:
324331
self.aot_sliding_window = (-1, -1)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from vllm.attention import AttentionType, get_attn_backend
1414
from vllm.attention.layer import Attention
15+
from vllm.attention.utils.fa_utils import get_flash_attn_version
1516
from vllm.config import (CompilationLevel, VllmConfig,
1617
get_layers_from_vllm_config)
1718
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -139,6 +140,16 @@ def __init__(
139140
raise NotImplementedError(
140141
"Non-Attention backend is not supported by V1 GPUModelRunner.")
141142

143+
if self.vllm_config.compilation_config.full_cuda_graph:
144+
attn_backend_name = self.attn_backend.__name__
145+
flash_attn_version = get_flash_attn_version()
146+
if attn_backend_name != "FlashAttentionBackend" or \
147+
flash_attn_version != 3:
148+
raise ValueError(
149+
f"full_cuda_graph is only supported with "
150+
f"FA3. Current attention backend is {attn_backend_name}, "
151+
f"FlashAttention version is {flash_attn_version}.")
152+
142153
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
143154
weakref.proxy(self))
144155
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
@@ -219,6 +230,16 @@ def __init__(
219230
self.positions = torch.zeros(self.max_num_tokens,
220231
dtype=torch.int64,
221232
device=self.device)
233+
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
234+
dtype=torch.int32,
235+
device=self.device)
236+
self.seq_lens = torch.zeros(self.max_num_reqs,
237+
dtype=torch.int32,
238+
device=self.device)
239+
self.slot_mapping = torch.zeros(self.max_num_tokens,
240+
dtype=torch.int64,
241+
device=self.device)
242+
222243
# None in the first PP rank. The rest are set after load_model.
223244
self.intermediate_tensors: Optional[IntermediateTensors] = None
224245

@@ -271,7 +292,7 @@ def __init__(
271292
pin_memory=self.pin_memory)
272293
self.positions_np = self.positions_cpu.numpy()
273294
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
274-
dtype=torch.int32,
295+
dtype=torch.int64,
275296
device="cpu",
276297
pin_memory=self.pin_memory)
277298
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
@@ -589,10 +610,22 @@ def _prepare_inputs(
589610
self.positions_cpu[:total_num_scheduled_tokens],
590611
non_blocking=True)
591612

592-
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
593-
self.device, non_blocking=True)
594-
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
595-
non_blocking=True)
613+
self.query_start_loc[:num_reqs + 1].copy_(
614+
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
615+
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
616+
non_blocking=True)
617+
self.slot_mapping[:total_num_scheduled_tokens].copy_(
618+
self.slot_mapping_cpu[:total_num_scheduled_tokens],
619+
non_blocking=True)
620+
621+
# Fill unused with -1. Needed for reshape_and_cache
622+
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
623+
self.seq_lens[num_reqs:].fill_(0)
624+
self.query_start_loc[num_reqs + 1:].fill_(-1)
625+
626+
query_start_loc = self.query_start_loc[:num_reqs + 1]
627+
seq_lens = self.seq_lens[:num_reqs]
628+
596629
common_attn_metadata = CommonAttentionMetadata(
597630
query_start_loc=query_start_loc, seq_lens=seq_lens)
598631

@@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict(
14781511
def _dummy_run(
14791512
self,
14801513
num_tokens: int,
1514+
skip_attn: bool = True,
14811515
) -> torch.Tensor:
14821516

14831517
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
@@ -1494,6 +1528,23 @@ def _dummy_run(
14941528
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
14951529
dtype=np.int32)
14961530

1531+
if skip_attn:
1532+
attn_metadata = None
1533+
else:
1534+
query_start_loc = self.query_start_loc[:num_reqs + 1]
1535+
seq_lens = self.seq_lens[:num_reqs]
1536+
1537+
common_attn_metadata = CommonAttentionMetadata(
1538+
query_start_loc=query_start_loc, seq_lens=seq_lens)
1539+
1540+
attn_metadata = self.attn_metadata_builder.build(
1541+
num_reqs=num_tokens,
1542+
num_actual_tokens=num_tokens,
1543+
max_query_len=num_tokens,
1544+
common_prefix_len=0,
1545+
common_attn_metadata=common_attn_metadata,
1546+
)
1547+
14971548
with self.maybe_dummy_run_with_lora(self.lora_config,
14981549
num_scheduled_tokens):
14991550
model = self.model
@@ -1522,7 +1573,7 @@ def _dummy_run(
15221573
for k, v in self.intermediate_tensors.items()
15231574
})
15241575

1525-
with set_forward_context(None,
1576+
with set_forward_context(attn_metadata,
15261577
self.vllm_config,
15271578
num_tokens=num_tokens):
15281579
outputs = model(
@@ -1708,11 +1759,12 @@ def capture_model(self) -> None:
17081759
# Capture the large shapes first so that the smaller shapes
17091760
# can reuse the memory pool allocated for the large shapes.
17101761
with graph_capture(device=self.device):
1762+
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
17111763
for num_tokens in reversed(self.cudagraph_batch_sizes):
17121764
for _ in range(self.vllm_config.compilation_config.
17131765
cudagraph_num_of_warmups):
1714-
self._dummy_run(num_tokens)
1715-
self._dummy_run(num_tokens)
1766+
self._dummy_run(num_tokens, skip_attn=skip_attn)
1767+
self._dummy_run(num_tokens, skip_attn=skip_attn)
17161768

17171769
end_time = time.perf_counter()
17181770
end_free_gpu_memory = torch.cuda.mem_get_info()[0]

0 commit comments

Comments
 (0)