diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 1bca7869e3..13cbe3a126 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -52,6 +52,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); + multi_gpu_device_check(); set_rt_device(device_info); rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 2a7fe884da..5551010a2a 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -74,7 +74,7 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); } - { + if (MULTI_DEVICE_SAFE_MODE) { std::unique_ptr device_profiler_guard; if (compiled_engine->profile_execution) { device_profiler_guard = diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c5b9118fee..1acc27dda5 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) { m.def("execute_engine", execute_engine); m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); }); m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; }); + m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; }); + m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { + MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; + }); } } // namespace diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index 0372258919..2d7f7f1198 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -7,6 +7,8 @@ namespace torch_tensorrt { namespace core { namespace runtime { +bool MULTI_DEVICE_SAFE_MODE = false; + c10::optional get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) { LOG_DEBUG("Target Device: " << target_device); auto device_options = find_compatible_devices(target_device); @@ -31,13 +33,13 @@ c10::optional get_most_compatible_device(const RTDevice& target_device if (device.device_name == target_device.device_name) { // First priority is selecting a candidate which agrees with the current device ID // If such a device is found, we can select it and break out of the loop - if (device.id == current_device.id && best_match.id != current_device.id) { + if (device.id == current_device.id) { best_match = device; break; } // Second priority is selecting a candidate which agrees with the target device ID // At deserialization time, the current device and target device may not agree - else if (device.id == target_device.id && best_match.id != target_device.id) { + else if (device.id == target_device.id) { best_match = device; } // If no such GPU ID is found, select the first available candidate GPU @@ -103,6 +105,17 @@ RTDevice get_current_device() { return RTDevice(device_id, nvinfer1::DeviceType::kGPU); } +void multi_gpu_device_check() { + // If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user + if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) { + LOG_WARNING( + "Detected this engine is being instantitated in a multi-GPU system with " + << "multi-device safe mode disabled. For more on the implications of this " + << "as well as workarounds, see the linked documentation " + << "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)"); + } +} + namespace { static DeviceList cuda_device_list; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 05d97a30b8..ea863850ba 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,6 +16,7 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "4"; +extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, @@ -33,6 +34,8 @@ std::vector find_compatible_devices(const RTDevice& target_device); std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine); +void multi_gpu_device_check(); + class DeviceList { using DeviceMap = std::unordered_map; DeviceMap device_list; diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index 0cfc93200f..8264abdd32 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -34,3 +34,37 @@ Plugin Library In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries. + +Multi Device Safe Mode +--------------- + +Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether +the runtime checks for device consistency prior to every inference call. + +There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why +it is now disabled by default. It can be controlled via the following convenience function which +doubles as a context manager. + +.. code-block:: python + + # Enables Multi Device Safe Mode + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + # Disables Multi Device Safe Mode [Default Behavior] + torch_tensorrt.runtime.set_multi_device_safe_mode(False) + + # Enables Multi Device Safe Mode, then resets the safe mode to its prior setting + with torch_tensorrt.runtime.set_multi_device_safe_mode(True): + ... + +TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked. +Therefore, if the device were to change in the active thread, which may be the case when invoking +engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display +an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine +device and CUDA context device, which could lead the program to crash. + +One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for +multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines +on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is +responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same +Python script without needing to switch CUDA contexts and incur performance overhead. diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index c015bd89db..b9d2af39c5 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str: from torch_tensorrt._Device import Device # noqa: F401 from torch_tensorrt._enums import * # noqa: F403 from torch_tensorrt._Input import Input # noqa: F401 -from torch_tensorrt.logging import * -from torch_tensorrt.ptq import * from torch_tensorrt._utils import * # noqa: F403 from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt.logging import * +from torch_tensorrt.ptq import * +from torch_tensorrt.runtime import * # noqa: F403 if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt import dynamo # noqa: F401 from torch_tensorrt.dynamo import backend # noqa: F401 + from torch_tensorrt import dynamo # noqa: F401 + def _register_with_torch() -> None: trtorch_dir = os.path.dirname(__file__) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1cdea63680..ea00113a23 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,6 +3,7 @@ import io from typing import Sequence +import tensorrt as trt import torch from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings @@ -10,8 +11,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs -import tensorrt as trt - def convert_module( module: torch.fx.GraphModule, @@ -72,6 +71,8 @@ def convert_module( engine=interpreter_result.engine, input_names=list(interpreter_result.input_names), output_names=list(interpreter_result.output_names), + target_device=settings.device, + profiling_enabled=settings.debug, ) else: diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 55df3cb2b3..5bdbb8919b 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -42,10 +42,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or node.op == "get_attr" ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator - if not node.is_impure(): + if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: self.supported_operators[node_name] = 1 else: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index f6149a2271..092bdabfd0 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -150,10 +150,10 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or node.op == "get_attr" ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator - if not node.is_impure(): + if not node.is_impure() and node.op != "get_attr": if node_name not in self.supported_operators: self.supported_operators[node_name] = 1 else: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 41baecc7ab..db45609123 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,13 +1,22 @@ from __future__ import annotations import logging +from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple import tensorrt as trt import torch from torch.nn import Module +from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo.runtime.tools import ( + _is_switch_required, + _select_rt_device, + multi_gpu_device_check, +) from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter +import torch_tensorrt + logger = logging.getLogger(__name__) @@ -23,13 +32,26 @@ def __init__( engine: trt.ICudaEngine, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, + target_device: Device = Device._current_device(), + profiling_enabled: Optional[bool] = None, ): super(PythonTorchTensorRTModule, self).__init__() self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict) + + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + self.engine = engine self.input_names = input_names if input_names is not None else [] self.output_names = output_names if output_names is not None else [] self.initialized = False + self.target_device_id = target_device.gpu_id + self.target_device_properties = torch.cuda.get_device_properties( + self.target_device_id + ) + self.profiling_enabled = ( + profiling_enabled if profiling_enabled is not None else False + ) self._initialize() def _initialize(self) -> None: @@ -119,6 +141,9 @@ def _load_from_state_dict( ) -> None: engine_bytes = state_dict[prefix + "engine"] + # Run multi-gpu device check to validate engine instantiation + multi_gpu_device_check() + logger = trt.Logger() runtime = trt.Runtime(logger) self.engine = runtime.deserialize_cuda_engine(engine_bytes) @@ -141,15 +166,43 @@ def __setstate__(self, state: Dict[str, Any]) -> None: if self.engine: self.context = self.engine.create_execution_context() - def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:Forward" - ): + ) if self.profiling_enabled else nullcontext(): self._check_initialized() + # If in safe mode, check at each iteration for for whether a switch is required + if ( + torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, + ) + device = torch.device(device_id) + torch.cuda.set_device(device_id) + + inputs = tuple([tensor.to(device) for tensor in inputs]) + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessInputs" - ): + ) if self.profiling_enabled else nullcontext(): assert len(inputs) == len( self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." @@ -188,7 +241,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessOutputs" - ): + ) if self.profiling_enabled else nullcontext(): # create output tensors outputs: List[torch.Tensor] = [] @@ -215,7 +268,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" - ): + ) if self.profiling_enabled else nullcontext(): self.context.execute_async_v2( bindings, torch.cuda.current_stream().cuda_stream ) @@ -235,6 +288,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: if not self.context.profiler: self.context.profiler = trt.Profiler() if profiler is None else profiler + self.profiling_enabled = True + def disable_profiling(self) -> None: """ Disable TensorRT profiling. @@ -244,6 +299,7 @@ def disable_profiling(self) -> None: torch.cuda.synchronize() del self.context self.context = self.engine.create_execution_context() + self.profiling_enabled = False def get_layer_info(self) -> str: """ diff --git a/py/torch_tensorrt/dynamo/runtime/tools.py b/py/torch_tensorrt/dynamo/runtime/tools.py new file mode 100644 index 0000000000..75c83a4f60 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/tools.py @@ -0,0 +1,131 @@ +import logging +from typing import Optional, Tuple + +import torch + +import torch_tensorrt + +logger = logging.getLogger(__name__) + + +def multi_gpu_device_check() -> None: + # If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user + if ( + not torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + and torch.cuda.device_count() > 1 + ): + logger.warning( + "Detected this engine is being instantitated in a multi-GPU system with " + "multi-device safe mode disabled. For more on the implications of this " + "as well as workarounds, see the linked documentation " + "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode). " + f"The engine is set to be instantiated on the current default cuda device, cuda:{torch.cuda.current_device()}. " + "If this is incorrect, please set the desired cuda device via torch.cuda.set_device(...) and retry." + ) + + +def _is_switch_required( + curr_device_id: int, + engine_device_id: int, + curr_device_properties: torch._C._CudaDeviceProperties, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> bool: + """Determines whether a device switch is required based on input device parameters""" + # Device Capabilities disagree + if (curr_device_properties.major, curr_device_properties.minor) != ( + engine_device_properties.major, + engine_device_properties.minor, + ): + logger.warning( + f"Configured SM capability {(engine_device_properties.major, engine_device_properties.minor)} does not match with " + f"current device SM capability {(curr_device_properties.major, curr_device_properties.minor)}. Switching device context." + ) + + return True + + # Names disagree + if curr_device_properties.name != engine_device_properties.name: + logger.warning( + f"Program compiled for {engine_device_properties.name} but current CUDA device is " + f"current device SM capability {curr_device_properties.name}. Attempting to switch device context for better compatibility." + ) + + return True + + # Device IDs disagree + if curr_device_id != engine_device_id: + logger.warning( + f"Configured Device ID: {engine_device_id} is different than current device ID: " + f"{curr_device_id}. Attempting to switch device context for better compatibility." + ) + + return True + + return False + + +def _select_rt_device( + curr_device_id: int, + engine_device_id: int, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> Tuple[int, torch._C._CudaDeviceProperties]: + """Wraps compatible device check and raises error if none are found""" + new_target_device_opt = _get_most_compatible_device( + curr_device_id, engine_device_id, engine_device_properties + ) + + assert ( + new_target_device_opt is not None + ), "Could not find a compatible device on the system to run TRT Engine" + + return new_target_device_opt + + +def _get_most_compatible_device( + curr_device_id: int, + engine_device_id: int, + engine_device_properties: torch._C._CudaDeviceProperties, +) -> Optional[Tuple[int, torch._C._CudaDeviceProperties]]: + """Selects a runtime device based on compatibility checks""" + all_devices = [ + (i, torch.cuda.get_device_properties(i)) + for i in range(torch.cuda.device_count()) + ] + logger.debug(f"All available devices: {all_devices}") + target_device_sm = (engine_device_properties.major, engine_device_properties.minor) + + # Any devices with the same SM capability are valid candidates + candidate_devices = [ + (i, device_properties) + for i, device_properties in all_devices + if (device_properties.major, device_properties.minor) == target_device_sm + ] + + logger.debug(f"Found candidate devices: {candidate_devices}") + + # If less than 2 candidates are found, return + if len(candidate_devices) <= 1: + return candidate_devices[0] if candidate_devices else None + + # If more than 2 candidates are found, select the best match + best_match = None + + for candidate in candidate_devices: + i, device_properties = candidate + # First priority is selecting a candidate which agrees with the current device ID + # If such a device is found, we can select it and break out of the loop + if device_properties.name == engine_device_properties.name: + if i == curr_device_id: + best_match = candidate + break + + # Second priority is selecting a candidate which agrees with the target device ID + # At deserialization time, the current device and target device may not agree + elif i == engine_device_id: + best_match = candidate + + # If no such GPU ID is found, select the first available candidate GPU + elif best_match is None: + best_match = candidate + + return best_match diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py new file mode 100644 index 0000000000..29895c83d5 --- /dev/null +++ b/py/torch_tensorrt/runtime/__init__.py @@ -0,0 +1 @@ +from .multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/multi_device_safe_mode.py b/py/torch_tensorrt/runtime/multi_device_safe_mode.py new file mode 100644 index 0000000000..0ddd900ab6 --- /dev/null +++ b/py/torch_tensorrt/runtime/multi_device_safe_mode.py @@ -0,0 +1,51 @@ +import logging +from importlib.util import find_spec +from typing import Any + +import torch + +if find_spec("torch_tensorrt._C") is not None: + _PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode() +else: + _PY_RT_MULTI_DEVICE_SAFE_MODE = False + + +logger = logging.getLogger(__name__) + + +class _MultiDeviceSafeModeContextManager(object): + """Helper class used in conjunction with `set_multi_device_safe_mode` + + Used to enable `set_multi_device_safe_mode` as a dual-purpose context manager + """ + + def __init__(self, old_mode: bool) -> None: + self.old_mode = old_mode + + def __enter__(self) -> "_MultiDeviceSafeModeContextManager": + return self + + def __exit__(self, *args: Any) -> None: + # Set multi-device safe mode back to old mode in Python + global _PY_RT_MULTI_DEVICE_SAFE_MODE + _PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode + + # Set multi-device safe mode back to old mode in C++ + if find_spec("torch_tensorrt._C") is not None: + torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode) + + +def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager: + # Fetch existing safe mode and set new mode for Python + global _PY_RT_MULTI_DEVICE_SAFE_MODE + old_mode = _PY_RT_MULTI_DEVICE_SAFE_MODE + _PY_RT_MULTI_DEVICE_SAFE_MODE = mode + + # Set new mode for C++ + if find_spec("torch_tensorrt._C") is not None: + torch.ops.tensorrt.set_multi_device_safe_mode(mode) + + logger.info(f"Set multi-device safe mode to {mode}") + + # Return context manager in case the function is used in a `with` call + return _MultiDeviceSafeModeContextManager(old_mode) diff --git a/setup.py b/setup.py index 82f1ac42f7..38d2121461 100644 --- a/setup.py +++ b/setup.py @@ -403,6 +403,7 @@ def run(self): "torch_tensorrt.fx.tracer", "torch_tensorrt.fx.tracer.acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer", + "torch_tensorrt.runtime", ] package_dir = { @@ -430,6 +431,7 @@ def run(self): "torch_tensorrt.fx.tracer": "py/torch_tensorrt/fx/tracer", "torch_tensorrt.fx.tracer.acc_tracer": "py/torch_tensorrt/fx/tracer/acc_tracer", "torch_tensorrt.fx.tracer.dispatch_tracer": "py/torch_tensorrt/fx/tracer/dispatch_tracer", + "torch_tensorrt.runtime": "py/torch_tensorrt/runtime", } package_data = {} diff --git a/tests/py/dynamo/runtime/test_safe_mode.py b/tests/py/dynamo/runtime/test_safe_mode.py new file mode 100644 index 0000000000..bd196b12f0 --- /dev/null +++ b/tests/py/dynamo/runtime/test_safe_mode.py @@ -0,0 +1,105 @@ +import torch +from torch.testing._internal.common_utils import TestCase, run_tests + +import torch_tensorrt + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + + +class TestSafeMode(TestCase): + def test_multi_device_safe_mode_on(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + self.assertTrue(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_off(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(False) + self.assertFalse(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_context(self): + with torch_tensorrt.runtime.set_multi_device_safe_mode(True): + self.assertTrue(torch.ops.tensorrt.get_multi_device_safe_mode()) + self.assertFalse(torch.ops.tensorrt.get_multi_device_safe_mode()) + + def test_multi_device_safe_mode_enabled_inference_python(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode Python TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_multi_device_safe_mode_enabled_inference_cpp(self): + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x + 2) * 7, dim=0) + + inputs = [ + torch.randn( + 3, + 5, + 7, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=False, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Safe Mode C++ TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()