Skip to content

Commit fe8a0c0

Browse files
fhl2000rtourgeman
authored andcommitted
[Bugfix][CI] Move resolving cudagraph_mode before initializing attn_metadata_builder (vllm-project#27427)
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
1 parent 8f72137 commit fe8a0c0

File tree

3 files changed

+34
-19
lines changed

3 files changed

+34
-19
lines changed

docs/design/cuda_graphs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum):
167167
"""NO CUDA Graphs support"""
168168
```
169169

170-
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture].
170+
Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code for [this][vllm.v1.worker.gpu_model_runner.GPUModelRunner._check_and_update_cudagraph_mode].
171171

172172
The following table lists backends that support full CUDA Graphs at the time of writing.
173173

tests/compile/test_fusions_e2e.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def test_attn_quant(
132132
mode = CUDAGraphMode.FULL_AND_PIECEWISE
133133
splitting_ops: list[str] | None = None
134134
else:
135+
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
136+
# CUDAGraphMode.NONE here because it derives an attention backend that
137+
# does not support full cudagraphs
135138
mode = CUDAGraphMode.FULL_DECODE_ONLY
136139
splitting_ops = []
137140

vllm/v1/worker/gpu_model_runner.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3751,8 +3751,6 @@ def capture_model(self) -> int:
37513751
"ensure `cudagraph_mode` was not manually set to `NONE`"
37523752
)
37533753
return 0
3754-
else:
3755-
self.initialize_cudagraph_capture()
37563754

37573755
compilation_counter.num_gpu_runner_capture_triggers += 1
37583756

@@ -3926,7 +3924,7 @@ class AttentionGroupKey(NamedTuple):
39263924

39273925
def get_attn_backends_for_group(
39283926
kv_cache_group_spec: KVCacheGroupSpec,
3929-
) -> dict[AttentionGroupKey, list[str]]:
3927+
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
39303928
layers = get_layers_from_vllm_config(
39313929
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
39323930
)
@@ -3955,7 +3953,10 @@ def get_attn_backends_for_group(
39553953
attn_backend, layer_kv_cache_spec
39563954
)
39573955
attn_backend_layers[key].append(layer_name)
3958-
return {attn_backends[k]: v for k, v in attn_backend_layers.items()}
3956+
return (
3957+
{attn_backends[k]: v for k, v in attn_backend_layers.items()},
3958+
set(group_key.attn_backend for group_key in attn_backends.values()),
3959+
)
39593960

39603961
def create_attn_groups(
39613962
attn_backends_map: dict[AttentionGroupKey, list[str]],
@@ -3976,28 +3977,39 @@ def create_attn_groups(
39763977
attn_groups.append(attn_group)
39773978
return attn_groups
39783979

3980+
attention_backend_maps = []
3981+
attention_backend_set: set[type[AttentionBackend]] = set()
39793982
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
39803983
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
3981-
self.attn_groups.append(create_attn_groups(attn_backends))
3984+
attention_backend_maps.append(attn_backends[0])
3985+
attention_backend_set.update(attn_backends[1])
3986+
3987+
# Resolve cudagraph_mode before actually initialize metadata_builders
3988+
self._check_and_update_cudagraph_mode(attention_backend_set)
3989+
3990+
for attn_backends_map in attention_backend_maps:
3991+
self.attn_groups.append(create_attn_groups(attn_backends_map))
39823992

39833993
# Calculate reorder batch threshold (if needed)
39843994
self.calculate_reorder_batch_threshold()
39853995

3986-
def initialize_cudagraph_capture(self) -> None:
3996+
def _check_and_update_cudagraph_mode(
3997+
self, attention_backends: set[type[AttentionBackend]]
3998+
) -> None:
39873999
"""
39884000
Resolve the cudagraph_mode when there are multiple attention
39894001
backends with potential conflicting CUDA graph support.
39904002
Then initialize the cudagraph_dispatcher based on the resolved
39914003
cudagraph_mode.
39924004
"""
39934005
min_cg_support = AttentionCGSupport.ALWAYS
3994-
min_cg_builder_name = None
4006+
min_cg_backend_name = None
39954007

3996-
for attn_group in self._attn_group_iterator():
3997-
builder = attn_group.get_metadata_builder()
3998-
if builder.cudagraph_support.value < min_cg_support.value:
3999-
min_cg_support = builder.cudagraph_support
4000-
min_cg_builder_name = builder.__class__.__name__
4008+
for attn_backend in attention_backends:
4009+
builder_cls = attn_backend.get_builder_cls()
4010+
if builder_cls.cudagraph_support.value < min_cg_support.value:
4011+
min_cg_support = builder_cls.cudagraph_support
4012+
min_cg_backend_name = attn_backend.__name__
40014013
# Flexible resolve the cudagraph mode
40024014
cudagraph_mode = self.compilation_config.cudagraph_mode
40034015
# check cudagraph for mixed batch is supported
@@ -4007,7 +4019,7 @@ def initialize_cudagraph_capture(self) -> None:
40074019
):
40084020
msg = (
40094021
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
4010-
f"with {min_cg_builder_name} backend (support: "
4022+
f"with {min_cg_backend_name} backend (support: "
40114023
f"{min_cg_support})"
40124024
)
40134025
if min_cg_support == AttentionCGSupport.NEVER:
@@ -4038,7 +4050,7 @@ def initialize_cudagraph_capture(self) -> None:
40384050
):
40394051
msg = (
40404052
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
4041-
f"with {min_cg_builder_name} backend (support: "
4053+
f"with {min_cg_backend_name} backend (support: "
40424054
f"{min_cg_support})"
40434055
)
40444056
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
@@ -4072,7 +4084,7 @@ def initialize_cudagraph_capture(self) -> None:
40724084
msg = (
40734085
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
40744086
f" with spec-decode for attention backend "
4075-
f"{min_cg_builder_name} (support: {min_cg_support})"
4087+
f"{min_cg_backend_name} (support: {min_cg_support})"
40764088
)
40774089
if self.compilation_config.splitting_ops_contain_attention():
40784090
msg += "; setting cudagraph_mode=PIECEWISE"
@@ -4094,14 +4106,14 @@ def initialize_cudagraph_capture(self) -> None:
40944106
):
40954107
raise ValueError(
40964108
f"CUDAGraphMode.{cudagraph_mode.name} is not "
4097-
f"supported with {min_cg_builder_name} backend ("
4109+
f"supported with {min_cg_backend_name} backend ("
40984110
f"support:{min_cg_support}) "
40994111
"; please try cudagraph_mode=PIECEWISE, "
41004112
"and make sure compilation mode is VLLM_COMPILE"
41014113
)
41024114

4103-
# Trigger cudagraph dispatching keys initialization here (after
4104-
# initializing attn backends).
4115+
# Trigger cudagraph dispatching keys initialization after
4116+
# resolved cudagraph mode.
41054117
self.cudagraph_dispatcher.initialize_cudagraph_keys(
41064118
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
41074119
)

0 commit comments

Comments
 (0)