|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | | -from typing import Any, Dict, List |
| 3 | +from typing import List |
4 | 4 |
|
5 | | -import torch |
6 | 5 | from torch import fx as fx |
7 | 6 |
|
8 | 7 | from vllm.config import CompilationConfig |
9 | 8 | from vllm.logger import init_logger |
10 | 9 |
|
11 | 10 | from .fix_functionalization import FixFunctionalizationPass |
12 | 11 | from .fusion import FusionPass |
13 | | -from .inductor_pass import InductorPass |
| 12 | +from .inductor_pass import CustomGraphPass, InductorPass |
14 | 13 | from .noop_elimination import NoOpEliminationPass |
15 | 14 |
|
16 | 15 | logger = init_logger(__name__) |
17 | 16 |
|
18 | 17 |
|
19 | | -class PlaceHolder: |
20 | | - pass |
21 | | - |
22 | | - |
23 | | -if torch.__version__ < "2.6": |
24 | | - Parent = PlaceHolder # type: ignore |
25 | | -else: |
26 | | - Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore |
27 | | - |
28 | | - |
29 | | -class PostGradPassManager(Parent): |
| 18 | +class PostGradPassManager(CustomGraphPass): |
30 | 19 | """ |
31 | 20 | The pass manager for post-grad passes. |
32 | 21 | It handles configuration, adding custom passes, and running passes. |
33 | | - It also supports pickling, which is used by the Inductor code cache. |
34 | | - TODO(torch==2.6), use CustomGraphPass |
35 | | - (torch._inductor.custom_graph_pass.CustomGraphPass) |
| 22 | + It supports uuid for the Inductor code cache. That includes torch<2.6 |
| 23 | + support using pickling (in .inductor_pass.CustomGraphPass). |
36 | 24 |
|
37 | 25 | The order of the post-grad post-passes is: |
38 | 26 | 1. passes (constructor parameter) |
@@ -67,27 +55,13 @@ def add(self, pass_: InductorPass): |
67 | 55 | self.passes.append(pass_) |
68 | 56 |
|
69 | 57 | def uuid(self): |
70 | | - return self.__getstate__() |
71 | | - |
72 | | - def __getstate__(self) -> Dict[str, List[Any]]: |
73 | 58 | """ |
74 | | - Custom pickling for the pass manager, as some passes cannot be pickled. |
75 | | - Pickling occurs because the pass manager is set as the value of |
76 | | - `config["post_grad_custom_post_pass"]` in the Inductor config. |
77 | | - The config is pickled to act as a key in the Inductor code cache. |
78 | | - Any other passes in the config are pickled as well. |
79 | | -
|
80 | | - TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. |
| 59 | + The PostGradPassManager is set as a custom pass in the Inductor and |
| 60 | + affects compilation caching. Its uuid depends on the UUIDs of all |
| 61 | + dependent passes and the pass config. See InductorPass for more info. |
81 | 62 | """ |
82 | 63 | state = {"pass_config": self.pass_config.uuid(), "passes": []} |
83 | 64 | for pass_ in self.passes: |
84 | 65 | state["passes"].append(pass_.uuid()) |
85 | 66 | state["passes"].append(self.fix_functionalization.uuid()) |
86 | | - return state |
87 | | - |
88 | | - def __setstate__(self, state): |
89 | | - """ |
90 | | - Do not allow unpickling of the pass manager. |
91 | | - If this is needed in the future, it should properly pickle the passes. |
92 | | - """ |
93 | | - raise ValueError("Cannot unpickle PostGradPassManager") |
| 67 | + return InductorPass.hash_dict(state) |
0 commit comments