11from __future__ import annotations
22
3+ import logging
34from typing import Any , Dict , List , Optional , Sequence , Tuple
45
6+ # @manual=//deeplearning/trt/python:py_tensorrt
7+ import tensorrt as trt
58import torch
69from torch .nn import Module
710from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
811
9- # @manual=//deeplearning/trt/python:py_tensorrt
10- import tensorrt as trt
12+ logger = logging .getLogger (__name__ )
1113
1214
1315class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -22,14 +24,12 @@ def __init__(
2224 engine : trt .ICudaEngine ,
2325 input_names : Optional [List [str ]] = None ,
2426 output_names : Optional [List [str ]] = None ,
25- cuda_graph_batch_size : int = - 1 ,
2627 ):
2728 super (PythonTorchTensorRTModule , self ).__init__ ()
2829 self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
2930 self .engine = engine
3031 self .input_names = input_names if input_names is not None else []
3132 self .output_names = output_names if output_names is not None else []
32- self .cuda_graph_batch_size = cuda_graph_batch_size
3333 self .initialized = False
3434 self ._initialize ()
3535
@@ -107,7 +107,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107107 state_dict [prefix + "engine" ] = bytearray (self .engine .serialize ())
108108 state_dict [prefix + "input_names" ] = self .input_names
109109 state_dict [prefix + "output_names" ] = self .output_names
110- state_dict [prefix + "cuda_graph_batch_size" ] = self .cuda_graph_batch_size
111110
112111 def _load_from_state_dict (
113112 self ,
@@ -156,8 +155,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156155 self .input_names
157156 ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
158157
159- # This is only used when the trt engine is using implicit batch dim.
160- batch_size = inputs [0 ].shape [0 ]
161158 contiguous_inputs : List [torch .Tensor ] = [i .contiguous () for i in inputs ]
162159 bindings : List [Any ] = [None ] * (
163160 len (self .input_names )
@@ -166,37 +163,34 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166163 )
167164
168165 for i , input_name in enumerate (self .input_names ):
169- assert inputs [
170- i
171- ].is_cuda , f"{ i } th input({ input_name } ) is not on cuda device."
166+ if not contiguous_inputs [i ].is_cuda :
167+ logger .warning (
168+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
169+ "This tensor is being moved by the runtime but for performance considerations, "
170+ "ensure your inputs are all on GPU and open an issue here "
171+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
172+ )
173+ contiguous_inputs = (
174+ contiguous_inputs [:i ]
175+ + [contiguous_inputs [i ].cuda ()]
176+ + contiguous_inputs [i + 1 :]
177+ )
178+
172179 assert (
173- inputs [i ].dtype == self .input_dtypes [i ]
174- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { inputs [i ].dtype } ."
180+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
181+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
175182
176183 idx = self .input_binding_indices_in_order [i ]
177184 bindings [idx ] = contiguous_inputs [i ].data_ptr ()
178185
179- if not self .engine .has_implicit_batch_dimension :
180- self .context .set_binding_shape (
181- idx , tuple (contiguous_inputs [i ].shape )
182- )
183- else :
184- assert inputs [i ].size ()[1 :] == self .input_shapes [i ], (
185- f"Shape mismatch for { i } th input({ input_name } ). "
186- f"Expect { self .input_shapes [i ]} , got { inputs [i ].size ()[1 :]} ."
187- )
188-
189186 with torch .autograd .profiler .record_function (
190187 "PythonTorchTensorRTModule:ProcessOutputs"
191188 ):
192189 # create output tensors
193190 outputs : List [torch .Tensor ] = []
194191
195192 for i , idx in enumerate (self .output_binding_indices_in_order ):
196- if self .engine .has_implicit_batch_dimension :
197- shape = (batch_size ,) + self .output_shapes [i ]
198- else :
199- shape = tuple (self .context .get_binding_shape (idx ))
193+ shape = tuple (self .context .get_binding_shape (idx ))
200194
201195 output = torch .empty (
202196 size = shape ,
@@ -207,10 +201,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207201 bindings [idx ] = output .data_ptr ()
208202
209203 for i , idx in enumerate (self .hidden_output_binding_indices_in_order ):
210- if self .engine .has_implicit_batch_dimension :
211- shape = (batch_size ,) + self .hidden_output_shapes [i ]
212- else :
213- shape = tuple (self .context .get_binding_shape (idx ))
204+ shape = tuple (self .context .get_binding_shape (idx ))
214205
215206 output = torch .empty (
216207 size = shape ,
@@ -222,14 +213,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222213 with torch .autograd .profiler .record_function (
223214 "PythonTorchTensorRTModule:TensorRTRuntime"
224215 ):
225- if self .engine .has_implicit_batch_dimension :
226- self .context .execute_async (
227- batch_size , bindings , torch .cuda .current_stream ().cuda_stream
228- )
229- else :
230- self .context .execute_async_v2 (
231- bindings , torch .cuda .current_stream ().cuda_stream
232- )
216+ self .context .execute_async_v2 (
217+ bindings , torch .cuda .current_stream ().cuda_stream
218+ )
233219
234220 if len (outputs ) == 1 :
235221 return outputs [0 ]
0 commit comments