Skip to content

Commit 0d6e550

Browse files
committed
fix func test
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 1b1a63e commit 0d6e550

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

tests/compile/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]):
5454
self.custom_passes = list(passes)
5555
vllm_config = get_current_vllm_config()
5656
compile_config = vllm_config.compilation_config
57-
self.inductor_config = compile_config.inductor_compile_config
57+
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
58+
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
5859
self.inductor_config["force_disable_caches"] = True
5960
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
6061

tests/compile/test_functionalization.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1212
from vllm.compilation.noop_elimination import NoOpEliminationPass
1313
from vllm.compilation.post_cleanup import PostCleanupPass
14-
from vllm.config import CompilationConfig, PassConfig, VllmConfig
14+
from vllm.config import (
15+
CompilationConfig,
16+
ModelConfig,
17+
PassConfig,
18+
VllmConfig,
19+
set_current_vllm_config,
20+
)
1521
from vllm.model_executor.layers.activation import SiluAndMul
1622
from vllm.model_executor.layers.layernorm import RMSNorm
1723
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@@ -217,42 +223,48 @@ def ops_not_in_model(self):
217223
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
218224
torch.set_default_device("cuda")
219225

220-
vllm_config = VllmConfig()
221-
vllm_config.compilation_config = CompilationConfig(
222-
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
226+
vllm_config = VllmConfig(
227+
model_config=ModelConfig(dtype=torch.bfloat16),
228+
compilation_config=CompilationConfig(
229+
custom_ops=["all"],
230+
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
231+
),
223232
)
224-
noop_pass = NoOpEliminationPass(vllm_config)
225-
fusion_pass = RMSNormQuantFusionPass(vllm_config)
226-
cleanup_pass = PostCleanupPass(vllm_config)
227-
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
228-
229-
passes = (
230-
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
231-
if do_fusion
232-
else [noop_pass, cleanup_pass]
233-
)
234-
func_pass = FixFunctionalizationPass(vllm_config)
235233

236-
backend_func = TestBackend(*passes, func_pass)
237-
backend_no_func = TestBackend(*passes)
234+
with set_current_vllm_config(vllm_config):
235+
assert RMSNorm.enabled()
236+
noop_pass = NoOpEliminationPass(vllm_config)
237+
fusion_pass = RMSNormQuantFusionPass(vllm_config)
238+
cleanup_pass = PostCleanupPass(vllm_config)
239+
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
240+
241+
passes = (
242+
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
243+
if do_fusion
244+
else [noop_pass, cleanup_pass]
245+
)
246+
func_pass = FixFunctionalizationPass(vllm_config)
238247

239-
model = model_class()
240-
torch.compile(model, backend=backend_func)(*model.example_inputs())
241-
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
248+
backend_func = TestBackend(*passes, func_pass)
249+
backend_no_func = TestBackend(*passes)
242250

243-
# check if the functionalization pass is applied
244-
for op in model.ops_in_model(do_fusion):
245-
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
246-
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
251+
model = model_class()
252+
torch.compile(model, backend=backend_func)(*model.example_inputs())
253+
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
247254

248-
# make sure the ops were all de-functionalized
249-
found = dict()
250-
for node in backend_func.graph_post_pass.nodes:
255+
# check if the functionalization pass is applied
251256
for op in model.ops_in_model(do_fusion):
252-
if is_func(node, op):
253-
found[op] = True
254-
for op in model.ops_not_in_model():
255-
if is_func(node, op):
256-
found[op] = True
257-
assert all(found[op] for op in model.ops_in_model(do_fusion))
258-
assert all(not found.get(op) for op in model.ops_not_in_model())
257+
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
258+
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
259+
260+
# make sure the ops were all de-functionalized
261+
found = dict()
262+
for node in backend_func.graph_post_pass.nodes:
263+
for op in model.ops_in_model(do_fusion):
264+
if is_func(node, op):
265+
found[op] = True
266+
for op in model.ops_not_in_model():
267+
if is_func(node, op):
268+
found[op] = True
269+
assert all(found[op] for op in model.ops_in_model(do_fusion))
270+
assert all(not found.get(op) for op in model.ops_not_in_model())

0 commit comments

Comments
 (0)