Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ TRTEngine::TRTEngine(

runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
runtime_states.old_pre_allocated_outputs = false;
runtime_states.context_changed = false;

if (_in_binding_names.size() == 0 && _out_binding_names.size() == 0) {
uint64_t inputs = 0;
Expand Down Expand Up @@ -310,6 +311,9 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
if (profile_execution) {
enable_profiling();
}
// Indicates to reevaluate the runtime settings
runtime_states.context_changed = true;

return result;
}

Expand Down
20 changes: 16 additions & 4 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,37 @@ struct TorchTRTRuntimeStates {
bool old_cudagraphs;
// Indicates whether pre-allocated output was enabled in the previous execute_engine
bool old_pre_allocated_outputs;
// Indicates whether context has changed
bool context_changed;

// Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
// Evaluates whether certain conditions are met to enable CUDA Graph recording/reset or to reuse pre-allocated outputs
// based on the current and previous states, as well as input shape has changed
std::tuple<bool, bool> set_runtime_states(bool new_cudagraphs, bool new_pre_allocated_output, bool shape_changed) {
std::tuple<bool, bool, bool> set_runtime_states(
bool new_cudagraphs,
bool new_pre_allocated_output,
bool shape_changed) {
bool need_cudagraphs_record = false;
bool can_use_pre_allocated_outputs = false;
bool need_cudagraphs_reset = false;

// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
if (new_cudagraphs && (!old_cudagraphs || shape_changed)) {
if (new_cudagraphs && (!old_cudagraphs || shape_changed || context_changed)) {
need_cudagraphs_record = true;
}
// Pre-allocated output can be used when previous and current state are true without shape change
if (old_pre_allocated_outputs && new_pre_allocated_output && !shape_changed) {
can_use_pre_allocated_outputs = true;
}
if (!new_cudagraphs || shape_changed || context_changed) {
need_cudagraphs_reset = true;
}

old_cudagraphs = new_cudagraphs;
old_pre_allocated_outputs = new_pre_allocated_output;
// Reset flag
context_changed = false;

return {need_cudagraphs_record, can_use_pre_allocated_outputs};
return {need_cudagraphs_record, can_use_pre_allocated_outputs, need_cudagraphs_reset};
}
};

Expand Down
3 changes: 2 additions & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

bool need_cudagraphs_record = std::get<0>(result);
bool can_use_pre_allocated_outputs = std::get<1>(result);
bool need_cudagraphs_reset = std::get<2>(result);

if (!cudagraphs_enabled || shape_changed) {
if (need_cudagraphs_reset) {
compiled_engine->cudagraph.reset();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def __init__(
super(CudaGraphsTorchTensorRTModule, self).__init__()
self.compiled_module = compiled_module
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
self.is_weight_streaming_set = False

self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.shape_key: Optional[str] = None
self.prev_cudagraphs_enabled = False
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
self.warm_up()
Expand Down Expand Up @@ -77,15 +77,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed
self.prev_cudagraphs_enabled = cudagraphs_enabled

need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.inputs)

self.is_weight_streaming_set = False
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(
Expand Down
57 changes: 37 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,33 @@


class TorchTRTRuntimeStates:
def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
def __init__(self, new_cudagraphs: bool):
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
self.old_cudagraphs = new_cudagraphs
# Indicates whether pre-allocated output was enabled in the previous execute_engine
self.old_pre_allocated_outputs = new_pre_allocated_output
self.old_pre_allocated_outputs = False
# Indicates whether context has changed
self.context_changed = False

def set_runtime_states(
self,
new_cudagraphs: bool,
new_pre_allocated_output: bool,
shape_changed: bool,
) -> Tuple[bool, bool]:
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
) -> Tuple[bool, bool, bool]:
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs
# based on the current and previous states, as well as input shape has changed
need_cudagraphs_record = False
can_use_pre_allocated_outputs = False

# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
if new_cudagraphs and (not self.old_cudagraphs or shape_changed):
need_cudagraphs_reset = False

# CUDA Graph recording is needed if CUDA graphs is enabled and:
# - CUDA graphs were previously disabled
# - or the shape has changed
# - or the execution context has changed (e.g., weight streaming)
if new_cudagraphs and (
not self.old_cudagraphs or shape_changed or self.context_changed
):
need_cudagraphs_record = True

# Pre-allocated output can be used when previous and current state are true without shape change
Expand All @@ -53,10 +61,19 @@ def set_runtime_states(
):
can_use_pre_allocated_outputs = True

if not new_cudagraphs or shape_changed or self.context_changed:
need_cudagraphs_reset = True

self.old_cudagraphs = new_cudagraphs
self.old_pre_allocated_outputs = new_pre_allocated_output
# reset flag
self.context_changed = False

return need_cudagraphs_record, can_use_pre_allocated_outputs
return (
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
)


class PythonTorchTensorRTModule(Module): # type: ignore[misc]
Expand Down Expand Up @@ -145,7 +162,7 @@ def __init__(
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.runtime_states = TorchTRTRuntimeStates(
torch_tensorrt.runtime.get_cudagraphs_mode(), False
torch_tensorrt.runtime.get_cudagraphs_mode()
)
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = False
Expand All @@ -168,6 +185,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
del self.context
budget_bytes = self._set_device_memory_budget(budget_bytes)
self.context = self.engine.create_execution_context()
self.runtime_states.context_changed = True
return budget_bytes

def _set_device_memory_budget(self, budget_bytes: int) -> int:
Expand Down Expand Up @@ -200,7 +218,6 @@ def setup_engine(self) -> None:
if self.settings.enable_weight_streaming:
self.set_default_device_memory_budget()
self.context = self.engine.create_execution_context()

assert self.engine.num_io_tensors == (
len(self.input_names) + len(self.output_names)
)
Expand Down Expand Up @@ -356,22 +373,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record, can_use_pre_allocated_outputs = (
self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)
(
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
) = self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()
self._input_buffers = [None] * len(self.input_names)
self._output_buffers = [None] * len(self.output_names)

if not cudagraphs_enabled and self.cudagraph:
self.cudagraph.reset()
self.cudagraph = None

# If in safe mode, check at each iteration for whether a switch is required
if (
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def __init__(
) -> None:
rt_mods = []
self.current_device_budget = 0
self.cuda_graphs_module = None

if isinstance(module, CudaGraphsTorchTensorRTModule):
self.cuda_graphs_module = module
module = module.compiled_module
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
Expand Down Expand Up @@ -78,6 +80,8 @@ def _set_streamable_weight_bytes(self, requested_budget: int) -> int:
ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size[i])
logger.debug(f"Set weight streaming size {normalized_size[i]} for {name}")

if self.cuda_graphs_module:
self.cuda_graphs_module.is_weight_streaming_set = True
return ws_budget_bytes

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
Loading
Loading