Skip to content

Commit 0c8c359

Browse files
ProExpertProglk-chen
authored andcommitted
[Fix] [torch.compile] Improve UUID system for custom passes (vllm-project#15249)
Signed-off-by: luka <luka@neuralmagic.com>
1 parent b8200d9 commit 0c8c359

File tree

5 files changed

+125
-84
lines changed

5 files changed

+125
-84
lines changed

tests/compile/test_pass_manager.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
3-
import pickle
2+
import copy
43

54
import pytest
65
import torch
@@ -10,32 +9,63 @@
109
from vllm.config import CompilationConfig
1110

1211

12+
# dummy custom pass that doesn't inherit
1313
def simple_callable(graph: torch.fx.Graph):
1414
pass
1515

1616

17-
callable_uuid = CallableInductorPass(simple_callable,
18-
InductorPass.hash_source(__file__))
17+
# Should fail to add directly to the pass manager
18+
def test_bad_callable():
19+
config = CompilationConfig().pass_config
20+
21+
pass_manager = PostGradPassManager()
22+
pass_manager.configure(config)
23+
24+
with pytest.raises(AssertionError):
25+
pass_manager.add(simple_callable) # noqa, type wrong on purpose
26+
27+
28+
# Pass that inherits from InductorPass
29+
class ProperPass(InductorPass):
30+
31+
def __call__(self, graph: torch.fx.graph.Graph) -> None:
32+
pass
1933

2034

2135
@pytest.mark.parametrize(
22-
"works, callable",
36+
"callable",
2337
[
24-
(False, simple_callable),
25-
(True, callable_uuid),
26-
(True, CallableInductorPass(simple_callable)),
38+
ProperPass(),
39+
# Can also wrap callables in CallableInductorPass for compliance
40+
CallableInductorPass(simple_callable),
41+
CallableInductorPass(simple_callable,
42+
InductorPass.hash_source(__file__))
2743
],
2844
)
29-
def test_pass_manager(works: bool, callable):
45+
def test_pass_manager_uuid(callable):
3046
config = CompilationConfig().pass_config
3147

3248
pass_manager = PostGradPassManager()
3349
pass_manager.configure(config)
3450

35-
# Try to add the callable to the pass manager
36-
if works:
37-
pass_manager.add(callable)
38-
pickle.dumps(pass_manager)
39-
else:
40-
with pytest.raises(AssertionError):
41-
pass_manager.add(callable)
51+
# Check that UUID is different if the same pass is added 2x
52+
pass_manager.add(callable)
53+
uuid1 = pass_manager.uuid()
54+
pass_manager.add(callable)
55+
uuid2 = pass_manager.uuid()
56+
assert uuid1 != uuid2
57+
58+
# UUID should be the same as the original one,
59+
# as we constructed in the same way.
60+
pass_manager2 = PostGradPassManager()
61+
pass_manager2.configure(config)
62+
pass_manager2.add(callable)
63+
assert uuid1 == pass_manager2.uuid()
64+
65+
# UUID should be different due to config change
66+
config2 = copy.deepcopy(config)
67+
config2.enable_fusion = not config2.enable_fusion
68+
pass_manager3 = PostGradPassManager()
69+
pass_manager3.configure(config2)
70+
pass_manager3.add(callable)
71+
assert uuid1 != pass_manager3.uuid()

vllm/compilation/inductor_pass.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import hashlib
4+
import importlib.metadata
45
import inspect
6+
import json
57
import types
6-
from abc import ABC, abstractmethod
7-
from typing import Any, Callable, Optional, Union
8+
from typing import Any, Callable, Dict, Optional, Union
89

910
import torch
11+
from packaging.version import Version
1012
from torch import fx
1113

14+
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
15+
from torch._inductor.custom_graph_pass import CustomGraphPass
16+
else:
17+
# CustomGraphPass is not present in 2.5 or lower, import our version
18+
from .torch25_custom_graph_pass import ( # noqa: yapf
19+
Torch25CustomGraphPass as CustomGraphPass)
1220

13-
class InductorPass(ABC):
21+
22+
class InductorPass(CustomGraphPass):
1423
"""
15-
General custom inductor pass interface.
24+
A custom graph pass that uses a hash of its source as the UUID.
25+
This is defined as a convenience and should work in most cases.
1626
"""
1727

18-
@abstractmethod
19-
def __call__(self, graph: torch.fx.Graph):
20-
"""
21-
Execute the pass on the given graph.
22-
"""
23-
raise NotImplementedError
24-
2528
def uuid(self) -> Any:
2629
"""
2730
Provide a unique identifier for the pass, used in Inductor code cache.
@@ -48,7 +51,16 @@ def hash_source(*srcs: Union[str, Any]):
4851
else:
4952
src_str = inspect.getsource(src.__class__)
5053
hasher.update(src_str.encode("utf-8"))
51-
return hasher.digest()
54+
return hasher.hexdigest()
55+
56+
@staticmethod
57+
def hash_dict(dict_: Dict[Any, Any]):
58+
"""
59+
Utility method to hash a dictionary, can alternatively be used for uuid.
60+
:return: A sha256 hash of the json rep of the dictionary.
61+
"""
62+
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
63+
return hashlib.sha256(encoded).hexdigest()
5264

5365

5466
class CallableInductorPass(InductorPass):
@@ -61,25 +73,10 @@ def __init__(self,
6173
callable: Callable[[fx.Graph], None],
6274
uuid: Optional[Any] = None):
6375
self.callable = callable
64-
if uuid is None:
65-
uuid = InductorPass.hash_source(callable)
66-
self._uuid = uuid
76+
self._uuid = self.hash_source(callable) if uuid is None else uuid
6777

6878
def __call__(self, graph: torch.fx.Graph):
6979
self.callable(graph)
7080

7181
def uuid(self) -> Any:
7282
return self._uuid
73-
74-
def __getstate__(self):
75-
"""
76-
Pickling occurs in the Inductor code cache if a pass is not given to
77-
the pass manager but is instead directly added to config as a pass.
78-
See PostGradPassManager for more.
79-
80-
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
81-
"""
82-
return self._uuid
83-
84-
def __setstate__(self, state):
85-
raise ValueError("Cannot unpickle CallableInductorPass")

vllm/compilation/pass_manager.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,26 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Any, Dict, List
3+
from typing import List
44

5-
import torch
65
from torch import fx as fx
76

87
from vllm.config import CompilationConfig
98
from vllm.logger import init_logger
109

1110
from .fix_functionalization import FixFunctionalizationPass
1211
from .fusion import FusionPass
13-
from .inductor_pass import InductorPass
12+
from .inductor_pass import CustomGraphPass, InductorPass
1413
from .noop_elimination import NoOpEliminationPass
1514

1615
logger = init_logger(__name__)
1716

1817

19-
class PlaceHolder:
20-
pass
21-
22-
23-
if torch.__version__ < "2.6":
24-
Parent = PlaceHolder # type: ignore
25-
else:
26-
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
27-
28-
29-
class PostGradPassManager(Parent):
18+
class PostGradPassManager(CustomGraphPass):
3019
"""
3120
The pass manager for post-grad passes.
3221
It handles configuration, adding custom passes, and running passes.
33-
It also supports pickling, which is used by the Inductor code cache.
34-
TODO(torch==2.6), use CustomGraphPass
35-
(torch._inductor.custom_graph_pass.CustomGraphPass)
22+
It supports uuid for the Inductor code cache. That includes torch<2.6
23+
support using pickling (in .inductor_pass.CustomGraphPass).
3624
3725
The order of the post-grad post-passes is:
3826
1. passes (constructor parameter)
@@ -67,27 +55,13 @@ def add(self, pass_: InductorPass):
6755
self.passes.append(pass_)
6856

6957
def uuid(self):
70-
return self.__getstate__()
71-
72-
def __getstate__(self) -> Dict[str, List[Any]]:
7358
"""
74-
Custom pickling for the pass manager, as some passes cannot be pickled.
75-
Pickling occurs because the pass manager is set as the value of
76-
`config["post_grad_custom_post_pass"]` in the Inductor config.
77-
The config is pickled to act as a key in the Inductor code cache.
78-
Any other passes in the config are pickled as well.
79-
80-
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
59+
The PostGradPassManager is set as a custom pass in the Inductor and
60+
affects compilation caching. Its uuid depends on the UUIDs of all
61+
dependent passes and the pass config. See InductorPass for more info.
8162
"""
8263
state = {"pass_config": self.pass_config.uuid(), "passes": []}
8364
for pass_ in self.passes:
8465
state["passes"].append(pass_.uuid())
8566
state["passes"].append(self.fix_functionalization.uuid())
86-
return state
87-
88-
def __setstate__(self, state):
89-
"""
90-
Do not allow unpickling of the pass manager.
91-
If this is needed in the future, it should properly pickle the passes.
92-
"""
93-
raise ValueError("Cannot unpickle PostGradPassManager")
67+
return InductorPass.hash_dict(state)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from abc import ABC, abstractmethod
3+
from typing import Any, Optional
4+
5+
import torch
6+
7+
8+
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
9+
"""
10+
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
11+
It conforms to the 2.6 interface but also supports pickling, as that's what
12+
the inductor code cache uses to determine the cache key before 2.6.
13+
(in 2.6 and above, uuid() is used.)
14+
15+
Subclasses can just "pretend" that uuid is used.
16+
"""
17+
18+
@abstractmethod
19+
def __call__(self, graph: torch.fx.graph.Graph) -> None:
20+
"""
21+
Implementation of the custom pass.
22+
"""
23+
24+
@abstractmethod
25+
def uuid(self) -> Optional[Any]:
26+
"""
27+
Return an ID to uniquely identify your custom pass implementation.
28+
Return None to skip inductor code caching entirely.
29+
"""
30+
31+
def __getstate__(self):
32+
"""
33+
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
34+
to enable subclasses to only have to implement uuid.
35+
"""
36+
return self.uuid()
37+
38+
def __setstate__(self, state):
39+
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
40+
" is used for cache key uuid. Use torch>=2.6 with"
41+
" native uuid support for custom passes.")

