diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c3aff8ddad49..e053367fb3d7 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -258,13 +258,13 @@ def tractable_computation( @torch.inference_mode def run_model( - llama_config, use_compile: bool, backend: str, split_attn: bool = False + llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False ) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - backend=backend, + use_inductor=use_inductor, cudagraph_capture_sizes=[1, 2], ) if split_attn: @@ -338,8 +338,8 @@ def run_model( return output.cpu() -@pytest.mark.parametrize("backend", ["inductor", "eager"]) -def test_toy_llama(backend: str): +@pytest.mark.parametrize("use_inductor", [True, False]) +def test_toy_llama(use_inductor: bool): # compare output with and without piecewise compilation llama_config = LlamaConfig( @@ -358,10 +358,10 @@ def test_toy_llama(backend: str): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(llama_config, backend="eager", use_compile=False)) - run_model(tractable_config, backend="eager", use_compile=False) + outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) + run_model(tractable_config, use_inductor=False, use_compile=False) - if backend == "inductor": + if use_inductor: kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} else: kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} @@ -377,8 +377,10 @@ def test_toy_llama(backend: str): num_cudagraph_captured=2, **kwargs, ): - outputs.append(run_model(llama_config, backend=backend, use_compile=True)) - run_model(tractable_config, backend=backend, use_compile=True) + outputs.append( + run_model(llama_config, use_inductor=use_inductor, use_compile=True) + ) + run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -393,9 +395,16 @@ def test_toy_llama(backend: str): ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, backend=backend, use_compile=True, split_attn=True) + run_model( + llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True, + ) ) - run_model(tractable_config, backend=backend, use_compile=True, split_attn=True) + run_model( + tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True + ) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index ab3a3a8268a3..12aad4cb8da0 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -37,59 +37,57 @@ class Relu3(ReLUSquaredActivation): @pytest.mark.parametrize( - "env, torch_level, backend, ops_enabled, default_on", + "env, torch_level, use_inductor, ops_enabled, default_on", [ # Default values based on compile level # - All by default (no Inductor compilation) - (None, 0, "eager", [True] * 4, True), - (None, 1, "eager", [True] * 4, True), - (None, 2, "eager", [True] * 4, True), - (None, 3, "eager", [True] * 4, True), + (None, 0, False, [True] * 4, True), + (None, 1, True, [True] * 4, True), + (None, 2, False, [True] * 4, True), # - None by default (with Inductor) - (None, 0, "inductor", [True] * 4, True), - # - None by default (with Inductor) - (None, 1, "inductor", [False] * 4, False), - (None, 2, "inductor", [False] * 4, False), - (None, 3, "inductor", [False] * 4, False), + (None, 3, True, [False] * 4, False), + (None, 4, True, [False] * 4, False), + # - All by default (without Inductor) + (None, 3, False, [True] * 4, True), + (None, 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul - ("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True), + ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True), # Only ReLU3 - ("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False), + ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False), # All but SiluAndMul - ("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True), + ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) - ("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True), + ("-relu3,+relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul - ("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False), + ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm - ("-rms_norm", 3, "eager", [0, 1, 1, 1], True), + ("-rms_norm", 3, False, [0, 1, 1, 1], True), # # Default: none # # Only ReLU3 - ("none,+relu3", 3, "inductor", [0, 0, 0, 1], False), + ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm - ("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True), + ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ], ) def test_enabled_ops( env: Optional[str], torch_level: int, - backend: str, + use_inductor: bool, ops_enabled: list[int], default_on: bool, ): custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig( - backend=backend, level=torch_level, custom_ops=custom_ops + use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops ) ) - # breakpoint() with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 55bd3d0c60b1..da9debbb0e27 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -34,7 +34,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: - if compilation_config.backend == "inductor": + if compilation_config.use_inductor: # Use standalone compile only if requested, version is new enough, # and the symbol actually exists in this PyTorch build. if ( @@ -48,10 +48,6 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: logger.debug("Using InductorAdaptor") return InductorAdaptor() else: - assert compilation_config.backend == "eager", ( - "Custom backends not supported with CompilationLevel.PIECEWISE" - ) - logger.debug("Using EagerAdaptor") return EagerAdaptor() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index f47fec12d7f9..7ed757fd59b0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -180,11 +180,10 @@ class CompilationConfig: """The directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use model-related information to generate a cache directory.""" - backend: str = "inductor" + backend: str = "" """The backend for compilation. It needs to be a string: - - "" (empty string): use the default backend ("inductor" on CUDA-alike - platforms). + - "" (empty string): use the default backend. - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - "full.module.name": a qualified name which can be used to import the @@ -193,11 +192,7 @@ class CompilationConfig: distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation - (it sees a part of the graph). The backend can not be custom for compilation - level 3. Furthermore, compilation is only piecewise if splitting ops is set - accordingly and use_inductor_cudagraphs_partition is off. Note that the - default options for splitting ops are sufficient for piecewise compilation. - """ + (it sees a part of the graph).""" custom_ops: list[str] = field(default_factory=list) """Fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op @@ -215,12 +210,8 @@ class CompilationConfig: compilation.""" # Inductor capture - use_inductor: Optional[bool] = None - """ - Whether to use inductor compilation. - - This flag is deprecated and will be removed. - Please use the 'backend' option instead. + use_inductor: bool = True + """Whether to use inductor compilation: - False: inductor compilation is not used. graph runs in eager (custom_ops enabled by default). @@ -228,11 +219,7 @@ class CompilationConfig: One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - This setting is ignored if level None: "(where 'op' is the registered op name)" ) - # Currently only eager and inductor backend are supported. - # for piecewise compilation. Custom backends are not suppported for - # piecewise compilation. Update when more backends are supported. - if self.level == CompilationLevel.PIECEWISE and self.backend not in [ - "", - "eager", - "inductor", - ]: - raise ValueError( - f"Invalid backend for piecewise compilation: {self.backend}" - ) - - if self.use_inductor is not None: - logger.warning_once( - "The 'use_inductor' flag is deprecated and will be\ - removed in a future release." - "Please use the 'backend' option instead.", - ) - self.backend = "inductor" if self.use_inductor else "eager" - - if self.backend == "": - self.backend = "inductor" - def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: - """ - Initialize the backend for the compilation config from a vllm config. - Arguments: - vllm_config: The vllm config to initialize the backend from. - Returns: - The backend for the compilation config. - """ - if self.level is None: - raise ValueError( - "No compilation level is set. This method should only be \ - called via vllm config where the level is set if none is \ - provided." - ) if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -582,15 +533,15 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: torch_backends = list_backends(exclude_tags=tuple()) if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: + if self.backend == "": + return "eager" if self.backend in torch_backends: return self.backend return resolve_obj_by_qualname(self.backend) + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE - if self.backend not in ["eager", "inductor"]: - raise ValueError( - f"Invalid backend for piecewise compilation: {self.backend}" - ) from vllm.compilation.backends import VllmBackend @@ -743,7 +694,7 @@ def is_attention_compiled_piecewise(self) -> bool: ) inductor_used = ( - self.level == CompilationLevel.PIECEWISE and self.backend == "inductor" + self.level == CompilationLevel.PIECEWISE and self.use_inductor ) or ( self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor" ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f6e46bb27013..833581035a31 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -318,25 +318,6 @@ def __post_init__(self): # NB: Passing both --enforce-eager and a compilation level # in V0 means the compilation level wins out. self.compilation_config.level = CompilationLevel.NO_COMPILATION - else: - assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION - assert self.compilation_config.level <= CompilationLevel.PIECEWISE - assert self.compilation_config.level <= 3 - - # If user does not set custom ops via none or all set it here based on - # compilation level and backend. - if ( - self.compilation_config.custom_ops.count("none") - + self.compilation_config.custom_ops.count("all") - == 0 - ): - if ( - self.compilation_config.level > 0 - and self.compilation_config.backend != "eager" - ): - self.compilation_config.custom_ops.append("none") - else: - self.compilation_config.custom_ops.append("all") # async tp is built on top of sequence parallelism # and requires it to be enabled. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6a0ea266378a..ad5a09ca970d 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -114,9 +114,7 @@ def enabled(cls) -> bool: custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( - "Custom op %s was not registered, which means it won't appear\ - in the op registry. It will be enabled/disabled based on the\ - global settings.", # noqa: E501 + "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501 cls.__name__, ) return CustomOp.default_on() @@ -130,17 +128,19 @@ def enabled(cls) -> bool: @staticmethod def default_on() -> bool: """ - Behavior controlled by `CompilationConfig.custom_ops`: On by default if - 'all', off by default if 'none'. - When PyTorch Inductor is used, 'none' is the default value, - otherwise 'all'. + On by default if PyTorch Inductor is not used. + Specifying 'all' or 'none' in custom_op takes precedence. """ + from vllm.config import CompilationLevel + compilation_config = get_cached_compilation_config() + default_on = ( + compilation_config.level < CompilationLevel.PIECEWISE + or not compilation_config.use_inductor + ) count_none = compilation_config.custom_ops.count("none") count_all = compilation_config.custom_ops.count("all") - assert count_none + count_all == 1 - - return not count_none > 0 or count_all > 0 + return default_on and not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). # To check if an op with a name is enabled, call .enabled() on the class. diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 24e08a8ecbd7..2f87664003dc 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -274,6 +274,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "epilogue_fusion": True, } ) + if compilation_config.use_inductor: + compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION