Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def tractable_computation(

@torch.inference_mode
def run_model(
llama_config, use_compile: bool, backend: str, split_attn: bool = False
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
backend=backend,
use_inductor=use_inductor,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
Expand Down Expand Up @@ -338,8 +338,8 @@ def run_model(
return output.cpu()


@pytest.mark.parametrize("backend", ["inductor", "eager"])
def test_toy_llama(backend: str):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation

llama_config = LlamaConfig(
Expand All @@ -358,10 +358,10 @@ def test_toy_llama(backend: str):
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
run_model(tractable_config, backend="eager", use_compile=False)
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)

if backend == "inductor":
if use_inductor:
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
Expand All @@ -377,8 +377,10 @@ def test_toy_llama(backend: str):
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
run_model(tractable_config, backend=backend, use_compile=True)
outputs.append(
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
)
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
Expand All @@ -393,9 +395,16 @@ def test_toy_llama(backend: str):
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
run_model(
llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True,
)
)
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
run_model(
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
)

for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
Expand Down
40 changes: 19 additions & 21 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,59 +37,57 @@ class Relu3(ReLUSquaredActivation):


@pytest.mark.parametrize(
"env, torch_level, backend, ops_enabled, default_on",
"env, torch_level, use_inductor, ops_enabled, default_on",
[
# Default values based on compile level
# - All by default (no Inductor compilation)
(None, 0, "eager", [True] * 4, True),
(None, 1, "eager", [True] * 4, True),
(None, 2, "eager", [True] * 4, True),
(None, 3, "eager", [True] * 4, True),
(None, 0, False, [True] * 4, True),
(None, 1, True, [True] * 4, True),
(None, 2, False, [True] * 4, True),
# - None by default (with Inductor)
(None, 0, "inductor", [True] * 4, True),
# - None by default (with Inductor)
(None, 1, "inductor", [False] * 4, False),
(None, 2, "inductor", [False] * 4, False),
(None, 3, "inductor", [False] * 4, False),
(None, 3, True, [False] * 4, False),
(None, 4, True, [False] * 4, False),
# - All by default (without Inductor)
(None, 3, False, [True] * 4, True),
(None, 4, False, [True] * 4, True),
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
# RMSNorm and SiluAndMul
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
# All but RMSNorm
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
("-rms_norm", 3, False, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
],
)
def test_enabled_ops(
env: Optional[str],
torch_level: int,
backend: str,
use_inductor: bool,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
backend=backend, level=torch_level, custom_ops=custom_ops
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
)
)
# breakpoint()
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

Expand Down
6 changes: 1 addition & 5 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.backend == "inductor":
if compilation_config.use_inductor:
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
Expand All @@ -48,10 +48,6 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
else:
assert compilation_config.backend == "eager", (
"Custom backends not supported with CompilationLevel.PIECEWISE"
)

logger.debug("Using EagerAdaptor")
return EagerAdaptor()

Expand Down
71 changes: 11 additions & 60 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,10 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend: str = "inductor"
backend: str = ""
"""The backend for compilation. It needs to be a string:

- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the

Expand All @@ -193,11 +192,7 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
level 3. Furthermore, compilation is only piecewise if splitting ops is set
accordingly and use_inductor_cudagraphs_partition is off. Note that the
default options for splitting ops are sufficient for piecewise compilation.
"""
(it sees a part of the graph)."""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
Expand All @@ -215,24 +210,16 @@ class CompilationConfig:
compilation."""

# Inductor capture
use_inductor: Optional[bool] = None
"""
Whether to use inductor compilation.

This flag is deprecated and will be removed.
Please use the 'backend' option instead.
use_inductor: bool = True
"""Whether to use inductor compilation:

- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.

This setting is ignored if level<PIECEWISE.

For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
Expand Down Expand Up @@ -538,59 +525,23 @@ def __post_init__(self) -> None:
"(where 'op' is the registered op name)"
)

# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)

if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be\
removed in a future release."
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"

if self.backend == "":
self.backend = "inductor"

def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")

from torch._dynamo.backends.registry import list_backends

torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)

# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)

from vllm.compilation.backends import VllmBackend

Expand Down Expand Up @@ -743,7 +694,7 @@ def is_attention_compiled_piecewise(self) -> bool:
)

inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
self.level == CompilationLevel.PIECEWISE and self.use_inductor
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)
Expand Down
19 changes: 0 additions & 19 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,25 +318,6 @@ def __post_init__(self):
# NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION
else:
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
assert self.compilation_config.level <= 3

# If user does not set custom ops via none or all set it here based on
# compilation level and backend.
if (
self.compilation_config.custom_ops.count("none")
+ self.compilation_config.custom_ops.count("all")
== 0
):
if (
self.compilation_config.level > 0
and self.compilation_config.backend != "eager"
):
self.compilation_config.custom_ops.append("none")
else:
self.compilation_config.custom_ops.append("all")

# async tp is built on top of sequence parallelism
# and requires it to be enabled.
Expand Down
20 changes: 10 additions & 10 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def enabled(cls) -> bool:
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
"Custom op %s was not registered, which means it won't appear\
in the op registry. It will be enabled/disabled based on the\
global settings.", # noqa: E501
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
cls.__name__,
)
return CustomOp.default_on()
Expand All @@ -130,17 +128,19 @@ def enabled(cls) -> bool:
@staticmethod
def default_on() -> bool:
"""
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
'all', off by default if 'none'.
When PyTorch Inductor is used, 'none' is the default value,
otherwise 'all'.
On by default if PyTorch Inductor is not used.
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel

compilation_config = get_cached_compilation_config()
default_on = (
compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor
)
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
assert count_none + count_all == 1

return not count_none > 0 or count_all > 0
return default_on and not count_none > 0 or count_all > 0

# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"epilogue_fusion": True,
}
)
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]

Comment on lines +277 to 279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to disable custom ops is based on compilation_config.use_inductor, which defaults to True and is not updated based on the backend variable set earlier. This causes custom ops to be incorrectly disabled even when the backend is 'eager', which should support them.

To fix this, use_inductor should be explicitly set based on the chosen backend. This ensures that custom ops are only disabled when inductor is the backend, and use_inductor is consistent for any other logic that might depend on it.

Suggested change
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]
compilation_config.use_inductor = (backend == "inductor")
if compilation_config.use_inductor:
compilation_config.custom_ops = ["none"]

if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
Expand Down