Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions tests/compile/test_pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import pickle
import copy

import pytest
import torch
Expand All @@ -10,32 +9,63 @@
from vllm.config import CompilationConfig


# dummy custom pass that doesn't inherit
def simple_callable(graph: torch.fx.Graph):
pass


callable_uuid = CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config

pass_manager = PostGradPassManager()
pass_manager.configure(config)

with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # noqa, type wrong on purpose


# Pass that inherits from InductorPass
class ProperPass(InductorPass):

def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass


@pytest.mark.parametrize(
"works, callable",
"callable",
[
(False, simple_callable),
(True, callable_uuid),
(True, CallableInductorPass(simple_callable)),
ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
],
)
def test_pass_manager(works: bool, callable):
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config

pass_manager = PostGradPassManager()
pass_manager.configure(config)

# Try to add the callable to the pass manager
if works:
pass_manager.add(callable)
pickle.dumps(pass_manager)
else:
with pytest.raises(AssertionError):
pass_manager.add(callable)
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2

# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()

# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()
53 changes: 25 additions & 28 deletions vllm/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
# SPDX-License-Identifier: Apache-2.0

import hashlib
import importlib.metadata
import inspect
import json
import types
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import torch
from packaging.version import Version
from torch import fx

if Version(importlib.metadata.version('torch')) >= Version("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)

class InductorPass(ABC):

class InductorPass(CustomGraphPass):
"""
General custom inductor pass interface.
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""

@abstractmethod
def __call__(self, graph: torch.fx.Graph):
"""
Execute the pass on the given graph.
"""
raise NotImplementedError

def uuid(self) -> Any:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
Expand All @@ -48,7 +51,16 @@ def hash_source(*srcs: Union[str, Any]):
else:
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.digest()
return hasher.hexdigest()

@staticmethod
def hash_dict(dict_: Dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()


class CallableInductorPass(InductorPass):
Expand All @@ -61,25 +73,10 @@ def __init__(self,
callable: Callable[[fx.Graph], None],
uuid: Optional[Any] = None):
self.callable = callable
if uuid is None:
uuid = InductorPass.hash_source(callable)
self._uuid = uuid
self._uuid = self.hash_source(callable) if uuid is None else uuid

def __call__(self, graph: torch.fx.Graph):
self.callable(graph)

def uuid(self) -> Any:
return self._uuid

def __getstate__(self):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.

TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return self._uuid

def __setstate__(self, state):
raise ValueError("Cannot unpickle CallableInductorPass")
44 changes: 9 additions & 35 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List
from typing import List

import torch
from torch import fx as fx

from vllm.config import CompilationConfig
from vllm.logger import init_logger

from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .inductor_pass import CustomGraphPass, InductorPass
from .noop_elimination import NoOpEliminationPass

logger = init_logger(__name__)


class PlaceHolder:
pass


if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore


class PostGradPassManager(Parent):
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
TODO(torch==2.6), use CustomGraphPass
(torch._inductor.custom_graph_pass.CustomGraphPass)
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).

The order of the post-grad post-passes is:
1. passes (constructor parameter)
Expand Down Expand Up @@ -67,27 +55,13 @@ def add(self, pass_: InductorPass):
self.passes.append(pass_)

def uuid(self):
return self.__getstate__()

def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
`config["post_grad_custom_post_pass"]` in the Inductor config.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.

TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
return state

def __setstate__(self, state):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise ValueError("Cannot unpickle PostGradPassManager")
return InductorPass.hash_dict(state)
41 changes: 41 additions & 0 deletions vllm/compilation/torch25_custom_graph_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional

import torch


class Torch25CustomGraphPass(ABC): # noqa (redefinition)
"""
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
It conforms to the 2.6 interface but also supports pickling, as that's what
the inductor code cache uses to determine the cache key before 2.6.
(in 2.6 and above, uuid() is used.)

Subclasses can just "pretend" that uuid is used.
"""

@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
"""
Implementation of the custom pass.
"""

@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation.
Return None to skip inductor code caching entirely.
"""

def __getstate__(self):
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return self.uuid()

def __setstate__(self, state):
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" native uuid support for custom passes.")
9 changes: 4 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import enum
import hashlib
import importlib.metadata
import json
import sys
import warnings
Expand All @@ -17,6 +18,7 @@
Optional, Protocol, Union)

import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
Expand Down Expand Up @@ -52,8 +54,6 @@
else:
QuantizationConfig = None

from packaging.version import Version

logger = init_logger(__name__)

# This value is chosen to have a balance between ITL and TTFT. Note it is
Expand Down Expand Up @@ -3060,8 +3060,7 @@ def uuid(self):
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
return InductorPass.hash_dict(dict_)

def model_post_init(self, __context: Any) -> None:
if not self.enable_noop and self.enable_fusion:
Expand Down Expand Up @@ -3150,7 +3149,7 @@ def model_post_init(self, __context: Any) -> None:
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703

if Version(torch.__version__) >= Version("2.6"):
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False
Expand Down