Skip to content

Commit 8f72d11

Browse files
committed
Revert "[Frontend] CompilationConfig overhaul (#20283): deprecate use_inducto…"
This reverts commit 0c824fc. Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent cf4cd6c commit 8f72d11

File tree

7 files changed

+63
-126
lines changed

7 files changed

+63
-126
lines changed

tests/compile/piecewise/test_toy_llama.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,13 @@ def tractable_computation(
258258

259259
@torch.inference_mode
260260
def run_model(
261-
llama_config, use_compile: bool, backend: str, split_attn: bool = False
261+
llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False
262262
) -> torch.Tensor:
263263
if use_compile:
264264
compilation_config = CompilationConfig(
265265
level=CompilationLevel.PIECEWISE,
266266
use_cudagraph=True,
267-
backend=backend,
267+
use_inductor=use_inductor,
268268
cudagraph_capture_sizes=[1, 2],
269269
)
270270
if split_attn:
@@ -338,8 +338,8 @@ def run_model(
338338
return output.cpu()
339339

340340

341-
@pytest.mark.parametrize("backend", ["inductor", "eager"])
342-
def test_toy_llama(backend: str):
341+
@pytest.mark.parametrize("use_inductor", [True, False])
342+
def test_toy_llama(use_inductor: bool):
343343
# compare output with and without piecewise compilation
344344

345345
llama_config = LlamaConfig(
@@ -358,10 +358,10 @@ def test_toy_llama(backend: str):
358358
num_backend_compilations=0,
359359
num_cudagraph_captured=0,
360360
):
361-
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
362-
run_model(tractable_config, backend="eager", use_compile=False)
361+
outputs.append(run_model(llama_config, use_inductor=False, use_compile=False))
362+
run_model(tractable_config, use_inductor=False, use_compile=False)
363363

364-
if backend == "inductor":
364+
if use_inductor:
365365
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
366366
else:
367367
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@@ -377,8 +377,10 @@ def test_toy_llama(backend: str):
377377
num_cudagraph_captured=2,
378378
**kwargs,
379379
):
380-
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
381-
run_model(tractable_config, backend=backend, use_compile=True)
380+
outputs.append(
381+
run_model(llama_config, use_inductor=use_inductor, use_compile=True)
382+
)
383+
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
382384

383385
with compilation_counter.expect(
384386
num_graphs_seen=1, # one graph for the model
@@ -393,9 +395,16 @@ def test_toy_llama(backend: str):
393395
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
394396
):
395397
outputs.append(
396-
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
398+
run_model(
399+
llama_config,
400+
use_inductor=use_inductor,
401+
use_compile=True,
402+
split_attn=True,
403+
)
397404
)
398-
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
405+
run_model(
406+
tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True
407+
)
399408

400409
for i in range(1, len(outputs)):
401410
assert torch.allclose(outputs[0], outputs[i])

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,59 +37,57 @@ class Relu3(ReLUSquaredActivation):
3737

3838

3939
@pytest.mark.parametrize(
40-
"env, torch_level, backend, ops_enabled, default_on",
40+
"env, torch_level, use_inductor, ops_enabled, default_on",
4141
[
4242
# Default values based on compile level
4343
# - All by default (no Inductor compilation)
44-
(None, 0, "eager", [True] * 4, True),
45-
(None, 1, "eager", [True] * 4, True),
46-
(None, 2, "eager", [True] * 4, True),
47-
(None, 3, "eager", [True] * 4, True),
44+
(None, 0, False, [True] * 4, True),
45+
(None, 1, True, [True] * 4, True),
46+
(None, 2, False, [True] * 4, True),
4847
# - None by default (with Inductor)
49-
(None, 0, "inductor", [True] * 4, True),
50-
# - None by default (with Inductor)
51-
(None, 1, "inductor", [False] * 4, False),
52-
(None, 2, "inductor", [False] * 4, False),
53-
(None, 3, "inductor", [False] * 4, False),
48+
(None, 3, True, [False] * 4, False),
49+
(None, 4, True, [False] * 4, False),
50+
# - All by default (without Inductor)
51+
(None, 3, False, [True] * 4, True),
52+
(None, 4, False, [True] * 4, True),
5453
# Explicitly enabling/disabling
5554
#
5655
# Default: all
5756
#
5857
# All but SiluAndMul
59-
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
58+
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
6059
# Only ReLU3
61-
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
60+
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
6261
# All but SiluAndMul
63-
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
62+
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
6463
# All but ReLU3 (even if ReLU2 is on)
65-
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
64+
("-relu3,+relu2", 3, False, [1, 1, 1, 0], True),
6665
# RMSNorm and SiluAndMul
67-
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
66+
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
6867
# All but RMSNorm
69-
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
68+
("-rms_norm", 3, False, [0, 1, 1, 1], True),
7069
#
7170
# Default: none
7271
#
7372
# Only ReLU3
74-
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
73+
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
7574
# All but RMSNorm
76-
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
75+
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
7776
],
7877
)
7978
def test_enabled_ops(
8079
env: Optional[str],
8180
torch_level: int,
82-
backend: str,
81+
use_inductor: bool,
8382
ops_enabled: list[int],
8483
default_on: bool,
8584
):
8685
custom_ops = env.split(",") if env else []
8786
vllm_config = VllmConfig(
8887
compilation_config=CompilationConfig(
89-
backend=backend, level=torch_level, custom_ops=custom_ops
88+
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
9089
)
9190
)
92-
# breakpoint()
9391
with set_current_vllm_config(vllm_config):
9492
assert CustomOp.default_on() == default_on
9593

vllm/compilation/backends.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
37-
if compilation_config.backend == "inductor":
37+
if compilation_config.use_inductor:
3838
# Use standalone compile only if requested, version is new enough,
3939
# and the symbol actually exists in this PyTorch build.
4040
if (
@@ -48,10 +48,6 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
4848
logger.debug("Using InductorAdaptor")
4949
return InductorAdaptor()
5050
else:
51-
assert compilation_config.backend == "eager", (
52-
"Custom backends not supported with CompilationLevel.PIECEWISE"
53-
)
54-
5551
logger.debug("Using EagerAdaptor")
5652
return EagerAdaptor()
5753

vllm/config/compilation.py

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,10 @@ class CompilationConfig:
180180
"""The directory to store the compiled graph, to accelerate Inductor
181181
compilation. By default, it will use model-related information to generate
182182
a cache directory."""
183-
backend: str = "inductor"
183+
backend: str = ""
184184
"""The backend for compilation. It needs to be a string:
185185
186-
- "" (empty string): use the default backend ("inductor" on CUDA-alike
187-
platforms).
186+
- "" (empty string): use the default backend.
188187
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
189188
- "full.module.name": a qualified name which can be used to import the
190189
@@ -193,11 +192,7 @@ class CompilationConfig:
193192
distributed setting. When the compilation level is 1 or 2, the backend is
194193
used for the compilation directly (it sees the whole graph). When the
195194
compilation level is 3, the backend is used for the piecewise compilation
196-
(it sees a part of the graph). The backend can not be custom for compilation
197-
level 3. Furthermore, compilation is only piecewise if splitting ops is set
198-
accordingly and use_inductor_cudagraphs_partition is off. Note that the
199-
default options for splitting ops are sufficient for piecewise compilation.
200-
"""
195+
(it sees a part of the graph)."""
201196
custom_ops: list[str] = field(default_factory=list)
202197
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
203198
to enable all, 'none' to disable all. Also specify a list of custom op
@@ -215,24 +210,16 @@ class CompilationConfig:
215210
compilation."""
216211

217212
# Inductor capture
218-
use_inductor: Optional[bool] = None
219-
"""
220-
Whether to use inductor compilation.
221-
222-
This flag is deprecated and will be removed.
223-
Please use the 'backend' option instead.
213+
use_inductor: bool = True
214+
"""Whether to use inductor compilation:
224215
225216
- False: inductor compilation is not used. graph runs in eager
226217
(custom_ops enabled by default).
227218
- True: inductor compilation is used (custom_ops disabled by default).
228219
One graph for symbolic shape and one graph per size in compile_sizes
229220
are compiled using configurations in inductor_compile_config.
230221
231-
This setting is ignored if level<PIECEWISE.
232-
233-
For future compatibility:
234-
If use_inductor is True, backend="inductor" otherwise backend="eager".
235-
"""
222+
This setting is ignored if level<PIECEWISE."""
236223
compile_sizes: Optional[list[Union[int, str]]] = None
237224
"""Sizes to compile for inductor. In addition
238225
to integers, it also supports "cudagraph_capture_sizes" to
@@ -538,59 +525,23 @@ def __post_init__(self) -> None:
538525
"(where 'op' is the registered op name)"
539526
)
540527

541-
# Currently only eager and inductor backend are supported.
542-
# for piecewise compilation. Custom backends are not suppported for
543-
# piecewise compilation. Update when more backends are supported.
544-
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
545-
"",
546-
"eager",
547-
"inductor",
548-
]:
549-
raise ValueError(
550-
f"Invalid backend for piecewise compilation: {self.backend}"
551-
)
552-
553-
if self.use_inductor is not None:
554-
logger.warning_once(
555-
"The 'use_inductor' flag is deprecated and will be\
556-
removed in a future release."
557-
"Please use the 'backend' option instead.",
558-
)
559-
self.backend = "inductor" if self.use_inductor else "eager"
560-
561-
if self.backend == "":
562-
self.backend = "inductor"
563-
564528
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
565-
"""
566-
Initialize the backend for the compilation config from a vllm config.
567-
Arguments:
568-
vllm_config: The vllm config to initialize the backend from.
569-
Returns:
570-
The backend for the compilation config.
571-
"""
572-
if self.level is None:
573-
raise ValueError(
574-
"No compilation level is set. This method should only be \
575-
called via vllm config where the level is set if none is \
576-
provided."
577-
)
578529
if self.level == CompilationLevel.NO_COMPILATION:
579530
raise ValueError("No compilation level is set.")
580531

581532
from torch._dynamo.backends.registry import list_backends
582533

583534
torch_backends = list_backends(exclude_tags=tuple())
584535
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
536+
if self.backend == "":
537+
return "eager"
585538
if self.backend in torch_backends:
586539
return self.backend
587540
return resolve_obj_by_qualname(self.backend)
588541

542+
# TODO: pass user-specified backend to piecewise compilation
543+
# merge with the config use_inductor
589544
assert self.level == CompilationLevel.PIECEWISE
590-
if self.backend not in ["eager", "inductor"]:
591-
raise ValueError(
592-
f"Invalid backend for piecewise compilation: {self.backend}"
593-
)
594545

595546
from vllm.compilation.backends import VllmBackend
596547

@@ -743,7 +694,7 @@ def is_attention_compiled_piecewise(self) -> bool:
743694
)
744695

745696
inductor_used = (
746-
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
697+
self.level == CompilationLevel.PIECEWISE and self.use_inductor
747698
) or (
748699
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
749700
)

vllm/config/vllm.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -318,25 +318,6 @@ def __post_init__(self):
318318
# NB: Passing both --enforce-eager and a compilation level
319319
# in V0 means the compilation level wins out.
320320
self.compilation_config.level = CompilationLevel.NO_COMPILATION
321-
else:
322-
assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION
323-
assert self.compilation_config.level <= CompilationLevel.PIECEWISE
324-
assert self.compilation_config.level <= 3
325-
326-
# If user does not set custom ops via none or all set it here based on
327-
# compilation level and backend.
328-
if (
329-
self.compilation_config.custom_ops.count("none")
330-
+ self.compilation_config.custom_ops.count("all")
331-
== 0
332-
):
333-
if (
334-
self.compilation_config.level > 0
335-
and self.compilation_config.backend != "eager"
336-
):
337-
self.compilation_config.custom_ops.append("none")
338-
else:
339-
self.compilation_config.custom_ops.append("all")
340321

341322
# async tp is built on top of sequence parallelism
342323
# and requires it to be enabled.

vllm/model_executor/custom_op.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def enabled(cls) -> bool:
114114
custom_ops = compilation_config.custom_ops
115115
if not hasattr(cls, "name"):
116116
logger.warning_once(
117-
"Custom op %s was not registered, which means it won't appear\
118-
in the op registry. It will be enabled/disabled based on the\
119-
global settings.", # noqa: E501
117+
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
120118
cls.__name__,
121119
)
122120
return CustomOp.default_on()
@@ -130,17 +128,19 @@ def enabled(cls) -> bool:
130128
@staticmethod
131129
def default_on() -> bool:
132130
"""
133-
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
134-
'all', off by default if 'none'.
135-
When PyTorch Inductor is used, 'none' is the default value,
136-
otherwise 'all'.
131+
On by default if PyTorch Inductor is not used.
132+
Specifying 'all' or 'none' in custom_op takes precedence.
137133
"""
134+
from vllm.config import CompilationLevel
135+
138136
compilation_config = get_cached_compilation_config()
137+
default_on = (
138+
compilation_config.level < CompilationLevel.PIECEWISE
139+
or not compilation_config.use_inductor
140+
)
139141
count_none = compilation_config.custom_ops.count("none")
140142
count_all = compilation_config.custom_ops.count("all")
141-
assert count_none + count_all == 1
142-
143-
return not count_none > 0 or count_all > 0
143+
return default_on and not count_none > 0 or count_all > 0
144144

145145
# Dictionary of all custom ops (classes, indexed by registered name).
146146
# To check if an op with a name is enabled, call .enabled() on the class.

vllm/platforms/cpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
274274
"epilogue_fusion": True,
275275
}
276276
)
277+
if compilation_config.use_inductor:
278+
compilation_config.custom_ops = ["none"]
277279

278280
if vllm_config.lora_config is not None:
279281
compilation_config.level = CompilationLevel.NO_COMPILATION

0 commit comments

Comments
 (0)