From 3751e32ca7498ef74e33183283970487479eae53 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 1 Jun 2023 15:35:55 -0700 Subject: [PATCH 1/3] fix: Upgrade Torch version to `2.1.0.dev20230605` - Upgrade Torch version across the stack - Update `setup.py` dependencies - Switch default required C++ version to C++17, as required for new Torch distributions --- .bazelrc | 2 +- .circleci/config.yml | 8 ++++---- CMakeLists.txt | 4 ++-- README.md | 2 +- WORKSPACE | 8 ++++---- py/requirements.txt | 4 ++-- py/setup.py | 2 ++ toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel | 8 ++++---- toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu | 8 ++++---- 9 files changed, 24 insertions(+), 22 deletions(-) diff --git a/.bazelrc b/.bazelrc index bcc2a11042..f9e0b4ab07 100644 --- a/.bazelrc +++ b/.bazelrc @@ -22,7 +22,7 @@ # +------------------------------------------------------------+ # Enable colorful output of GCC build --cxxopt="-fdiagnostics-color=always" -build --cxxopt='-std=c++14' +build --cxxopt='-std=c++17' #build --linkopt="-Wl,--no-as-needed" diff --git a/.circleci/config.yml b/.circleci/config.yml index 5422b31a5a..d286866fd1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -269,10 +269,10 @@ commands: parameters: torch-build: type: string - default: "2.1.0.dev20230419+cu118" + default: "2.1.0.dev20230605+cu118" torchvision-build: type: string - default: "0.16.0.dev20230419+cu118" + default: "0.16.0.dev20230605+cu118" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu118" @@ -1352,10 +1352,10 @@ parameters: # Nightly platform config torch-build: type: string - default: "2.1.0.dev20230419+cu118" + default: "2.1.0.dev20230605+cu118" torchvision-build: type: string - default: "0.16.0.dev20230419+cu118" + default: "0.16.0.dev20230605+cu118" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu118" diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b103b8d86..1fdd9390e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.17) project(Torch-TensorRT LANGUAGES CXX) -# use c++14 like PyTorch -set(CMAKE_CXX_STANDARD 14) +# use c++17 like PyTorch +set(CMAKE_CXX_STANDARD 17) # Build the libraries with -fPIC set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/README.md b/README.md index e600400d9a..8db561e2e1 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass. - Bazel 5.2.0 -- Libtorch 2.1.0.dev20230419 (built with CUDA 11.8) +- Libtorch 2.1.0.dev20230605 (built with CUDA 11.8) - CUDA 11.8 - cuDNN 8.8.0 - TensorRT 8.6.1 diff --git a/WORKSPACE b/WORKSPACE index 256e817592..4df265c64c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -51,17 +51,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834", + sha256 = "999becce82b73e566d0ffe010cd21fea8cf3a33f90f09dcc6b01150b820ae063", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac", + sha256 = "786cc728c63ea69c40bd8fb535cf8e5e1dfff1d43eaad3eb5256b9ed89c1b268", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) # Download these tarballs manually from the NVIDIA website diff --git a/py/requirements.txt b/py/requirements.txt index c25e9b2737..d49e95c9b6 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -2,7 +2,7 @@ numpy packaging pybind11==2.6.2 --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -torch==2.1.0.dev20230419+cu118 -torchvision==0.16.0.dev20230419+cu118 +torch==2.1.0.dev20230605+cu118 +torchvision==0.16.0.dev20230605+cu118 --extra-index-url https://pypi.ngc.nvidia.com tensorrt==8.6.1 diff --git a/py/setup.py b/py/setup.py index b870560ae5..eb382559f8 100644 --- a/py/setup.py +++ b/py/setup.py @@ -427,6 +427,8 @@ def run(self): ext_modules=ext_modules, install_requires=[ "torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0", + "pyyaml", + "packaging", ], setup_requires=[], cmdclass={ diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel index 263ed6d40b..23443f18a7 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834", + sha256 = "999becce82b73e566d0ffe010cd21fea8cf3a33f90f09dcc6b01150b820ae063", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac", + sha256 = "786cc728c63ea69c40bd8fb535cf8e5e1dfff1d43eaad3eb5256b9ed89c1b268", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) #################################################################################### diff --git a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu index 263ed6d40b..23443f18a7 100644 --- a/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu +++ b/toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu @@ -56,17 +56,17 @@ new_local_repository( http_archive( name = "libtorch", build_file = "@//third_party/libtorch:BUILD", - sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834", + sha256 = "999becce82b73e566d0ffe010cd21fea8cf3a33f90f09dcc6b01150b820ae063", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) http_archive( name = "libtorch_pre_cxx11_abi", build_file = "@//third_party/libtorch:BUILD", - sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac", + sha256 = "786cc728c63ea69c40bd8fb535cf8e5e1dfff1d43eaad3eb5256b9ed89c1b268", strip_prefix = "libtorch", - urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"], + urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230605%2Bcu118.zip"], ) #################################################################################### From 34638b664b352e220e02131edb004b638350ce23 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 20 Jun 2023 13:21:53 -0700 Subject: [PATCH 2/3] feat: Add `options` kwargs for Torch compile [3 / x] (#2005) --- py/torch_tensorrt/dynamo/backend/backends.py | 13 +++---- py/torch_tensorrt/dynamo/backend/utils.py | 39 ++++++++++++++++++++ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..cf869562b6 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -12,6 +12,7 @@ partition, get_submod_inputs, ) +from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs from torch_tensorrt.dynamo.backend.conversion import convert_module from torch._dynamo.backends.common import fake_tensor_unsupported @@ -25,22 +26,20 @@ @td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported def torch_tensorrt_backend( - gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], - settings: CompilationSettings = CompilationSettings(), + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs ): DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend - return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) @td.register_backend(name="aot_torch_tensorrt_aten") @fake_tensor_unsupported def aot_torch_tensorrt_aten_backend( - gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], - settings: CompilationSettings = CompilationSettings(), + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs ): + settings = parse_dynamo_kwargs(kwargs) + custom_backend = partial( _pretraced_backend, settings=settings, diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index e6e22d5f96..9396373790 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -1,9 +1,15 @@ import torch +import logging +from dataclasses import replace, fields +from torch_tensorrt.dynamo.backend._settings import CompilationSettings from typing import Any, Union, Sequence, Dict from torch_tensorrt import _Input, Device +logger = logging.getLogger(__name__) + + def prepare_inputs( inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict], device: torch.device = torch.device("cuda"), @@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device: ) return device + + +def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: + """Parses the kwargs field of a Dynamo backend + + Args: + kwargs: Keyword arguments dictionary provided to the backend + Returns: + CompilationSettings object with relevant kwargs + """ + + # Initialize an empty CompilationSettings object + settings = CompilationSettings() + + # If the user specifies keyword args, overwrite those fields in settings + # Validate all specified kwargs to ensure they are true fields of the dataclass + # + # Note: kwargs provided by torch.compile are wrapped in the "options" key + if kwargs: + if "options" in kwargs and len(kwargs) == 1: + kwargs = kwargs["options"] + + valid_attrs = {attr.name for attr in fields(settings)} + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} + settings = replace(settings, **valid_kwargs) + + # Enable debug/verbose mode if requested + if settings.debug: + logger.setLevel(logging.DEBUG) + + logger.debug(f"Compiling with Settings:\n{settings}") + + return settings From 784ec9bbe8633520d92531361ef9423bd0e4d6dd Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 20 Jun 2023 14:28:42 -0700 Subject: [PATCH 3/3] feat: Add support for output data types in `TRTInterpreter` [2 / x] (#2004) --- .../dynamo/backend/conversion.py | 10 +++++++++ .../dynamo/fx_ts_compat/fx2trt.py | 22 ++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index f359020bfb..f2631f0c87 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -24,11 +24,21 @@ def convert_module( Returns: TRTModule or TRTModuleNext """ + # Specify module output data types to ensure TRT output types agree with + # that of the equivalent Torch module + module_outputs = module(*inputs) + + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + output_dtypes = list(output.dtype for output in module_outputs) + interpreter = TRTInterpreter( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, ) interpreter_result = interpreter.run( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index e4298600cb..444efc0f4e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -1,6 +1,7 @@ import logging import warnings from datetime import datetime +from packaging import version from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence import numpy @@ -40,6 +41,7 @@ def __init__( explicit_batch_dimension: bool = True, explicit_precision: bool = False, logger_level=None, + output_dtypes=None, ): super().__init__(module) @@ -78,6 +80,9 @@ def __init__( trt.tensorrt.ITensor, TensorMetadata ] = dict() + # Data types for TRT Module output Tensors + self.output_dtypes = output_dtypes + def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: @@ -178,13 +183,17 @@ def run( algorithm_selector: set up algorithm selection for certain layer timing_cache: enable timing cache for TensorRT profiling_verbosity: TensorRT logging level + max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine + version_compatible: Provide version forward-compatibility for engine plan files + optimization_level: Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 Return: TRTInterpreterResult """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and - # force_fp32_output=False. + # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = ( not force_fp32_output and lower_precision == LowerPrecision.FP16 ) @@ -224,14 +233,14 @@ def run( cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) - if trt.__version__ >= "8.2": + if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) - if trt.__version__ >= "8.6": + if version.parse(trt.__version__) >= version.parse("8.6"): if max_aux_streams is not None: _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") builder_config.max_aux_streams = max_aux_streams @@ -372,6 +381,11 @@ def output(self, target, args, kwargs): if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") + if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): + raise RuntimeError( + f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" + ) + for i, output in enumerate(outputs): if any( op_name in output.name.split("_") @@ -396,6 +410,8 @@ def output(self, target, args, kwargs): self.network.mark_output(output) if output_bool: output.dtype = trt.bool + elif self.output_dtypes is not None: + output.dtype = torch_dtype_to_trt(self.output_dtypes[i]) elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name)