diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 84f4945c8272..41055f431569 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -15,6 +15,7 @@ VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter @@ -50,16 +51,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -@pytest.mark.parametrize("use_inductor", [True, False]) -@torch.inference_mode() -def test_simple_piecewise_compile(use_inductor): - assert VLLM_USE_V1 - +def _run_simple_model( + splitting_ops, + use_inductor_graph_partition, + use_inductor, + expected_num_piecewise_graphs_seen, + expected_num_piecewise_capturable_graphs_seen, + expected_num_backend_compilations, + expected_num_cudagraph_captured, +): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, use_inductor=use_inductor, - splitting_ops=["silly.attention"], + splitting_ops=splitting_ops, + use_inductor_graph_partition=use_inductor_graph_partition, cudagraph_copy_inputs=True, cudagraph_capture_sizes=[1, 2], )) @@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor): with compilation_counter.expect( num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=5, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - num_backend_compilations=3, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen= + expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ), set_forward_context(None, vllm_config=vllm_config): # background context # warm up with background context @@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor): output = model(input) assert get_global_counter() == 2 assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) + + +@pytest.mark.parametrize("use_inductor", [True, False]) +@torch.inference_mode() +def test_simple_piecewise_compile(use_inductor): + assert VLLM_USE_V1 + _run_simple_model( + splitting_ops=["silly.attention"], + use_inductor_graph_partition=False, + use_inductor=use_inductor, + expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1 + expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers + expected_num_backend_compilations= + 3, # num_piecewise_capturable_graphs_seen + expected_num_cudagraph_captured= + 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ) + + +@torch.inference_mode() +@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []]) +def test_simple_inductor_graph_partition(splitting_ops): + assert VLLM_USE_V1 + if not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available " + "in PyTorch 2.9+") + + _run_simple_model( + # inductor graph partition automatically resets splitting_ops + # to be an empty list + splitting_ops=splitting_ops, + use_inductor_graph_partition=True, + use_inductor=True, + expected_num_piecewise_graphs_seen= + 1, # since not splitting at fx graph level + expected_num_piecewise_capturable_graphs_seen= + 1, # since not splitting at fx graph level + expected_num_backend_compilations= + 1, # since not splitting at fx graph level + expected_num_cudagraph_captured= + 6, # inductor graph partition still captures 6 + # graph, same as fx graph partition. + ) diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index 13eb0bf4b1fa..baedafbae99f 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -60,4 +60,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, + tags=(torch._C.Tag.cudagraph_unsafe, ), ) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 84178344a5f3..053236af2725 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging import tempfile from typing import Any, Optional, Union @@ -10,9 +11,13 @@ import torch from tests.quantization.utils import is_quant_method_supported +from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel, PassConfig +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, + PassConfig) from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test @@ -107,6 +112,18 @@ def test_full_graph( (CompilationConfig(level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir()), ("facebook/opt-125m", {})), + ] + [ + # graph inductor partition + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, + # inductor graph partition uses + # torch._C.Tag.cudagraph_unsafe to specify splitting ops + use_inductor_graph_partition=True, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + compile_sizes=[1, 2]), + model) for model in models_list(all=False) + if is_torch_equal_or_newer("2.9.0.dev") ]) # only test some of the models @create_new_process_for_each_test() @@ -114,11 +131,51 @@ def test_custom_compile_config( compilation_config: CompilationConfig, model_info: tuple[str, dict[str, Any]], ): + if (compilation_config.use_inductor_graph_partition + and not is_torch_equal_or_newer("2.9.0.dev")): + pytest.skip("inductor graph partition is only available " + "in PyTorch 2.9+") + model, model_kwargs = model_info print(f"MODEL={model}") run_model(compilation_config, model, model_kwargs) +def test_inductor_graph_partition_attn_fusion(caplog_vllm): + if not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available " + "in PyTorch 2.9+") + + model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=True, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + custom_ops=["+quant_fp8"], + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + ) + model_kwargs = { + "kv_cache_dtype": "fp8", + "max_model_len": 1024, + } + with caplog_vllm.at_level( + logging.DEBUG), global_force_attn_backend_context_manager( + _Backend.FLASHINFER): + run_model(compilation_config, model, model_kwargs) + + try: + assert ("Fused quantization onto 48 attention nodes" + in caplog_vllm.text), caplog_vllm.text + except AssertionError: + # Note: this message is only triggered when the compilation goes + # through the custom pass. Due to multiple layers of cache on + # PyTorch side, the compilation of a graph may be cached such + # that custom pass directly goes through cache. In this case, + # we go through this branch and assert that the pass is not + # triggered. + assert "Fused quantization" not in caplog_vllm.text + + def run_model(compile_config: Union[int, CompilationConfig], model: str, model_kwargs: dict[str, Any]): prompts = [ diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 6baf4bf83f49..022f183b3193 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() @@ -339,6 +340,10 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): @pytest.mark.parametrize( "split_attention", [False, True] if current_platform.is_rocm() else [False]) +# TODO(boyuan): test inductor graph partition on rocm +@pytest.mark.parametrize( + "use_inductor_graph_partition", + [False] if current_platform.is_rocm() else [False, True]) @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, dtype: torch.dtype, model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, split_attention: bool, - monkeypatch, dist_init): + use_inductor_graph_partition: bool, + monkeypatch, dist_init, caplog_vllm): """Test AttentionStaticQuantPattern fusion pass""" + if use_inductor_graph_partition and not is_torch_equal_or_newer( + "2.9.0.dev"): + pytest.skip("inductor graph partition is only available " + "in PyTorch 2.9+") + monkeypatch.setenv("VLLM_USE_V1", "1") if split_attention: monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") @@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+quant_fp8"], + use_inductor_graph_partition=use_inductor_graph_partition, ), cache_config=CacheConfig(cache_dtype="fp8")) @@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, backend=test_backend, fullgraph=True) assert model_compiled.attn._o_scale_float is None + result_fused_1 = model_compiled(q, k, v) if backend == _Backend.FLASHINFER: @@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # _o_scale_float assert model_compiled.attn._o_scale_float is not None result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None torch.testing.assert_close(result_unfused, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 15c0ce33e965..b4701380f82c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -575,6 +575,7 @@ def unified_attention_fake( mutates_args=[], fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, + tags=(torch._C.Tag.cudagraph_unsafe, ), ) @@ -625,4 +626,5 @@ def unified_attention_with_output_fake( mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, + tags=(torch._C.Tag.cudagraph_unsafe, ), ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3cc0fc3106f5..28f1bc1552ab 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -326,6 +326,7 @@ def call_module(self, target: torch.fx.node.Target, i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time + compiled_graph_for_dynamic_shape = self.vllm_backend.\ compiler_manager.compile( submod, @@ -336,7 +337,6 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) # Lazy import here to avoid circular import - from .cuda_graph import CUDAGraphOptions from .cuda_piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( @@ -344,7 +344,13 @@ def call_module(self, target: torch.fx.node.Target, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) - if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and + not self.compilation_config.use_inductor_graph_partition): + # We're using Dynamo-based piecewise splitting, so we wrap + # the whole subgraph with a static graph wrapper. + from .cuda_graph import CUDAGraphOptions + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper # class) as platform dependent. static_graph_wrapper_class = resolve_obj_by_qualname( diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 41d9fcb824b0..b7a6e23c1aa7 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import inspect from typing import Callable, Optional, TypeVar, Union, overload from unittest.mock import patch @@ -14,7 +15,7 @@ from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.utils import supports_dynamo +from vllm.utils import resolve_obj_by_qualname, supports_dynamo from .monitor import start_monitoring_torch_compile @@ -301,8 +302,11 @@ def patched_inline_call(parent, func, args, kwargs): with patch.object(InliningInstructionTranslator, 'inline_call', patched_inline_call), torch._dynamo.config.patch( - **dynamo_config_patches): + **dynamo_config_patches + ), maybe_use_cudagraph_partition_wrapper( + self.vllm_config): output = self.compiled_callable(*args, **kwargs) + return output # usually, capturing the model once is enough, and then we can @@ -314,3 +318,52 @@ def patched_inline_call(parent, func, args, kwargs): cls.__call__ = __call__ return cls + + +@contextlib.contextmanager +def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): + """ + Context manager to set/unset customized cudagraph partition wrappers. + + If we're using Inductor-based graph partitioning, we currently have the + whole `fx.Graph` before Inductor lowering and and the piecewise + splitting happens after all graph passes and fusions. Here, we add + a custom hook for Inductor to wrap each partition with our static + graph wrapper class to maintain more control over static graph + capture and replay. + """ + from vllm.config import CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and compilation_config.use_inductor_graph_partition): + from torch._inductor.utils import CUDAGraphWrapperMetadata + + from vllm.compilation.cuda_graph import CUDAGraphOptions + from vllm.platforms import current_platform + + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + def customized_cudagraph_wrapper(f, + metadata: CUDAGraphWrapperMetadata): + partition_id = metadata.partition_index + num_partitions = metadata.num_partitions + return static_graph_wrapper_class( + runnable=f, + vllm_config=vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=partition_id == 0, + gc_disable=partition_id != 0, + weak_ref_output=partition_id == num_partitions - 1, + )) + + torch._inductor.utils.set_customized_partition_wrappers( + customized_cudagraph_wrapper) + + yield + + if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and compilation_config.use_inductor_graph_partition): + torch._inductor.utils.set_customized_partition_wrappers(None) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3618f472e742..22b38daf46c3 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -299,6 +299,26 @@ class CompilationConfig: minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. """ + use_inductor_graph_partition: bool = False + """Use inductor graph partition to split the graph at cudagraph_unsafe ops. + This partition happens at inductor codegen time after all passes and fusions + are finished. It generates a single `call` function which wraps + cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops + outside the partition functions. For a graph with N cudagraph-unsafe ops + (e.g., Attention), there would be N+1 partitions. To mark an op as + cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when + register the custom op. + + This config supports both full cudagraph and piecewise cudagraph without + compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper + to each partition. For N+1 partitions, there would be N+1 + CUDAGraph wrapper instances. + + For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the + inductor `call` function in the model runner. The top-level full cudagraph + capture ignores all partitioning. + """ + pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -461,6 +481,12 @@ def __post_init__(self) -> None: "since full_cuda_graph is deprecated.") self.cudagraph_mode = CUDAGraphMode.FULL + if (self.use_inductor_graph_partition + and not is_torch_equal_or_newer("2.9.0.dev")): + raise ValueError("use_inductor_graph_partition is only " + "supported with torch>=2.9.0.dev. Set " + "use_inductor_graph_partition=False instead.") + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -540,19 +566,36 @@ def set_splitting_ops_for_v1(self): "set_splitting_ops_for_v1 should only be called when " "level is CompilationLevel.PIECEWISE") + use_inductor_graph_partition_msg = ( + "When use_inductor_graph_partition=True, splitting_ops " + "are ignored and set to an empty list. Instead, " + "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " + "used to annotate custom ops for graph partition.") + if self.splitting_ops is None: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture the - # full cudagraph outside the fx graph. This reduces some cpu - # overhead when the runtime batch_size is not cudagraph captured. - # see https://github.com/vllm-project/vllm/pull/20059 for details. - # make a copy to avoid mutating the class-level list via reference. - self.splitting_ops = list(self._attention_ops) + if self.use_inductor_graph_partition: + # When using inductor graph partition, we set splitting_ops + # to be empty and rely on torch._C.Tag.cudagraph_unsafe to + # annotate custom ops as splitting ops. + logger.warning_once(use_inductor_graph_partition_msg) + self.splitting_ops = [] + else: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty " - "splitting_ops.") - if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "Using piecewise compilation with empty " + "splitting_ops and use_inductor_graph_partition" + f"={self.use_inductor_graph_partition}.") + if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE + and not self.use_inductor_graph_partition): logger.warning_once( "When compilation level is piecewise with empty " "splitting_ops, PIECEWISE cudagraph_mode will be " @@ -562,7 +605,26 @@ def set_splitting_ops_for_v1(self): "any problems.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + elif self.use_inductor_graph_partition: + logger.warning_once(use_inductor_graph_partition_msg) + self.splitting_ops = [] def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) + + def is_attention_compiled_piecewise(self) -> bool: + use_fx_graph_piecewise_compilation = ( + self.level == CompilationLevel.PIECEWISE + and self.splitting_ops_contain_attention()) + + inductor_used = (self.level == CompilationLevel.PIECEWISE + and self.use_inductor) or ( + self.level >= CompilationLevel.DYNAMO_AS_IS + and self.backend == "inductor") + use_inductor_piecewise_compilation = ( + inductor_used and self.use_inductor_graph_partition + and not self.splitting_ops_contain_attention()) + + return use_fx_graph_piecewise_compilation or \ + use_inductor_piecewise_compilation diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index d2db7dcb3f09..ea4fba8eeea6 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor from vllm.logger import init_logger @@ -39,11 +39,15 @@ def __init__(self, vllm_config: VllmConfig): CUDAGraphMode.FULL: set(), } - assert not self.cudagraph_mode.requires_piecewise_compilation() or \ - (self.compilation_config.level == CompilationLevel.PIECEWISE and - self.compilation_config.splitting_ops_contain_attention()), \ + not_use_piecewise_compilation = ( + not self.cudagraph_mode.requires_piecewise_compilation()) + + assert not_use_piecewise_compilation or \ + self.compilation_config.is_attention_compiled_piecewise(), \ "Compilation level should be CompilationLevel.PIECEWISE when "\ "cudagraph_mode piecewise cudagraphs is used, "\ + "and attention should be in splitting_ops or "\ + "inductor splitting should be used. " \ f"cudagraph_mode={self.cudagraph_mode}, "\ f"compilation_level={self.compilation_config.level}, "\ f"splitting_ops={self.compilation_config.splitting_ops}"