Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5b46656
added deprecation warning for use_inductor in CompilationConfig
morrison-turnansky Oct 1, 2025
7f6e73a
added logic for init_backend to take in backend information
morrison-turnansky Oct 1, 2025
3688ac5
removed use_inductor from code paths
morrison-turnansky Oct 2, 2025
cc540fa
added all as default for custom_ops field in CompilitionConfig, made …
morrison-turnansky Oct 2, 2025
18c765d
updated test_enabled_ops
morrison-turnansky Oct 2, 2025
d99a58a
toy llam tests passing
morrison-turnansky Oct 2, 2025
3069253
removed unused parameter
morrison-turnansky Oct 2, 2025
d6961a0
Apply suggestion from @ProExpertProg
morrison-turnansky Oct 3, 2025
aaf805b
updated behaviour of custom ops initialziation
morrison-turnansky Oct 3, 2025
53c3fb7
saving progress
morrison-turnansky Oct 3, 2025
4fe3326
resolved comments for compilation
morrison-turnansky Oct 3, 2025
7883358
Apply suggestion from @ProExpertProg
morrison-turnansky Oct 3, 2025
5acc937
resolved comments for custom_op
morrison-turnansky Oct 3, 2025
6a1c87b
resolved comments cpu
morrison-turnansky Oct 3, 2025
95e0a9f
Apply suggestion from @ProExpertProg
morrison-turnansky Oct 3, 2025
798a7fa
updated test_enabled_custom_ops to be consistent with new initialziation
morrison-turnansky Oct 3, 2025
b1d6ca6
formattting
morrison-turnansky Oct 3, 2025
13be66b
Update vllm/config/compilation.py
morrison-turnansky Oct 6, 2025
c979a6f
Update vllm/config/vllm.py
morrison-turnansky Oct 6, 2025
620aadd
adjusted according to reviewer comments
morrison-turnansky Oct 6, 2025
01007f2
linting
morrison-turnansky Oct 6, 2025
bd153a2
Update vllm/compilation/backends.py
ProExpertProg Oct 6, 2025
a1a08a3
updated vllm custom op default behaviour so that non is appened if an…
morrison-turnansky Oct 9, 2025
40826c7
changed test_basic_correctness for new backend behaviour
morrison-turnansky Oct 9, 2025
34a2460
adjusted to 0. format
morrison-turnansky Oct 13, 2025
6ba0620
linting
morrison-turnansky Oct 13, 2025
4c0cd9f
tpu fix
morrison-turnansky Oct 13, 2025
14371c9
made 0.level explicit in test_basic_correctness
morrison-turnansky Oct 13, 2025
70b554c
Update vllm/config/vllm.py
morrison-turnansky Oct 13, 2025
f500719
changed default backed from inductor to
morrison-turnansky Oct 13, 2025
19ddc8d
updated according to comments
morrison-turnansky Oct 13, 2025
e438c43
adjusted for comments
morrison-turnansky Oct 13, 2025
96e2c1e
Update vllm/model_executor/custom_op.py
morrison-turnansky Oct 13, 2025
0f8e420
Update vllm/config/vllm.py
morrison-turnansky Oct 13, 2025
975c909
Update vllm/config/compilation.py
morrison-turnansky Oct 13, 2025
65bc52a
Update tests/compile/test_basic_correctness.py
morrison-turnansky Oct 13, 2025
51d4aa4
Update tests/compile/test_basic_correctness.py
morrison-turnansky Oct 13, 2025
5b5455c
reviewer comments
morrison-turnansky Oct 13, 2025
c25bafb
Add version number to deprecation warning
ProExpertProg Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def tractable_computation(

@torch.inference_mode
def run_model(
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
llama_config, use_compile: bool, backend: str, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
Expand Down Expand Up @@ -338,8 +338,8 @@ def run_model(
return output.cpu()


@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
# compare output with and without piecewise compilation

llama_config = LlamaConfig(
Expand All @@ -358,10 +358,10 @@ def test_toy_llama(use_inductor: bool):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)

if use_inductor:
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
Expand All @@ -377,10 +377,8 @@ def test_toy_llama(use_inductor: bool):
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
Expand All @@ -395,16 +393,9 @@ def test_toy_llama(use_inductor: bool):
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)

for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
Expand Down
54 changes: 30 additions & 24 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ class TestSetting:
method="encode",
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
),
# See https://github.com/vllm-project/vllm/issues/26716.
# TestSetting(
# model="microsoft/Phi-3.5-vision-instruct",
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
# pp_size=2,
# tp_size=1,
# attn_backend="FLASH_ATTN",
# method="generate_with_image",
# ),
],
)
def test_compile_correctness(
Expand All @@ -109,41 +110,46 @@ def test_compile_correctness(
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [
"--enforce-eager",
*model_args,
"-pp",
str(pp_size),
"-tp",
str(tp_size),
"-O.cudagraph_mode=none",
]

all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = []

for level in [
CompilationLevel.NO_COMPILATION,
for comp_level in [
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
for level in [CompilationLevel.NO_COMPILATION, comp_level]:
all_args.append(
final_args + [f"-O.level={level}", "-O.backend=inductor"]
)

# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + [f"-O{level}"])
all_args.append(final_args + [f"-O.level={level}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})

compare_all_settings(model, all_args * 3, all_envs, method=method)
39 changes: 20 additions & 19 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,55 +36,56 @@ class Relu3(ReLUSquaredActivation):


@pytest.mark.parametrize(
"env, torch_level, use_inductor, ops_enabled, default_on",
"env, torch_level, backend, ops_enabled, default_on",
[
# Default values based on compile level
# - All by default (no Inductor compilation)
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
(None, 0, "eager", [True] * 4, True),
(None, 1, "eager", [True] * 4, True),
(None, 2, "eager", [True] * 4, True),
(None, 3, "eager", [True] * 4, True),
# - None by default (with Inductor)
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
(None, 0, "inductor", [True] * 4, True),
# - None by default (with Inductor)
(None, 1, "inductor", [False] * 4, False),
(None, 2, "inductor", [False] * 4, False),
(None, 3, "inductor", [False] * 4, False),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 3, False, [0, 1, 1, 1], True),
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
],
)
def test_enabled_ops(
env: str | None,
torch_level: int,
use_inductor: bool,
backend: str,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
backend=backend, level=torch_level, custom_ops=custom_ops
)
)
with set_current_vllm_config(vllm_config):
Expand Down
6 changes: 5 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
Expand All @@ -55,6 +55,10 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
)

logger.debug("Using EagerAdaptor")
return EagerAdaptor()

Expand Down
73 changes: 63 additions & 10 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname

if TYPE_CHECKING:
Expand Down Expand Up @@ -187,7 +188,8 @@ class CompilationConfig:
backend: str = ""
"""The backend for compilation. It needs to be a string:

- "" (empty string): use the default backend.
- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the

Expand All @@ -196,7 +198,12 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph)."""
(it sees a part of the graph). The backend can not be custom for compilation
level 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_cudagraphs_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
"""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
Expand Down Expand Up @@ -229,16 +236,24 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""

# Inductor capture
use_inductor: bool = True
"""Whether to use inductor compilation:
use_inductor: bool | None = None
"""
Whether to use inductor compilation.

This flag is deprecated and will be removed in the next release 0.12.0.
Please use the 'backend' option instead.

- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.

This setting is ignored if level<PIECEWISE."""
This setting is ignored if level<PIECEWISE.

For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
compile_sizes: list[int | str] | None = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
Expand Down Expand Up @@ -545,23 +560,59 @@ def __post_init__(self) -> None:
"(where 'op' is the registered op name)"
)

# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)

if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be "
"removed in the next release (v0.12.0). "
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"

if self.backend == "":
self.backend = current_platform.simple_compile_backend

def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")

from torch._dynamo.backends.registry import list_backends

torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)

# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)

from vllm.compilation.backends import VllmBackend

Expand Down Expand Up @@ -710,7 +761,9 @@ def is_attention_compiled_piecewise(self) -> bool:
return self.level == CompilationLevel.PIECEWISE

# Inductor partition case
return self.level > CompilationLevel.NO_COMPILATION and self.use_inductor
return (
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
)

def custom_op_log_check(self):
"""
Expand Down
Loading