|
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 | import torch |
7 | | -from torch._inductor.codecache import BypassFxGraphCache |
8 | 7 |
|
9 | | -from vllm.compilation.config import CompilationConfig |
10 | | -from vllm.compilation.inductor_pass import (CallableInductorPass, |
11 | | - as_inductor_pass) |
| 8 | +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass |
12 | 9 | from vllm.compilation.pass_manager import PostGradPassManager |
| 10 | +from vllm.config import CompilationConfig |
13 | 11 |
|
14 | 12 |
|
15 | 13 | def simple_callable(graph: torch.fx.Graph): |
16 | 14 | pass |
17 | 15 |
|
18 | 16 |
|
19 | | -@as_inductor_pass(files=(__file__, )) |
20 | | -def callable_decorated(graph: torch.fx.Graph): |
21 | | - pass |
| 17 | +callable_uuid = CallableInductorPass(simple_callable, |
| 18 | + InductorPass.hash_source(__file__)) |
22 | 19 |
|
23 | 20 |
|
24 | 21 | @pytest.mark.parametrize( |
25 | 22 | "works, callable", |
26 | | - [(False, simple_callable), (True, callable_decorated), |
27 | | - (True, CallableInductorPass(simple_callable, "simple_callable"))]) |
| 23 | + [ |
| 24 | + (False, simple_callable), |
| 25 | + (True, callable_uuid), |
| 26 | + (True, CallableInductorPass(simple_callable)), |
| 27 | + ], |
| 28 | +) |
28 | 29 | def test_pass_manager(works: bool, callable): |
29 | 30 | config = CompilationConfig().pass_config |
30 | | - pass_manager = PostGradPassManager([callable]) |
31 | | - pass_manager.configure(config) # Adds default passes |
32 | 31 |
|
| 32 | + pass_manager = PostGradPassManager() |
| 33 | + pass_manager.configure(config) |
| 34 | + |
| 35 | + # Try to add the callable to the pass manager |
33 | 36 | if works: |
| 37 | + pass_manager.add(callable) |
34 | 38 | pickle.dumps(pass_manager) |
35 | 39 | else: |
36 | | - with pytest.raises(BypassFxGraphCache): |
37 | | - pickle.dumps(pass_manager) |
| 40 | + with pytest.raises(AssertionError): |
| 41 | + pass_manager.add(callable) |
0 commit comments