diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 657c430049f8..17a063d0c233 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -498,6 +498,16 @@ def __post_init__(self) -> None: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) + if ( + is_torch_equal_or_newer("2.9.0.dev") + and "combo_kernels" not in self.inductor_compile_config + and "benchmark_combo_kernel" not in self.inductor_compile_config + ): + # use horizontal fusion, which is useful for fusing qk-norm and + # qk-rope when query and key have different shapes. + self.inductor_compile_config["combo_kernels"] = True + self.inductor_compile_config["benchmark_combo_kernel"] = True + # migrate the deprecated flags if not self.use_cudagraph: logger.warning(