vllm/config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import copy
55
import enum
66
import hashlib
7+
import importlib.metadata
78
import json
89
import sys
910
import warnings
@@ -17,6 +18,7 @@
1718
Optional, Protocol, Union)
1819

1920
import torch
21+
from packaging.version import Version
2022
from pydantic import BaseModel, Field, PrivateAttr
2123
from torch.distributed import ProcessGroup, ReduceOp
2224
from transformers import PretrainedConfig
@@ -52,8 +54,6 @@
5254
else:
5355
QuantizationConfig = None
5456

55-
from packaging.version import Version
56-
5757
logger = init_logger(__name__)
5858

5959
# This value is chosen to have a balance between ITL and TTFT. Note it is
@@ -3088,8 +3088,7 @@ def uuid(self):
30883088
compilation.
30893089
"""
30903090
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
3091-
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
3092-
return hashlib.sha256(encoded).digest()
3091+
return InductorPass.hash_dict(dict_)
30933092

30943093
def model_post_init(self, __context: Any) -> None:
30953094
if not self.enable_noop and self.enable_fusion:
@@ -3178,7 +3177,7 @@ def model_post_init(self, __context: Any) -> None:
31783177
# and it is not yet a priority. RFC here:
31793178
# https://github.com/vllm-project/vllm/issues/14703
31803179

3181-
if Version(torch.__version__) >= Version("2.6"):
3180+
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
31823181
KEY = 'enable_auto_functionalized_v2'
31833182
if KEY not in self.inductor_compile_config:
31843183
self.inductor_compile_config[KEY] = False

0 commit comments

Comments
 (0)