diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 826d1744d88a5..4f960b441f21d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -208,8 +208,8 @@ def wrap_inductor(graph: fx.GraphModule, from torch._inductor.compile_fx import graph_returns_tuple returns_tuple = graph_returns_tuple(graph) - # this is the graph we return to Dynamo to run - def compiled_graph(*args) -> Optional[fx.CompiledFxGraph]: + # this is the callable we return to Dynamo to run + def compiled_graph(*args): # convert args to list list_args = list(args) graph_output = inductor_compiled_graph(list_args) @@ -537,7 +537,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: example_inputs[x].clone() for x in self.sym_tensor_indices ] - def copy_and_call(*args) -> fx.GraphModule: + # this is the callable we return to Dynamo to run + def copy_and_call(*args): list_args = list(args) for i, index in enumerate(self.sym_tensor_indices): runtime_tensor = list_args[index]