diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 3743b263db..0a3096e6a0 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -4,7 +4,7 @@ import torch_tensorrt from functools import partial -from typing import Any, Sequence +from typing import Any, Optional, Sequence from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision @@ -14,9 +14,13 @@ from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, - MAX_WORKSPACE_SIZE, + WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_EXPERIMENTAL_RT, ) @@ -35,7 +39,7 @@ def compile( debug=DEBUG, capability=EngineCapability.default, num_avg_timing_iters=1, - workspace_size=MAX_WORKSPACE_SIZE, + workspace_size=WORKSPACE_SIZE, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, @@ -45,6 +49,10 @@ def compile( min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], torch_executed_modules=[], + max_aux_streams=MAX_AUX_STREAMS, + version_compatible=VERSION_COMPATIBLE, + optimization_level=OPTIMIZATION_LEVEL, + use_experimental_rt=USE_EXPERIMENTAL_RT, **kwargs, ): if debug: @@ -86,6 +94,10 @@ def compile( workspace_size=workspace_size, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_experimental_rt=use_experimental_rt, **kwargs, ) @@ -105,10 +117,14 @@ def compile( def create_backend( precision: LowerPrecision = PRECISION, debug: bool = DEBUG, - workspace_size: int = MAX_WORKSPACE_SIZE, + workspace_size: int = WORKSPACE_SIZE, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Sequence[str] = set(), pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = MAX_AUX_STREAMS, + version_compatible: bool = VERSION_COMPATIBLE, + optimization_level: Optional[int] = OPTIMIZATION_LEVEL, + use_experimental_rt: bool = USE_EXPERIMENTAL_RT, **kwargs, ): """Create torch.compile backend given specified arguments @@ -116,8 +132,15 @@ def create_backend( Args: precision: debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision + workspace_size: Workspace TRT is allowed to use for the module (0 is default) + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) + max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine + version_compatible: Provide version forward-compatibility for engine plan files + optimization_level: Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 + use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines Returns: Backend for torch.compile """ @@ -131,6 +154,10 @@ def create_backend( min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, pass_through_build_failures=pass_through_build_failures, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_experimental_rt=use_experimental_rt, ) return partial( diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index fe7b5f6b4f..286c60c2fa 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -3,6 +3,10 @@ PRECISION = LowerPrecision.FP32 DEBUG = False -MAX_WORKSPACE_SIZE = 20 << 30 +WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False +MAX_AUX_STREAMS = None +VERSION_COMPATIBLE = False +OPTIMIZATION_LEVEL = None +USE_EXPERIMENTAL_RT = False diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index df3212f54a..7ec4cc596e 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,13 +1,17 @@ from dataclasses import dataclass, field -from typing import Sequence +from typing import Optional, Sequence from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, - MAX_WORKSPACE_SIZE, + WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_EXPERIMENTAL_RT, ) @@ -15,7 +19,11 @@ class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG - workspace_size: int = MAX_WORKSPACE_SIZE + workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Sequence[str] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES + max_aux_streams: Optional[int] = MAX_AUX_STREAMS + version_compatible: bool = VERSION_COMPATIBLE + optimization_level: Optional[int] = OPTIMIZATION_LEVEL + use_experimental_rt: bool = USE_EXPERIMENTAL_RT diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..78e2172fa8 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -135,6 +135,7 @@ def _compile_module( submodule, submodule_inputs, settings=settings, + name=name, ) # Replace FX Module with TRT Module diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 1644dea547..85a63a80a8 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -1,9 +1,10 @@ from typing import Sequence, Union import torch +import io from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt import TRTModuleNext from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.fx.fx2trt import ( +from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( InputTensorSpec, TRTInterpreter, ) @@ -15,30 +16,50 @@ def convert_module( module: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), + name: str = "", ) -> Union[TRTModuleNext, TRTModule]: """Convert an FX module to a TRT module Args: module: FX GraphModule to convert inputs: Sequence of Tensors representing inputs to the module settings: Compilation settings + name: TRT engine name Returns: TRTModule or TRTModuleNext """ - interp = TRTInterpreter( + interpreter = TRTInterpreter( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), ) - r = interp.run( - max_workspace_size=settings.workspace_size, + interpreter_result = interpreter.run( + workspace_size=settings.workspace_size, lower_precision=settings.precision, profiling_verbosity=( trt.ProfilingVerbosity.VERBOSE if settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ), + max_aux_streams=settings.max_aux_streams, + version_compatible=settings.version_compatible, + optimization_level=settings.optimization_level, ) - return TRTModule(*r) + if settings.use_experimental_rt: + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + return TRTModuleNext( + serialized_engine=engine_str, + name=name, + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + ) + else: + return TRTModule( + engine=interpreter_result.engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index e4298600cb..94249cb70a 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -1,6 +1,7 @@ import logging import warnings from datetime import datetime +from packaging import version from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence import numpy @@ -224,14 +225,14 @@ def run( cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) - if trt.__version__ >= "8.2": + if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) - if trt.__version__ >= "8.6": + if version.parse(trt.__version__) >= version.parse("8.6"): if max_aux_streams is not None: _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") builder_config.max_aux_streams = max_aux_streams