1313from torch_tensorrt .dynamo .lowering import (
1414 apply_lowering_passes ,
1515 get_decompositions ,
16+ remove_sym_nodes ,
1617 repair_input_aliasing ,
1718)
1819from torch_tensorrt .dynamo .utils import (
2728@td .register_backend (name = "tensorrt" ) # type: ignore[misc]
2829@td .register_backend (name = "torch_tensorrt" ) # type: ignore[misc]
2930def torch_tensorrt_backend (
30- gm : torch .fx .GraphModule , sample_inputs : Sequence [torch . Tensor ], ** kwargs : Any
31+ gm : torch .fx .GraphModule , sample_inputs : Sequence [Any ], ** kwargs : Any
3132) -> torch .nn .Module :
3233 # Set log level at the top of compilation (torch_tensorrt.dynamo)
3334 if (
@@ -44,15 +45,15 @@ def torch_tensorrt_backend(
4445
4546@td .register_backend (name = "aot_torch_tensorrt_aten" ) # type: ignore[misc]
4647def aot_torch_tensorrt_aten_backend (
47- gm : torch .fx .GraphModule , sample_inputs : Sequence [torch . Tensor ], ** kwargs : Any
48+ gm : torch .fx .GraphModule , sample_inputs : Sequence [Any ], ** kwargs : Any
4849) -> torch .nn .Module :
4950 settings = parse_dynamo_kwargs (kwargs )
5051 return _pretraced_backend (gm , sample_inputs , settings )
5152
5253
5354def _pretraced_backend (
5455 gm : torch .fx .GraphModule ,
55- sample_inputs : Sequence [torch . Tensor ],
56+ sample_inputs : Sequence [Any ],
5657 settings : CompilationSettings = CompilationSettings (),
5758) -> torch .fx .GraphModule | Callable [..., Any ]:
5859 """Helper function to manage translation of traced FX module to TRT engines
@@ -74,10 +75,17 @@ def _pretraced_backend(
7475 fake_mode , "allow_non_fake_inputs" , True
7576 ), fake_mode :
7677 repair_input_aliasing (gm )
78+
79+ # Remove sym_int placeholders and inputs
80+ remove_sym_nodes (gm )
81+ torch_inputs = [
82+ input for input in sample_inputs if isinstance (input , torch .Tensor )
83+ ]
84+
7785 # Invoke AOTAutograd to translate operators to aten
7886 gm = aot_export_joint_simple (
7987 gm ,
80- sample_inputs ,
88+ torch_inputs ,
8189 trace_joint = False ,
8290 decompositions = get_decompositions (
8391 settings .enable_experimental_decompositions
@@ -86,10 +94,10 @@ def _pretraced_backend(
8694
8795 logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
8896
89- gm = apply_lowering_passes (gm , sample_inputs )
97+ gm = apply_lowering_passes (gm , torch_inputs )
9098
9199 torchtrt_inputs = prepare_inputs (
92- sample_inputs , disable_memory_format_check = True
100+ torch_inputs , disable_memory_format_check = True
93101 )
94102 trt_compiled = compile_module (
95103 gm ,
0 commit comments