Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ std::vector<std::string> TRTEngine::serialize() {
return serialized_info;
}

void TRTEngine::reset_captured_graph() {
cudagraph.reset();
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct TRTEngine : torch::CustomClassHolder {
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

void set_profiling_paths();
void reset_captured_graph();
#ifndef NDEBUG
bool profile_execution = true;
#else
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:

return False

def __del__(self) -> None:
def reset_captured_graph(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_reset_captured_graph

if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

def __del__(self) -> None:
self.reset_captured_graph()

def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable
Expand All @@ -119,8 +123,7 @@ def forward(
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self.reset_captured_graph()
self._input_buffers = [None] * len(inputs)

self.is_weight_streaming_set = False
Expand Down Expand Up @@ -196,7 +199,5 @@ def forward(
return outputs[0]
return outputs
else:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
self.reset_captured_graph()
return self.compiled_module(*args, **kwargs)
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
result.__setstate__(self.__getstate__())
return result

def __del__(self) -> None:
def reset_captured_graph(self) -> None:
if self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

def __del__(self) -> None:
self.reset_captured_graph()

def setup_input_tensors(
self,
Expand Down Expand Up @@ -426,9 +430,8 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None
if need_cudagraphs_reset:
self.reset_captured_graph()

if need_cudagraphs_record:
self._input_buffers = [None] * len(self.input_names)
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:

return budget_bytes

def reset_captured_graph(self) -> None:
self.engine.reset_captured_graph()

def setup_engine(self) -> None:
"""
Setup engine for a module which has deferred engine setup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def automatic_device_memory_budget_getter(self) -> Any:
def infer_outputs(self, input_shapes: List[Any]) -> Any:
pass

def reset_captured_graph(self) -> Any:
pass

def __setstate__(self, serialized_state: List[str]) -> Any:
pass

Expand Down
9 changes: 7 additions & 2 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Union
from typing import Any, Optional, Union

import torch
import torch_tensorrt
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, compiled_module: torch.nn.Module) -> None:
global _PY_RT_CUDAGRAPHS
self.old_mode = _PY_RT_CUDAGRAPHS
self.compiled_module = compiled_module
self.cudagraphs_module: Optional[CudaGraphsTorchTensorRTModule] = None

def __enter__(self) -> torch.nn.Module:
global _PY_RT_CUDAGRAPHS
Expand Down Expand Up @@ -98,7 +99,8 @@ def __enter__(self) -> torch.nn.Module:
logger.debug(
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
)
return CudaGraphsTorchTensorRTModule(self.compiled_module)
self.cudagraphs_module = CudaGraphsTorchTensorRTModule(self.compiled_module)
return self.cudagraphs_module
else:
if num_trt_module > 0:
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
Expand All @@ -113,6 +115,9 @@ def __enter__(self) -> torch.nn.Module:
def __exit__(self, *args: Any) -> None:
# Set cudagraphs back to old mode
set_cudagraphs_mode(self.old_mode)
# __del__ is not entirely predictable, so we reset cudagraph here
if self.cudagraphs_module:
self.cudagraphs_module.reset_captured_graph()


def enable_cudagraphs(
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,15 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
int(streamable_bytes / total_bytes * requested_budget)
for streamable_bytes in self.streamable_budget
]
if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
self.cuda_graphs_module.reset_captured_graph()

for i, (name, rt_mod) in enumerate(self.rt_mods):
rt_mod.reset_captured_graph()
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
Loading