88
99import torch
1010import torch .nn as nn
11+ from packaging import version
1112from torch ._dynamo .symbolic_convert import InliningInstructionTranslator
1213
1314from vllm .compilation .counter import compilation_counter
@@ -300,13 +301,30 @@ def patched_inline_call(parent, func, args, kwargs):
300301 logger .debug (
301302 "enable_cpp_symbolic_shape_guards config not available" )
302303
303- with patch .object (InliningInstructionTranslator , ' inline_call' ,
304+ with patch .object (InliningInstructionTranslator , " inline_call" ,
304305 patched_inline_call ), torch ._dynamo .config .patch (
305306 ** dynamo_config_patches
307+ << < << << HEAD
306308 ), maybe_use_cudagraph_partition_wrapper (
307309 self .vllm_config ):
308- output = self .compiled_callable (* args , ** kwargs )
310+ from vllm .model_executor .parameter import (
311+ BasevLLMParameter , ModelWeightParameter , RowvLLMParameter ,
312+ _ColumnvLLMParameter )
313+ with (
314+ torch ._dynamo .config .patch (
315+ "traceable_tensor_subclasses" , [
316+ BasevLLMParameter , ModelWeightParameter ,
317+ _ColumnvLLMParameter , RowvLLMParameter
318+ ]),
319+ patch (
320+ "torch._dynamo.variables.torch.can_dispatch_torch_function" , # noqa: E501
321+ return_false )):
322+ output = self .compiled_callable (* args , ** kwargs )
323+ == == == =
324+ ), _torch27_patch_tensor_subclasses ():
309325
326+ output = self .compiled_callable (* args , ** kwargs )
327+ >> >> >> > 9 adfb5582 (break out function , gate torch )
310328 return output
311329
312330 # usually, capturing the model once is enough, and then we can
@@ -320,6 +338,7 @@ def patched_inline_call(parent, func, args, kwargs):
320338 return cls
321339
322340
341+ < << << << HEAD
323342@contextlib .contextmanager
324343def maybe_use_cudagraph_partition_wrapper (vllm_config : VllmConfig ):
325344 """
@@ -367,3 +386,29 @@ def customized_cudagraph_wrapper(f,
367386 if (compilation_config .cudagraph_mode != CUDAGraphMode .NONE
368387 and compilation_config .use_inductor_graph_partition ):
369388 torch ._inductor .utils .set_customized_partition_wrappers (None )
389+ == == == =
390+ @contextlib .contextmanger
391+ def _torch27_patch_tensor_subclasses ():
392+ from vllm .model_executor .parameter import (BasevLLMParameter ,
393+ ModelWeightParameter ,
394+ RowvLLMParameter ,
395+ _ColumnvLLMParameter )
396+
397+ def return_false (* args , ** kwargs ):
398+ return False
399+
400+ if version .parse ("2.7" ) <= version .parse (
401+ torch .__version__ ) < version .parse ("2.8" ):
402+ yield
403+ return
404+
405+ with (
406+ torch ._dynamo .config .patch ("traceable_tensor_subclasses" , [
407+ BasevLLMParameter , ModelWeightParameter , _ColumnvLLMParameter ,
408+ RowvLLMParameter
409+ ]),
410+ patch (
411+ "torch._dynamo.variables.torch.can_dispatch_torch_function" , # noqa: E501
412+ return_false )):
413+ yield
414+ > >> >> >> 9 adfb5582 (break out function , gate torch )
0 commit comments