11import inspect
2- from typing import Dict , List , Optional , Union
2+ from typing import Callable , Dict , List , Optional , TypeVar , Union , overload
33
44import torch
5+ import torch .nn as nn
56
67from vllm .compilation .counter import compilation_counter
78from vllm .compilation .wrapper import TorchCompileWrapperWithCustomDispatcher
1213
1314logger = init_logger (__name__ )
1415
16+ _T = TypeVar ("_T" , bound = type [nn .Module ])
17+
18+
19+ @overload
20+ def support_torch_compile (
21+ * ,
22+ dynamic_arg_dims : Optional [Dict [str , Union [int , List [int ]]]],
23+ ) -> Callable [[_T ], _T ]:
24+ ...
25+
26+
27+ @overload
28+ def support_torch_compile (cls : _T ) -> _T :
29+ ...
30+
1531
1632def support_torch_compile (
17- cls : Optional [type ] = None ,
18- dynamic_arg_dims : Optional [Dict [str , Union [int , List [int ]]]] = None ):
33+ cls : Optional [_T ] = None ,
34+ * ,
35+ dynamic_arg_dims : Optional [Dict [str , Union [int , List [int ]]]] = None ,
36+ ) -> Union [Callable [[_T ], _T ], _T ]:
1937 """
2038 A decorator to add support for compiling the forward method of a class.
2139
@@ -66,7 +84,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
6684 computation graph.
6785 """
6886
69- def cls_decorator_helper (cls : type ) :
87+ def cls_decorator_helper (cls : _T ) -> _T :
7088 # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
7189 # to avoid too much indentation for `_support_torch_compile``
7290 if not hasattr (cls , 'forward' ):
@@ -105,8 +123,10 @@ def cls_decorator_helper(cls: type):
105123 return cls_decorator_helper
106124
107125
108- def _support_torch_compile (cls : type ,
109- dynamic_arg_dims : Dict [str , Union [int , List [int ]]]):
126+ def _support_torch_compile (
127+ cls : _T ,
128+ dynamic_arg_dims : Dict [str , Union [int , List [int ]]],
129+ ) -> _T :
110130 """
111131 A decorator to add support for compiling the forward method of a class.
112132 """
@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
119139 # other than TorchCompileWrapperWithCustomDispatcher
120140 cls .__bases__ = cls .__bases__ + (TorchCompileWrapperWithCustomDispatcher , )
121141
122- old_init = cls .__init__ # type: ignore
142+ old_init = cls .__init__
123143
124144 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = '' , ** kwargs ):
125145 old_init (self , vllm_config = vllm_config , prefix = prefix , ** kwargs )
@@ -135,7 +155,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
135155 TorchCompileWrapperWithCustomDispatcher .__init__ (
136156 self , compilation_level = vllm_config .compilation_config .level )
137157
138- cls .__init__ = __init__ # type: ignore
158+ cls .__init__ = __init__
139159
140160 def __call__ (self , * args , ** kwargs ):
141161 # torch.compiler.is_compiling() means we are inside the compilation
@@ -180,5 +200,5 @@ def __call__(self, *args, **kwargs):
180200 model_output = self .forward (* args , ** kwargs )
181201 return model_output
182202
183- cls .__call__ = __call__ # type: ignore
203+ cls .__call__ = __call__
184204 return cls
0 commit comments