11from __future__ import annotations
22
3+ import logging
34from typing import Any , Dict , List , Optional , Sequence , Tuple
45
6+ import tensorrt as trt
57import torch
68from torch .nn import Module
79from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
810
9- # @manual=//deeplearning/trt/python:py_tensorrt
10- import tensorrt as trt
11+ logger = logging .getLogger (__name__ )
1112
1213
1314class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -22,14 +23,12 @@ def __init__(
2223 engine : trt .ICudaEngine ,
2324 input_names : Optional [List [str ]] = None ,
2425 output_names : Optional [List [str ]] = None ,
25- cuda_graph_batch_size : int = - 1 ,
2626 ):
2727 super (PythonTorchTensorRTModule , self ).__init__ ()
2828 self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
2929 self .engine = engine
3030 self .input_names = input_names if input_names is not None else []
3131 self .output_names = output_names if output_names is not None else []
32- self .cuda_graph_batch_size = cuda_graph_batch_size
3332 self .initialized = False
3433 self ._initialize ()
3534
@@ -107,7 +106,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107106 state_dict [prefix + "engine" ] = bytearray (self .engine .serialize ())
108107 state_dict [prefix + "input_names" ] = self .input_names
109108 state_dict [prefix + "output_names" ] = self .output_names
110- state_dict [prefix + "cuda_graph_batch_size" ] = self .cuda_graph_batch_size
111109
112110 def _load_from_state_dict (
113111 self ,
@@ -156,8 +154,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156154 self .input_names
157155 ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
158156
159- # This is only used when the trt engine is using implicit batch dim.
160- batch_size = inputs [0 ].shape [0 ]
161157 contiguous_inputs : List [torch .Tensor ] = [i .contiguous () for i in inputs ]
162158 bindings : List [Any ] = [None ] * (
163159 len (self .input_names )
@@ -166,25 +162,29 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166162 )
167163
168164 for i , input_name in enumerate (self .input_names ):
169- assert inputs [
170- i
171- ].is_cuda , f"{ i } th input({ input_name } ) is not on cuda device."
165+ if not contiguous_inputs [i ].is_cuda :
166+ logger .warning (
167+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
168+ "This tensor is being moved by the runtime but for performance considerations, "
169+ "ensure your inputs are all on GPU and open an issue here "
170+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171+ )
172+ contiguous_inputs = (
173+ contiguous_inputs [:i ]
174+ + [contiguous_inputs [i ].cuda ()]
175+ + contiguous_inputs [i + 1 :]
176+ )
177+
172178 assert (
173- inputs [i ].dtype == self .input_dtypes [i ]
174- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { inputs [i ].dtype } ."
179+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
180+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
175181
176182 idx = self .input_binding_indices_in_order [i ]
177183 bindings [idx ] = contiguous_inputs [i ].data_ptr ()
178184
179- if not self .engine .has_implicit_batch_dimension :
180- self .context .set_binding_shape (
181- idx , tuple (contiguous_inputs [i ].shape )
182- )
183- else :
184- assert inputs [i ].size ()[1 :] == self .input_shapes [i ], (
185- f"Shape mismatch for { i } th input({ input_name } ). "
186- f"Expect { self .input_shapes [i ]} , got { inputs [i ].size ()[1 :]} ."
187- )
185+ self .context .set_binding_shape (
186+ idx , tuple (contiguous_inputs [i ].shape )
187+ )
188188
189189 with torch .autograd .profiler .record_function (
190190 "PythonTorchTensorRTModule:ProcessOutputs"
@@ -193,10 +193,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
193193 outputs : List [torch .Tensor ] = []
194194
195195 for i , idx in enumerate (self .output_binding_indices_in_order ):
196- if self .engine .has_implicit_batch_dimension :
197- shape = (batch_size ,) + self .output_shapes [i ]
198- else :
199- shape = tuple (self .context .get_binding_shape (idx ))
196+ shape = tuple (self .context .get_binding_shape (idx ))
200197
201198 output = torch .empty (
202199 size = shape ,
@@ -207,10 +204,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207204 bindings [idx ] = output .data_ptr ()
208205
209206 for i , idx in enumerate (self .hidden_output_binding_indices_in_order ):
210- if self .engine .has_implicit_batch_dimension :
211- shape = (batch_size ,) + self .hidden_output_shapes [i ]
212- else :
213- shape = tuple (self .context .get_binding_shape (idx ))
207+ shape = tuple (self .context .get_binding_shape (idx ))
214208
215209 output = torch .empty (
216210 size = shape ,
@@ -222,14 +216,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222216 with torch .autograd .profiler .record_function (
223217 "PythonTorchTensorRTModule:TensorRTRuntime"
224218 ):
225- if self .engine .has_implicit_batch_dimension :
226- self .context .execute_async (
227- batch_size , bindings , torch .cuda .current_stream ().cuda_stream
228- )
229- else :
230- self .context .execute_async_v2 (
231- bindings , torch .cuda .current_stream ().cuda_stream
232- )
219+ self .context .execute_async_v2 (
220+ bindings , torch .cuda .current_stream ().cuda_stream
221+ )
233222
234223 if len (outputs ) == 1 :
235224 return outputs [0 ]
0 commit comments