diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 45317b456af4..eaf0a15479e9 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -258,13 +258,13 @@ def tractable_computation( @torch.inference_mode def run_model( - llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False + llama_config, use_compile: bool, backend: str, split_attn: bool = False ) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - use_inductor=use_inductor, + backend=backend, cudagraph_capture_sizes=[1, 2], ) if split_attn: @@ -338,8 +338,8 @@ def run_model( return output.cpu() -@pytest.mark.parametrize("use_inductor", [True, False]) -def test_toy_llama(use_inductor: bool): +@pytest.mark.parametrize("backend", ["inductor", "eager"]) +def test_toy_llama(backend: str): # compare output with and without piecewise compilation llama_config = LlamaConfig( @@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) - run_model(tractable_config, use_inductor=False, use_compile=False) + outputs.append(run_model(llama_config, backend="eager", use_compile=False)) + run_model(tractable_config, backend="eager", use_compile=False) - if use_inductor: + if backend == "inductor": kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} else: kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} @@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool): num_cudagraph_captured=2, **kwargs, ): - outputs.append( - run_model(llama_config, use_inductor=use_inductor, use_compile=True) - ) - run_model(tractable_config, use_inductor=use_inductor, use_compile=True) + outputs.append(run_model(llama_config, backend=backend, use_compile=True)) + run_model(tractable_config, backend=backend, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool): ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model( - llama_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True, - ) + run_model(llama_config, backend=backend, use_compile=True, split_attn=True) ) - run_model( - tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True - ) + run_model(tractable_config, backend=backend, use_compile=True, split_attn=True) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 9bfd72260436..ab6a17e149fc 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -77,14 +77,15 @@ class TestSetting: method="encode", ), # vision language model - TestSetting( - model="microsoft/Phi-3.5-vision-instruct", - model_args=["--trust-remote-code", "--max-model-len", "2048"], - pp_size=2, - tp_size=1, - attn_backend="FLASH_ATTN", - method="generate_with_image", - ), + # See https://github.com/vllm-project/vllm/issues/26716. + # TestSetting( + # model="microsoft/Phi-3.5-vision-instruct", + # model_args=["--trust-remote-code", "--max-model-len", "2048"], + # pp_size=2, + # tp_size=1, + # attn_backend="FLASH_ATTN", + # method="generate_with_image", + # ), ], ) def test_compile_correctness( @@ -109,41 +110,46 @@ def test_compile_correctness( with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) final_args = [ - "--enforce-eager", *model_args, "-pp", str(pp_size), "-tp", str(tp_size), + "-O.cudagraph_mode=none", ] all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] - for level in [ - CompilationLevel.NO_COMPILATION, + for comp_level in [ + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE, ]: - all_args.append(final_args + [f"-O{level}"]) - all_envs.append({}) + for level in [CompilationLevel.NO_COMPILATION, comp_level]: + all_args.append( + final_args + [f"-O.level={level}", "-O.backend=inductor"] + ) - # inductor will change the output, so we only compare if the output - # is close, not exactly the same. - compare_all_settings( - model, - all_args, - all_envs, - method=method if method != "generate" else "generate_close", - ) - all_envs.clear() - all_args.clear() + # inductor will change the output, so we only compare if the output + # is close, not exactly the same. + compare_all_settings( + model, + all_args, + all_envs, + method=method if method != "generate" else "generate_close", + ) + all_envs.clear() + all_args.clear() for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, + CompilationLevel.PIECEWISE, ]: - all_args.append(final_args + [f"-O{level}"]) + all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"]) + all_envs.append({}) all_envs.append({}) compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index bf290079469a..254e9b3ab8af 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -36,55 +36,56 @@ class Relu3(ReLUSquaredActivation): @pytest.mark.parametrize( - "env, torch_level, use_inductor, ops_enabled, default_on", + "env, torch_level, backend, ops_enabled, default_on", [ # Default values based on compile level # - All by default (no Inductor compilation) - (None, 0, False, [True] * 4, True), - (None, 1, True, [True] * 4, True), - (None, 2, False, [True] * 4, True), + (None, 0, "eager", [True] * 4, True), + (None, 1, "eager", [True] * 4, True), + (None, 2, "eager", [True] * 4, True), + (None, 3, "eager", [True] * 4, True), # - None by default (with Inductor) - (None, 3, True, [False] * 4, False), - (None, 4, True, [False] * 4, False), - # - All by default (without Inductor) - (None, 3, False, [True] * 4, True), - (None, 4, False, [True] * 4, True), + (None, 0, "inductor", [True] * 4, True), + # - None by default (with Inductor) + (None, 1, "inductor", [False] * 4, False), + (None, 2, "inductor", [False] * 4, False), + (None, 3, "inductor", [False] * 4, False), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul - ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True), + ("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True), # Only ReLU3 - ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False), + ("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False), # All but SiluAndMul - ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), + ("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True), # RMSNorm and SiluAndMul - ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), + ("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False), # All but RMSNorm - ("-rms_norm", 3, False, [0, 1, 1, 1], True), + ("-rms_norm", 3, "eager", [0, 1, 1, 1], True), # # Default: none # # Only ReLU3 - ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), + ("none,+relu3", 3, "inductor", [0, 0, 0, 1], False), # All but RMSNorm - ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), + ("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True), ], ) def test_enabled_ops( env: str | None, torch_level: int, - use_inductor: bool, + backend: str, ops_enabled: list[int], default_on: bool, ): custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig( - use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops + backend=backend, level=torch_level, custom_ops=custom_ops ) ) with set_current_vllm_config(vllm_config): diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e559fdb397fa..46c433fe6aef 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -41,7 +41,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: - if compilation_config.use_inductor: + if compilation_config.backend == "inductor": # Use standalone compile only if requested, version is new enough, # and the symbol actually exists in this PyTorch build. if ( @@ -55,6 +55,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: logger.debug("Using InductorAdaptor") return InductorAdaptor() else: + assert compilation_config.backend == "eager", ( + "Custom backends not supported with CompilationLevel.PIECEWISE" + ) + logger.debug("Using EagerAdaptor") return EagerAdaptor() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 657c430049f8..5313112a19a6 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -15,6 +15,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import config from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname if TYPE_CHECKING: @@ -187,7 +188,8 @@ class CompilationConfig: backend: str = "" """The backend for compilation. It needs to be a string: - - "" (empty string): use the default backend. + - "" (empty string): use the default backend ("inductor" on CUDA-alike + platforms). - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - "full.module.name": a qualified name which can be used to import the @@ -196,7 +198,12 @@ class CompilationConfig: distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation - (it sees a part of the graph).""" + (it sees a part of the graph). The backend can not be custom for compilation + level 3, i.e. the backend must be either eager or inductor. Furthermore, + compilation is only piecewise if splitting ops is set accordingly and + use_inductor_cudagraphs_partition is off. Note that the default options for + splitting ops are sufficient for piecewise compilation. + """ custom_ops: list[str] = field(default_factory=list) """Fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op @@ -229,8 +236,12 @@ class CompilationConfig: If empty list [], no ops are excluded (suitable for full cudagraphs).""" # Inductor capture - use_inductor: bool = True - """Whether to use inductor compilation: + use_inductor: bool | None = None + """ + Whether to use inductor compilation. + + This flag is deprecated and will be removed in the next release 0.12.0. + Please use the 'backend' option instead. - False: inductor compilation is not used. graph runs in eager (custom_ops enabled by default). @@ -238,7 +249,11 @@ class CompilationConfig: One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - This setting is ignored if level None: "(where 'op' is the registered op name)" ) + # Currently only eager and inductor backend are supported. + # for piecewise compilation. Custom backends are not suppported for + # piecewise compilation. Update when more backends are supported. + if self.level == CompilationLevel.PIECEWISE and self.backend not in [ + "", + "eager", + "inductor", + ]: + raise ValueError( + f"Invalid backend for piecewise compilation: {self.backend}" + ) + + if self.use_inductor is not None: + logger.warning_once( + "The 'use_inductor' flag is deprecated and will be " + "removed in the next release (v0.12.0). " + "Please use the 'backend' option instead.", + ) + self.backend = "inductor" if self.use_inductor else "eager" + + if self.backend == "": + self.backend = current_platform.simple_compile_backend + def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: + """ + Initialize the backend for the compilation config from a vllm config. + Arguments: + vllm_config: The vllm config to initialize the backend from. + Returns: + The backend for the compilation config. + """ + if self.level is None: + raise ValueError( + "No compilation level is set. This method should only be \ + called via vllm config where the level is set if none is \ + provided." + ) if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -553,15 +604,15 @@ def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: torch_backends = list_backends(exclude_tags=tuple()) if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - if self.backend == "": - return "eager" if self.backend in torch_backends: return self.backend return resolve_obj_by_qualname(self.backend) - # TODO: pass user-specified backend to piecewise compilation - # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE + if self.backend not in ["eager", "inductor"]: + raise ValueError( + f"Invalid backend for piecewise compilation: {self.backend}" + ) from vllm.compilation.backends import VllmBackend @@ -710,7 +761,9 @@ def is_attention_compiled_piecewise(self) -> bool: return self.level == CompilationLevel.PIECEWISE # Inductor partition case - return self.level > CompilationLevel.NO_COMPILATION and self.use_inductor + return ( + self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION + ) def custom_op_log_check(self): """ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b15d122c9161..c94101bf608f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -322,6 +322,20 @@ def __post_init__(self): # NB: Passing both --enforce-eager and a compilation level # in V0 means the compilation level wins out. self.compilation_config.level = CompilationLevel.NO_COMPILATION + else: + assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION + assert self.compilation_config.level <= CompilationLevel.PIECEWISE + + # If user does not set custom ops via none or all set it here based on + # compilation level and backend. + if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): + if ( + self.compilation_config.backend == "inductor" + and self.compilation_config.level > CompilationLevel.NO_COMPILATION + ): + self.compilation_config.custom_ops.append("none") + else: + self.compilation_config.custom_ops.append("all") # async tp is built on top of sequence parallelism # and requires it to be enabled. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 7f75066f2c36..9ef696d80712 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -113,7 +113,9 @@ def enabled(cls) -> bool: custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( - "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501 + "Custom op %s was not registered, which means it won't appear " + "in the op registry. It will be enabled/disabled based on the " + "global settings.", cls.__name__, ) return CustomOp.default_on() @@ -127,19 +129,17 @@ def enabled(cls) -> bool: @staticmethod def default_on() -> bool: """ - On by default if PyTorch Inductor is not used. - Specifying 'all' or 'none' in custom_op takes precedence. + Behavior controlled by `CompilationConfig.custom_ops`: On by default if + 'all', off by default if 'none'. + When PyTorch Inductor is used, 'none' is the default value, + otherwise 'all'. """ - from vllm.config import CompilationLevel - compilation_config = get_cached_compilation_config() - default_on = ( - compilation_config.level < CompilationLevel.PIECEWISE - or not compilation_config.use_inductor - ) count_none = compilation_config.custom_ops.count("none") count_all = compilation_config.custom_ops.count("all") - return default_on and not count_none > 0 or count_all > 0 + assert count_none + count_all == 1 + + return not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). # To check if an op with a name is enabled, call .enabled() on the class. diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index ed6724b298a5..17d610ac16a3 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -275,8 +275,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "epilogue_fusion": True, } ) - if compilation_config.use_inductor: - compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION