diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c6895e7907..796e0690f3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -16,13 +16,21 @@ from torch_tensorrt.dynamo._defaults import ( DEBUG, DEVICE, + DISABLE_TF32, + DLA_GLOBAL_DRAM_SIZE, + DLA_LOCAL_DRAM_SIZE, + DLA_SRAM_SIZE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENGINE_CAPABILITY, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, + NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, PRECISION, + REFIT, REQUIRE_FULL_COMPILATION, + SPARSE_WEIGHTS, TRUNCATE_LONG_AND_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, @@ -51,17 +59,18 @@ def compile( inputs: Tuple[Any, ...], *, device: Optional[Union[Device, torch.device, str]] = DEVICE, - disable_tf32: bool = False, - sparse_weights: bool = False, + disable_tf32: bool = DISABLE_TF32, + sparse_weights: bool = SPARSE_WEIGHTS, enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), - refit: bool = False, + engine_capability: EngineCapability = ENGINE_CAPABILITY, + refit: bool = REFIT, debug: bool = DEBUG, capability: EngineCapability = EngineCapability.default, - num_avg_timing_iters: int = 1, + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, workspace_size: int = WORKSPACE_SIZE, - dla_sram_size: int = 1048576, - dla_local_dram_size: int = 1073741824, - dla_global_dram_size: int = 536870912, + dla_sram_size: int = DLA_SRAM_SIZE, + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, @@ -199,6 +208,13 @@ def compile( "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "require_full_compilation": require_full_compilation, + "disable_tf32": disable_tf32, + "sparse_weights": sparse_weights, + "refit": refit, + "engine_capability": engine_capability, + "dla_sram_size": dla_sram_size, + "dla_local_dram_size": dla_local_dram_size, + "dla_global_dram_size": dla_global_dram_size, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 103b5f7792..4ec872fb1b 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,19 +1,28 @@ import torch +from tensorrt import EngineCapability from torch_tensorrt._Device import Device PRECISION = torch.float32 DEBUG = False DEVICE = None +DISABLE_TF32 = False +DLA_LOCAL_DRAM_SIZE = 1073741824 +DLA_GLOBAL_DRAM_SIZE = 536870912 +DLA_SRAM_SIZE = 1048576 +ENGINE_CAPABILITY = EngineCapability.STANDARD WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False MAX_AUX_STREAMS = None +NUM_AVG_TIMING_ITERS = 1 VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None +SPARSE_WEIGHTS = False TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False +REFIT = False REQUIRE_FULL_COMPILATION = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c9f4534cb8..cd58c9547f 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -2,16 +2,25 @@ from typing import Optional, Set import torch +from tensorrt import EngineCapability from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, + DISABLE_TF32, + DLA_GLOBAL_DRAM_SIZE, + DLA_LOCAL_DRAM_SIZE, + DLA_SRAM_SIZE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENGINE_CAPABILITY, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, + NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, PRECISION, + REFIT, REQUIRE_FULL_COMPILATION, + SPARSE_WEIGHTS, TRUNCATE_LONG_AND_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, @@ -46,6 +55,14 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + disable_tf32 (bool): Whether to disable TF32 computation for TRT layers + sparse_weights (bool): Whether to allow the builder to use sparse weights + refit (bool): Whether to build a refittable engine + engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. + dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations + dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution """ precision: torch.dtype = PRECISION @@ -63,3 +80,11 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION + disable_tf32: bool = DISABLE_TF32 + sparse_weights: bool = SPARSE_WEIGHTS + refit: bool = REFIT + engine_capability: EngineCapability = ENGINE_CAPABILITY + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS + dla_sram_size: int = DLA_SRAM_SIZE + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 0f1c3b0c42..eec7e62516 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -23,8 +24,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -96,6 +95,7 @@ def __init__( self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata ] = dict() + self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes @@ -118,40 +118,25 @@ def validate_conversion(self) -> Set[str]: def run( self, - workspace_size: int = 0, - precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set - sparse_weights: bool = False, - disable_tf32: bool = False, force_fp32_output: bool = False, strict_type_constraints: bool = False, algorithm_selector: Optional[trt.IAlgorithmSelector] = None, timing_cache: Optional[trt.ITimingCache] = None, - profiling_verbosity: Optional[trt.ProfilingVerbosity] = None, tactic_sources: Optional[int] = None, - max_aux_streams: Optional[int] = None, - version_compatible: bool = False, - optimization_level: Optional[int] = None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. Args: - workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation. - precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). - sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity force_fp32_output: force output to be fp32 strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. algorithm_selector: set up algorithm selection for certain layer timing_cache: enable timing cache for TensorRT - profiling_verbosity: TensorRT logging level - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine - version_compatible: Provide version forward-compatibility for engine plan files - optimization_level: Builder optimization 0-5, higher levels imply longer build time, - searching for more optimization options. TRT defaults to 3 Return: TRTInterpreterResult """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) + precision = self.compilation_settings.precision # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = not force_fp32_output and precision == torch.float16 @@ -172,9 +157,9 @@ def run( builder_config = self.builder.create_builder_config() - if workspace_size != 0: + if self.compilation_settings.workspace_size != 0: builder_config.set_memory_pool_limit( - trt.MemoryPoolType.WORKSPACE, workspace_size + trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) cache = None @@ -187,21 +172,50 @@ def run( if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( - profiling_verbosity - if profiling_verbosity + trt.ProfilingVerbosity.VERBOSE + if self.compilation_settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) if version.parse(trt.__version__) >= version.parse("8.6"): - if max_aux_streams is not None: - _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") - builder_config.max_aux_streams = max_aux_streams - if version_compatible: + if self.compilation_settings.max_aux_streams is not None: + _LOGGER.info( + f"Setting max aux streams to {self.compilation_settings.max_aux_streams}" + ) + builder_config.max_aux_streams = ( + self.compilation_settings.max_aux_streams + ) + if self.compilation_settings.version_compatible: _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) - if optimization_level is not None: - _LOGGER.info(f"Using optimization level {optimization_level}") - builder_config.builder_optimization_level = optimization_level + if self.compilation_settings.optimization_level is not None: + _LOGGER.info( + f"Using optimization level {self.compilation_settings.optimization_level}" + ) + builder_config.builder_optimization_level = ( + self.compilation_settings.optimization_level + ) + + builder_config.engine_capability = self.compilation_settings.engine_capability + builder_config.avg_timing_iterations = ( + self.compilation_settings.num_avg_timing_iters + ) + + if self.compilation_settings.device.device_type == trt.DeviceType.DLA: + builder_config.DLA_core = self.compilation_settings.device.dla_core + _LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}") + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_MANAGED_SRAM, + self.compilation_settings.dla_sram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_LOCAL_DRAM, + self.compilation_settings.dla_local_dram_size, + ) + builder_config.set_memory_pool_limit( + trt.MemoryPoolType.DLA_GLOBAL_DRAM, + self.compilation_settings.dla_global_dram_size, + ) if precision == torch.float16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -209,12 +223,15 @@ def run( if precision == torch.int8: builder_config.set_flag(trt.BuilderFlag.INT8) - if sparse_weights: + if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) - if disable_tf32: + if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) + if self.compilation_settings.refit: + builder_config.set_flag(trt.BuilderFlag.REFIT) + if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1cdea63680..d39b7f35c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -3,6 +3,7 @@ import io from typing import Sequence +import tensorrt as trt import torch from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._settings import CompilationSettings @@ -10,8 +11,6 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs -import tensorrt as trt - def convert_module( module: torch.fx.GraphModule, @@ -54,18 +53,7 @@ def convert_module( output_dtypes=output_dtypes, compilation_settings=settings, ) - interpreter_result = interpreter.run( - workspace_size=settings.workspace_size, - precision=settings.precision, - profiling_verbosity=( - trt.ProfilingVerbosity.VERBOSE - if settings.debug - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY - ), - max_aux_streams=settings.max_aux_streams, - version_compatible=settings.version_compatible, - optimization_level=settings.optimization_level, - ) + interpreter_result = interpreter.run() if settings.use_python_runtime: return PythonTorchTensorRTModule( diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index be13f7d2c1..404f50a187 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -50,7 +50,6 @@ def run_test( interpreter, rtol, atol, - precision=torch.float, check_dtype=True, ): with torch.no_grad(): @@ -60,7 +59,7 @@ def run_test( mod.eval() start = time.perf_counter() - interpreter_result = interpreter.run(precision=precision) + interpreter_result = interpreter.run() sec = time.perf_counter() - start _LOGGER.info(f"Interpreter run time(s): {sec}") trt_mod = PythonTorchTensorRTModule( @@ -234,7 +233,9 @@ def run_test( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_long_and_double=True) + compilation_settings = CompilationSettings( + precision=precision, truncate_long_and_double=True + ) interp = TRTInterpreter( mod, @@ -248,7 +249,6 @@ def run_test( interp, rtol, atol, - precision, check_dtype, ) diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_compilation_settings.py new file mode 100644 index 0000000000..daa67ad032 --- /dev/null +++ b/tests/py/dynamo/runtime/test_compilation_settings.py @@ -0,0 +1,95 @@ +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + + +class TestEnableTRTFlags(TestCase): + def test_toggle_build_args(self): + class AddSoftmax(torch.nn.Module): + def forward(self, x): + x = 3 * x + y = x + 1 + return torch.softmax(y, 0) + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddSoftmax()) + + # Validate that the results between Torch and Torch-TRT are similar + # Enable multiple TRT build arguments + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + disable_tf32=True, + sparse_weights=True, + refit=True, + num_avg_timing_iters=5, + workspace_size=1 << 10, + truncate_long_and_double=True, + ) + + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"AddSoftmax TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_dla_args(self): + class AddSoftmax(torch.nn.Module): + def forward(self, x): + x = 3 * x + y = x + 1 + return torch.softmax(y, 0) + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddSoftmax()) + + # Validate that the results between Torch and Torch-TRT are similar + # Enable multiple TRT build arguments + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + device=torch_tensorrt.Device("dla:0", allow_gpu_fallback=True), + pass_through_build_failures=True, + dla_sram_size=1048577, + dla_local_dram_size=1073741825, + dla_global_dram_size=536870913, + ) + + # DLA is not present on the active machine + with self.assertRaises(RuntimeError): + optimized_model(*inputs).detach().cpu() + + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()