diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..135d940c2d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -237,6 +237,14 @@ TRTEngine::TRTEngine( out_binding_names[pyt_idx] = binding_name; } num_io = std::make_pair(inputs_size, outputs); + + this->current_device_id = at::cuda::current_device(); + this->stream = c10::cuda::getCurrentCUDAStream(this->current_device_id); + this->io_size = this->cuda_engine->getNbIOTensors(); + for (int64_t i = 0; i < this->in_binding_names.size(); i++) { + this->isShapeInferenceIO[this->in_binding_names[i]] = + this->cuda_engine->isShapeInferenceIO(this->in_binding_names[i].c_str()); + } } #ifndef NDEBUG @@ -281,6 +289,14 @@ void TRTEngine::enable_profiling() { exec_ctx->setProfiler(trt_engine_profiler.get()); } +void TRTEngine::set_unowned_output_tensor(bool enable) { + this->unowned_output_tensor = enable; +} + +bool TRTEngine::is_unowned_output_tensor() { + return this->unowned_output_tensor; +} + void TRTEngine::set_profile_format(std::string format) { if (format == "trex") { this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..c44869c7c8 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -103,6 +103,9 @@ struct TRTEngine : torch::CustomClassHolder { std::shared_ptr cuda_engine; std::shared_ptr exec_ctx; std::pair num_io; + uint64_t io_size; + std::map isShapeInferenceIO; + bool unowned_output_tensor = false; std::string name; RTDevice device_info; @@ -159,6 +162,8 @@ struct TRTEngine : torch::CustomClassHolder { int64_t get_automatic_device_memory_budget(); std::vector infer_outputs(std::vector> input_shapes); void set_pre_allocated_outputs(bool enable); + void set_unowned_output_tensor(bool enable); + bool is_unowned_output_tensor(); TorchTRTRuntimeStates runtime_states; friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; @@ -169,13 +174,14 @@ struct TRTEngine : torch::CustomClassHolder { // CUDAGraph-Related Functionality at::cuda::CUDAGraph cudagraph = {}; - at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream(); - at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream = c10::cuda::getDefaultCUDAStream(); + int64_t current_device_id = at::cuda::current_device(); std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key = "None"; bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + std::vector allocated_outputs; // Output Allocator-Related Functionality bool requires_output_allocator = false; // engine requires output allocator diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..22d3c6340e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -96,7 +96,8 @@ void setup_input_tensors( std::vector inputs, c10::intrusive_ptr compiled_engine, bool cudagraphs_enabled, - bool need_cudagraphs_record) { + bool need_cudagraphs_record, + bool shape_changed) { // this is a buffer to store shape tensor input addresses throughout the runtime scope std::list> inputShapeTensorValues; std::list formatted_inputs(compiled_engine->num_io.first); @@ -117,7 +118,7 @@ void setup_input_tensors( auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); - if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { + if (compiled_engine->isShapeInferenceIO[name]) { // Shape tensor inputs are casted to int64 explicitly. // Refer to // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 @@ -145,10 +146,10 @@ void setup_input_tensors( // Create a new persistent input buffer compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone()); } - - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); - + if (shape_changed) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); + } if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); @@ -217,7 +218,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->cudagraph.reset(); } - std::vector outputs(compiled_engine->num_io.second); + std::vector outputs; // Intialize inputs and outputs to be available throughout the succeeding scopes { // Input Setup @@ -226,10 +227,9 @@ std::vector execute_engine(std::vector inputs, c10::intr input_profiler_guard = std::make_unique(compiled_engine->input_profile_path); } - - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); + setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed); // Check if input shapes can be inferred. - int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + int32_t const io_size{compiled_engine->io_size}; std::vector names(io_size); int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); TORCHTRT_CHECK( @@ -240,6 +240,7 @@ std::vector execute_engine(std::vector inputs, c10::intr } { // Output Setup + bool new_outputs = false; std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = @@ -248,26 +249,32 @@ std::vector execute_engine(std::vector inputs, c10::intr if (can_use_pre_allocated_outputs) { outputs = compiled_engine->pre_allocated_outputs; } else { - outputs = create_output_tensors(compiled_engine); + if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->unowned_output_tensor or shape_changed) { + compiled_engine->allocated_outputs = create_output_tensors(compiled_engine); + new_outputs = true; + } + outputs = compiled_engine->allocated_outputs; } - for (auto output_indices : compiled_engine->out_binding_map) { - auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - if (need_cudagraphs_record) { - // If we are recording the cuda graph then we need to update the persistent output buffer - compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); - } + if (new_outputs) { + for (auto output_indices : compiled_engine->out_binding_map) { + auto pyt_idx = output_indices.second; + std::string name = compiled_engine->out_binding_names[pyt_idx]; + if (need_cudagraphs_record) { + // If we are recording the cuda graph then we need to update the persistent output buffer + compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); + } - if (cudagraphs_enabled) { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress( - name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } else { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); + if (cudagraphs_enabled) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress( + name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } else { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } } } } @@ -275,18 +282,12 @@ std::vector execute_engine(std::vector inputs, c10::intr auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + if (current_device_id != compiled_engine->current_device_id) { + compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + } } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); std::unique_ptr enqueue_profiler_guard; if (compiled_engine->profile_execution) { @@ -294,18 +295,13 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); - if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } else { if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; + c10::cuda::CUDAStream recording_stream = compiled_engine->stream; compiled_engine->cudagraph.capture_begin(); compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); @@ -325,11 +321,6 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); } - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { @@ -354,7 +345,7 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, false, false); + setup_input_tensors(inputs, compiled_engine, false, false, true); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -378,18 +369,12 @@ std::vector execute_engine(std::vector inputs, c10::intr auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else { - current_device_id = at::cuda::current_device(); - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + if (current_device_id != compiled_engine->current_device_id) { + compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + } } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); std::unique_ptr enqueue_profiler_guard; if (compiled_engine->profile_execution) { @@ -397,21 +382,11 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); - // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } // End engine exeuction (resets to caller stream) - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { output_profiler_guard = diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 173ff8c35f..36d85481ff 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -90,6 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def("reset_captured_graph", &TRTEngine::reset_captured_graph) + .def("set_unowned_output_tensor", &TRTEngine::set_unowned_output_tensor) + .def("is_unowned_output_tensor", &TRTEngine::is_unowned_output_tensor) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc345947d3..2a52e488f0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -877,6 +877,7 @@ def preserve_module_specs( for attr in dir(gm): if attr.startswith("_frozen_param"): delattr(gm, attr) + trt_module = None for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -997,6 +998,10 @@ def preserve_module_specs( ) as f: f.write(trt_module.get_layer_info()) + # Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor + if trt_module: + trt_module.set_unowned_output_tensor(True) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) @@ -1339,7 +1344,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore + exported_program, list(trt_arg_inputs), trt_kwarg_inputs )[0] try: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..d54be6f9e8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -172,8 +172,9 @@ def __init__( self._input_buffers: List[torch.Tensor] = [] self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None - self._caller_stream: Optional[torch.cuda.Stream] = None - self._engine_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: torch.cuda.Stream = torch.cuda.current_stream() + self.output_tensors: Optional[List[torch.Tensor]] = None + self.sync_stream = True # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None @@ -218,9 +219,30 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - + self.device = torch.cuda.current_device() + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + # If the output tensor is not owned by the engine (unowned_output_tensor=True), we need to create a new output tensor in each forward pass + self.unowned_output_tensor = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + self.is_shape_inference_io = { + input_name: self.engine.is_shape_inference_io(input_name) + for input_name in self.input_names + } + + def set_unowned_output_tensor(self, enabled: bool) -> None: + """ + Set the flag to indicate if the output tensor is unowned by the engine. + If self.unowned_output_tensor=True, the engine will create a new output tensor in each forward pass. + This would be slower but is required when users need to manipulate the output tensor after each forward pass. + Therefore, this should be set to True only for the last module in a graph and leave to False for intermediate modules, + which users don't have access to. + Args: + enabled: bool + Whether to set the flag to True. + + """ + self.unowned_output_tensor = enabled def get_streamable_device_memory_budget(self) -> Any: return self.engine.streamable_weights_size @@ -263,6 +285,9 @@ def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream + # otherwise, use the caller stream and disable stream synchronization + self._engine_stream = torch.cuda.current_stream() self.initialized = True runtime = trt.Runtime(TRT_LOGGER) @@ -287,10 +312,14 @@ def setup_engine(self) -> None: for output_name in self.output_names ] self.output_shapes = [ - self.engine.get_tensor_shape(output_name) + tuple(self.context.get_tensor_shape(output_name)) for output_name in self.output_names ] + self.shape_key = "".join( + str(tuple(t)).replace(" ", "") for t in self.input_shapes + ) + if self.requires_output_allocator: self.create_output_allocator() @@ -356,6 +385,7 @@ def setup_input_tensors( contiguous_inputs: List[torch.Tensor], cudagraphs_enabled: bool, need_cudagraphs_record: bool, + shape_changed: bool = True, ) -> None: for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: @@ -382,16 +412,17 @@ def setup_input_tensors( # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): + if self.is_shape_inference_io[input_name]: # Shape tensor inputs are casted to int64 explicitly # Currently Torch CPU pointers are not working; numpy pointers are used instead # to refer to underlying memory inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data) else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) + if shape_changed: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) if cudagraphs_enabled: self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( @@ -410,7 +441,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: output = torch.empty( size=self.output_shapes[o], dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), + device=self.device, ) outputs.append(output) return outputs @@ -459,7 +490,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, + self.cudagraphs_enabled, + need_cudagraphs_record, + shape_changed + or self.output_tensors is None, # First time execution ) if shape_changed: @@ -481,15 +516,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if can_use_pre_allocated_outputs: outputs = self.pre_allocated_outputs else: - self.output_shapes = [ - tuple(self.context.get_tensor_shape(output_name)) - for output_name in self.output_names - ] + if shape_changed: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] if DYNAMIC_DIM in self.output_shapes: raise ValueError( "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." ) - outputs = self.create_output_tensors() + if ( + self.output_tensors is None + or self.unowned_output_tensor + or shape_changed + ): + self.output_tensors = self.create_output_tensors() + outputs = self.output_tensors for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: @@ -511,41 +553,46 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.profiling_enabled else nullcontext() ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - self._engine_stream.wait_stream(self._caller_stream) + if self.cudagraphs_enabled: + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() - with torch.cuda.stream(self._engine_stream): - if self.cudagraphs_enabled: - if need_cudagraphs_record: - self.cudagraph = torch.cuda.CUDAGraph() + if self.profiling_enabled: + self.cudagraph.enable_debug_mode() - if self.profiling_enabled: - self.cudagraph.enable_debug_mode() + with torch.cuda.graph( + self.cudagraph, stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream + if self.profiling_enabled: + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + self.cudagraph.debug_dump( + f"{tmpdir}/{self.name}_cudagraph.dot" ) if self.profiling_enabled: self.cudagraph.debug_dump( f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot" ) + self.cudagraph.replay() # type: ignore - self.cudagraph.replay() # type: ignore - - else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + else: + import warnings - self._caller_stream.wait_stream(self._engine_stream) + with warnings.catch_warnings(): + try: + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) + except Warning as e: + breakpoint() + print("warning ignored") if self.use_pre_allocated_outputs: self.pre_allocated_outputs = self.create_output_tensors() @@ -600,22 +647,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.profiling_enabled else nullcontext() ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): self.context.execute_async_v3( self._engine_stream.cuda_stream ) # The OutputAllocator is called by execute_async_v3() - self._caller_stream.wait_stream(self._engine_stream) - with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessOutputs" @@ -644,8 +681,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs - self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - # Run forward function contiguous_inputs: List[torch.Tensor] = [ (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) @@ -750,13 +785,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [] - for t in inputs: - if not isinstance(t, torch.Tensor): - return True - tensor_inputs.append(t) + if not all(isinstance(t, torch.Tensor) for t in inputs): + return True + new_shape_key = "".join( - str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + str(tuple(t.shape)).replace(" ", "") + for t in inputs + if isinstance(t, torch.Tensor) ) # If the new shape key differs from the existing one, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..2a7d22e312 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -156,6 +156,11 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, + "requires_new_output_tensor": ( + False + if self.engine is None + else self.engine.get_requires_new_output_tensor() + ), } target_platform = ( Platform.current_platform() @@ -284,6 +289,8 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] + self.unowned_output_tensor = metadata["unowned_output_tensor"] + self.engine.set_unowned_output_tensor(self.unowned_output_tensor) else: self.engine = None @@ -355,6 +362,12 @@ def enable_profiling( self.engine.enable_profiling() self.engine.set_profile_format(profile_format) + def set_unowned_output_tensor(self, enabled: bool) -> None: + self.engine.set_unowned_output_tensor(enabled) + + def is_unowned_output_tensor(self) -> bool: + return self.engine.is_unowned_output_tensor() # type: ignore[no-any-return] + def disable_profiling(self) -> None: """Disable the profiler""" if self.engine is None: diff --git a/setup.py b/setup.py index d151e0df0c..bde77cef3d 100644 --- a/setup.py +++ b/setup.py @@ -195,10 +195,10 @@ def build_libtorchtrt_cxx11_abi( else: cmd.append("//:libtorchtrt") - if develop: - cmd.append("--compilation_mode=dbg") - else: - cmd.append("--compilation_mode=opt") + # if develop: + # cmd.append("--compilation_mode=dbg") + # else: + cmd.append("--compilation_mode=opt") if use_dist_dir: if IS_AARCH64: cmd.append("--distdir=third_party/dist_dir/aarch64-linux-gnu") diff --git a/tools/perf/graph_break_overhead/graph_break.py b/tools/perf/graph_break_overhead/graph_break.py new file mode 100644 index 0000000000..c7e5129527 --- /dev/null +++ b/tools/perf/graph_break_overhead/graph_break.py @@ -0,0 +1,150 @@ +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +import torchvision +from pyinstrument import Profiler +from torch_tensorrt.dynamo.utils import get_model_device + +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) +import argparse + + +def benchmark_model(model, input, label, profile=False): + if profile: + profiler = Profiler(interval=0.01) + profiler.start() + start_time = time.time() + for _ in range(1000): + model_outputs = model(*input) + end_time = time.time() + print(f"{label} 1000 runs: {end_time - start_time:.4f} seconds") + if profile: + profiler.stop() + profiler.write_html( + f"/home/other/{label.replace(' ', '_')}.html", timeline=False, show_all=True + ) + + +def main(args): + profile = args.profile + use_python_runtime = args.use_python_runtime + model_name = args.model + + with torchtrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False): + + model = ( + torchvision.models.detection.ssd300_vgg16(pretrained=True).eval().to("cuda") + ) + input = [torch.randn((1, 3, 224, 224)).to("cuda")] + + BATCH = torch.export.Dim("BATCH", min=1, max=16) + exp_program = torch.export.export(model, tuple(input), strict=True) + trt_mod2 = trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(input), + use_python_runtime=use_python_runtime, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + reuse_cached_engines=False, + ) + + trt_mod1 = trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(input), + use_python_runtime=use_python_runtime, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + torch_executed_ops={torch.ops.aten.relu.default}, + reuse_cached_engines=False, + ) + + # AOTI + if not use_python_runtime: + torchtrt.save( + trt_mod1, + "/home/other/aoti.pt2", + output_format="aot_inductor", + inputs=input, + retrace=True, + ) + aoti_model_gb = torch._inductor.aoti_load_package("/home/other/aoti.pt2") + torchtrt.save( + trt_mod2, + "/home/other/aoti_no_gb.pt2", + output_format="aot_inductor", + inputs=input, + retrace=True, + ) + aoti_model_no_gb = torch._inductor.aoti_load_package( + "/home/other/aoti_no_gb.pt2" + ) + + # Warmup runs to avoid measuring first-run overheads + for _ in range(100): + trt_mod2(*input) + model(*input) + if not use_python_runtime: + aoti_model_gb(*input) + aoti_model_no_gb(*input) + + time.sleep(1) + benchmark_model(trt_mod1, input, "trt_mod1 (with graph break)", profile=profile) + benchmark_model(trt_mod2, input, "trt_mod2 (without graph break)", profile=profile) + if not use_python_runtime: + benchmark_model(aoti_model_gb, input, "aoti_model_gb", profile=profile) + benchmark_model(aoti_model_no_gb, input, "aoti_model_no_gb", profile=profile) + + out1 = trt_mod1(*input) + out2 = trt_mod2(*input) + if not use_python_runtime: + out3 = aoti_model_gb(*input) + out4 = aoti_model_no_gb(*input) + + def _to_tuple(x): + if isinstance(x, (tuple, list)): + return tuple(x) + return (x,) + + outs1 = _to_tuple(out1) + outs2 = _to_tuple(out2) + if not use_python_runtime: + outs3 = _to_tuple(out3) + outs4 = _to_tuple(out4) + + def compare_outputs(a, b, name1="A", name2="B"): + if len(a) != len(b): + print(f"Number of outputs differ: {len(a)} vs {len(b)}") + return False + all_equal = True + for i, (x, y) in enumerate(zip(a, b)): + if not torch.allclose(x, y, atol=1e-3, rtol=1e-3): + print(f"Output {i} differs between {name1} and {name2}") + print(f"max diff: {torch.max(torch.abs(x - y))}") + print(f"Mean diff: {torch.mean(torch.abs(x - y))}") + all_equal = False + if all_equal: + print(f"All outputs match between {name1} and {name2}") + return all_equal + + compare_outputs(outs1, outs2, "trt_mod1", "trt_mod2") + if not use_python_runtime: + compare_outputs(outs1, outs3, "trt_mod1", "aoti_model_gb") + compare_outputs(outs1, outs4, "trt_mod1", "aoti_model_no_gb") + compare_outputs(outs2, outs3, "trt_mod2", "aoti_model") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--profile", action="store_true") + arg_parser.add_argument("--use_python_runtime", action="store_true") + arg_parser.add_argument( + "--model", type=str, default="resnet18", choices=["resnet18", "resnet152"] + ) + args = arg_parser.parse_args() + main(args)