1818import vllm .envs as envs
1919from vllm .compilation .counter import compilation_counter
2020from vllm .compilation .wrapper import TorchCompileWrapperWithCustomDispatcher
21- from vllm .config import CompilationMode , VllmConfig , set_current_vllm_config
21+ from vllm .config import (
22+ CompilationMode ,
23+ VllmConfig ,
24+ get_current_vllm_config ,
25+ set_current_vllm_config ,
26+ )
2227from vllm .logger import init_logger
2328from vllm .sequence import IntermediateTensors
2429from vllm .utils .import_utils import resolve_obj_by_qualname
@@ -74,6 +79,21 @@ def support_torch_compile(
7479) -> Callable [[_T ], _T ]: ...
7580
7681
82+ @overload
83+ def support_torch_compile (
84+ * ,
85+ mark_unbacked_dims : dict [str , int | list [int ]] | None ,
86+ ) -> Callable [[_T ], _T ]: ...
87+
88+
89+ @overload
90+ def support_torch_compile (
91+ * ,
92+ dynamic_arg_dims : dict [str , int | list [int ]] | None ,
93+ mark_unbacked_dims : dict [str , int | list [int ]] | None ,
94+ ) -> Callable [[_T ], _T ]: ...
95+
96+
7797@overload
7898def support_torch_compile (cls : _T ) -> _T : ...
7999
@@ -82,6 +102,7 @@ def support_torch_compile(
82102 cls : _T | None = None ,
83103 * ,
84104 dynamic_arg_dims : dict [str , int | list [int ]] | None = None ,
105+ mark_unbacked_dims : dict [str , int | list [int ]] | None = None ,
85106 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
86107) -> Callable [[_T ], _T ] | _T :
87108 """
@@ -135,6 +156,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
135156 returns a boolean value indicating whether to compile the model or not.
136157 This is useful if you want to compile the model only when certain
137158 conditions are met.
159+
160+ `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
161+ dim to be decorated with `mark_unbacked`. This is useful if we would like to
162+ enforce that dynamo do not specialize on 0/1 values in the case of dummy input
163+ such as for vision model compilation
138164 """
139165
140166 def cls_decorator_helper (cls : _T ) -> _T :
@@ -172,7 +198,9 @@ def cls_decorator_helper(cls: _T) -> _T:
172198 raise ValueError (
173199 f"Argument { k } not found in the forward method of { cls } "
174200 )
175- return _support_torch_compile (cls , inferred_dynamic_arg_dims , enable_if )
201+ return _support_torch_compile (
202+ cls , inferred_dynamic_arg_dims , mark_unbacked_dims , enable_if
203+ )
176204
177205 if cls is not None :
178206 # use `support_torch_compile` as a decorator without arguments
@@ -212,6 +240,7 @@ def _verify_source_unchanged(source_info, vllm_config) -> None:
212240def _support_torch_compile (
213241 cls : _T ,
214242 dynamic_arg_dims : dict [str , int | list [int ]],
243+ mark_unbacked_dims : dict [str , int | list [int ]] | None = None ,
215244 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
216245) -> _T :
217246 """
@@ -230,8 +259,22 @@ def _support_torch_compile(
230259
231260 setattr (cls , IGNORE_COMPILE_KEY , False )
232261
233- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" , ** kwargs ):
234- old_init (self , vllm_config = vllm_config , prefix = prefix , ** kwargs )
262+ def __init__ (
263+ self , * , vllm_config : VllmConfig | None = None , prefix : str = "" , ** kwargs
264+ ):
265+ if vllm_config is None :
266+ vllm_config = get_current_vllm_config ()
267+
268+ # NOTE: to support multimodal models (such as encoder),
269+ # we may not have vllm_config so we may need to patch
270+ # it
271+ sig = inspect .signature (old_init )
272+ if "vllm_config" in sig .parameters :
273+ kwargs ["vllm_config" ] = vllm_config
274+ if "prefix" in sig .parameters :
275+ kwargs ["prefix" ] = prefix
276+ old_init (self , ** kwargs )
277+
235278 self .vllm_config = vllm_config
236279 enable_compile = enable_if is None or enable_if (vllm_config )
237280 # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
@@ -344,6 +387,15 @@ def __call__(self, *args, **kwargs):
344387 "Unsupported dynamic dimensions"
345388 f" { dims } for argument { k } with type { type (arg )} ."
346389 )
390+ if mark_unbacked_dims :
391+ for k , dims in mark_unbacked_dims .items ():
392+ arg = bound_args .arguments .get (k )
393+ if arg is not None :
394+ dims = [dims ] if isinstance (dims , int ) else dims
395+ if isinstance (arg , torch .Tensor ):
396+ # In case dims is specified with negative indexing
397+ dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
398+ torch ._dynamo .decorators .mark_unbacked (arg , dims )
347399 # here, it is the starting point of the `torch.compile` process
348400 start_monitoring_torch_compile (self .vllm_config )
349401 logger .debug ("Start compiling function %s" , self .original_code_object )
0 commit comments