Skip to content

Commit db1ca39

Browse files
DarkLight1337afeldman-nm
authored andcommitted
[Misc] Improve type annotations for support_torch_compile (vllm-project#10763)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
1 parent cbf1489 commit db1ca39

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

vllm/compilation/decorators.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import inspect
2-
from typing import Dict, List, Optional, Union
2+
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
33

44
import torch
5+
import torch.nn as nn
56

67
from vllm.compilation.counter import compilation_counter
78
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
@@ -12,10 +13,27 @@
1213

1314
logger = init_logger(__name__)
1415

16+
_T = TypeVar("_T", bound=type[nn.Module])
17+
18+
19+
@overload
20+
def support_torch_compile(
21+
*,
22+
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
23+
) -> Callable[[_T], _T]:
24+
...
25+
26+
27+
@overload
28+
def support_torch_compile(cls: _T) -> _T:
29+
...
30+
1531

1632
def support_torch_compile(
17-
cls: Optional[type] = None,
18-
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
33+
cls: Optional[_T] = None,
34+
*,
35+
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
36+
) -> Union[Callable[[_T], _T], _T]:
1937
"""
2038
A decorator to add support for compiling the forward method of a class.
2139
@@ -66,7 +84,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
6684
computation graph.
6785
"""
6886

69-
def cls_decorator_helper(cls: type):
87+
def cls_decorator_helper(cls: _T) -> _T:
7088
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
7189
# to avoid too much indentation for `_support_torch_compile``
7290
if not hasattr(cls, 'forward'):
@@ -105,8 +123,10 @@ def cls_decorator_helper(cls: type):
105123
return cls_decorator_helper
106124

107125

108-
def _support_torch_compile(cls: type,
109-
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
126+
def _support_torch_compile(
127+
cls: _T,
128+
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
129+
) -> _T:
110130
"""
111131
A decorator to add support for compiling the forward method of a class.
112132
"""
@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
119139
# other than TorchCompileWrapperWithCustomDispatcher
120140
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
121141

122-
old_init = cls.__init__ # type: ignore
142+
old_init = cls.__init__
123143

124144
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
125145
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
@@ -135,7 +155,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
135155
TorchCompileWrapperWithCustomDispatcher.__init__(
136156
self, compilation_level=vllm_config.compilation_config.level)
137157

138-
cls.__init__ = __init__ # type: ignore
158+
cls.__init__ = __init__
139159

140160
def __call__(self, *args, **kwargs):
141161
# torch.compiler.is_compiling() means we are inside the compilation
@@ -180,5 +200,5 @@ def __call__(self, *args, **kwargs):
180200
model_output = self.forward(*args, **kwargs)
181201
return model_output
182202

183-
cls.__call__ = __call__ # type: ignore
203+
cls.__call__ = __call__
184204
return cls

0 commit comments

Comments
 (0)