Skip to content

Commit 3373dd7

Browse files
toy llam tests passing
Signed-off-by: morrison-turnansky <mturnans@redhat.com>
1 parent 40c8d5a commit 3373dd7

File tree

3 files changed

+40
-31
lines changed

3 files changed

+40
-31
lines changed

tests/compile/piecewise/test_toy_llama.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,14 @@ def tractable_computation(input_ids: torch.Tensor,
241241
@torch.inference_mode
242242
def run_model(llama_config,
243243
use_compile: bool,
244-
use_inductor: bool,
244+
backend: str,
245245
split_attn: bool = False) -> torch.Tensor:
246246

247247
if use_compile:
248248
compilation_config = CompilationConfig(
249249
level=CompilationLevel.PIECEWISE,
250250
use_cudagraph=True,
251-
use_inductor=use_inductor,
251+
backend=backend,
252252
cudagraph_capture_sizes=[1, 2],
253253
)
254254
if split_attn:
@@ -310,8 +310,8 @@ def run_model(llama_config,
310310
return output.cpu()
311311

312312

313-
@pytest.mark.parametrize("use_inductor", [True, False])
314-
def test_toy_llama(use_inductor: bool):
313+
@pytest.mark.parametrize("backend", ["inductor", "eager"])
314+
def test_toy_llama(backend: str):
315315
# compare output with and without piecewise compilation
316316

317317
llama_config = LlamaConfig(hidden_size=128,
@@ -334,10 +334,10 @@ def test_toy_llama(use_inductor: bool):
334334
num_cudagraph_captured=0,
335335
):
336336
outputs.append(
337-
run_model(llama_config, use_inductor=False, use_compile=False))
338-
run_model(tractable_config, use_inductor=False, use_compile=False)
337+
run_model(llama_config, backend="eager", use_compile=False))
338+
run_model(tractable_config, backend="eager", use_compile=False)
339339

340-
if use_inductor:
340+
if backend == "inductor":
341341
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
342342
else:
343343
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
@@ -352,10 +352,8 @@ def test_toy_llama(use_inductor: bool):
352352
**kwargs,
353353
):
354354
outputs.append(
355-
run_model(llama_config,
356-
use_inductor=use_inductor,
357-
use_compile=True))
358-
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
355+
run_model(llama_config, backend=backend, use_compile=True))
356+
run_model(tractable_config, backend=backend, use_compile=True)
359357

360358
with compilation_counter.expect(
361359
num_graphs_seen=1, # one graph for the model
@@ -371,11 +369,11 @@ def test_toy_llama(use_inductor: bool):
371369
):
372370
outputs.append(
373371
run_model(llama_config,
374-
use_inductor=use_inductor,
372+
backend=backend,
375373
use_compile=True,
376374
split_attn=True))
377375
run_model(tractable_config,
378-
use_inductor=use_inductor,
376+
backend=backend,
379377
use_compile=True,
380378
split_attn=True)
381379

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class Relu3(ReLUSquaredActivation):
5757
# All but ReLU3 (even if ReLU2 is on)
5858
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
5959
# RMSNorm and SiluAndMul
60-
("none,-relu3,+rms_norm,+silu_and_mul", 4, "eager", [1, 1, 0, 0], False),
60+
("none,-relu3,+rms_norm,+silu_and_mul", 4, "eager", [1, 1, 0, 0], False
61+
),
6162
# All but RMSNorm
6263
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
6364
#
@@ -71,10 +72,8 @@ class Relu3(ReLUSquaredActivation):
7172
def test_enabled_ops(env: Optional[str], torch_level: int, backend: str,
7273
ops_enabled: list[int], default_on: bool):
7374
custom_ops = env.split(',') if env else []
74-
vllm_config = VllmConfig(
75-
compilation_config=CompilationConfig(backend=backend,
76-
level=torch_level,
77-
custom_ops=custom_ops))
75+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
76+
backend=backend, level=torch_level, custom_ops=custom_ops))
7877
with set_current_vllm_config(vllm_config):
7978

8079
assert CustomOp.default_on() == default_on

vllm/config/compilation.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import asdict, field
88
from pathlib import Path
99
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
10+
1011
from pydantic import TypeAdapter, field_validator
1112
from pydantic.dataclasses import dataclass
1213

@@ -212,7 +213,8 @@ class CompilationConfig:
212213
"""
213214
Whether to use inductor compilation.
214215
215-
This flag is deprecated and will be removed. Please use the 'backend' option instead.
216+
This flag is deprecated and will be removed.
217+
Please use the 'backend' option instead.
216218
217219
- False: inductor compilation is not used. graph runs in eager
218220
(custom_ops enabled by default).
@@ -514,17 +516,22 @@ def __post_init__(self, **kwargs) -> None:
514516
"must be 'all', 'none', '+op' or '-op' "
515517
"(where 'op' is the registered op name)")
516518

517-
# Currently only eager and inductor backend are supported for piecewise compilation.
518-
# Update when more backends are supported.
519-
if self.level == CompilationLevel.PIECEWISE and self.backend not in ["", "eager", "inductor"]:
520-
raise ValueError(f"Invalid backend for piecewise compilation: {self.backend}")
519+
# Currently only eager and inductor backend are supported
520+
# for piecewise compilation. Update when more backends are supported.
521+
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
522+
"", "eager", "inductor"
523+
]:
524+
raise ValueError(
525+
f"Invalid backend for piecewise compilation: {self.backend}")
526+
527+
if self.backend == "":
528+
self.backend = "inductor"
521529

522530
logger.warning_once(
523-
"The 'use_inductor' flag is deprecated and will be removed in a future release. "
524-
"Please use the 'backend' option instead.",
525-
)
526-
527-
531+
"The 'use_inductor' flag is deprecated and will be\
532+
removed in a future release."
533+
"Please use the 'backend' option instead.", )
534+
528535
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
529536
"""
530537
Initialize the backend for the compilation config from a vllm config.
@@ -534,7 +541,10 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
534541
The backend for the compilation config.
535542
"""
536543
if self.level is None:
537-
raise ValueError("No compilation level is set. This method should only be called via vllm config where the level is set if none is provided.")
544+
raise ValueError(
545+
"No compilation level is set. This method should only be \
546+
called via vllm config where the level is set if none is \
547+
provided.")
538548
if self.level == CompilationLevel.NO_COMPILATION:
539549
raise ValueError("No compilation level is set.")
540550

@@ -554,8 +564,10 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
554564
elif self.backend in ["eager", "inductor"]:
555565
vllm_config.compilation_config.backend = self.backend
556566
else:
557-
raise ValueError(f"Invalid backend for piecewise compilation: {self.backend}")
558-
567+
raise ValueError(
568+
f"Invalid backend for piecewise compilation: {self.backend}"
569+
)
570+
559571
assert self.level == CompilationLevel.PIECEWISE
560572

561573
from vllm.compilation.backends import VllmBackend

0 commit comments

Comments
 (0)