From 482265fa8a3fa5e7278d13723d926909750af1d4 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 17 Jul 2021 16:31:20 -0700 Subject: [PATCH] feat(//py)!: Implementing top level python api changes to reflect new Input type and enabled_precisions set BREAKING CHANGE: This commit introduces the next iteration of the Python TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes" and "op_precision" compile spec keys will be removed. Users should port forward to using the "inputs" key which expects a list of trtorch.Input objects and the "enabled_precisions" key which expects a set of data type specifying enums. Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion.cpp | 16 +- core/ir/Input.cpp | 5 +- py/setup.py | 2 +- py/trtorch/Input.py | 181 ++++++++++++++++++ py/trtorch/__init__.py | 1 + py/trtorch/_compile_spec.py | 126 ++++++------ py/trtorch/_compiler.py | 40 ++-- py/trtorch/_types.py | 2 +- py/trtorch/csrc/register_tensorrt_classes.cpp | 19 +- py/trtorch/csrc/tensorrt_classes.cpp | 110 +++++++---- py/trtorch/csrc/tensorrt_classes.h | 44 +++-- py/trtorch/csrc/trtorch_py.cpp | 23 ++- 12 files changed, 406 insertions(+), 163 deletions(-) create mode 100644 py/trtorch/Input.py diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index c00c39a1cd..324f3d0890 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -160,21 +160,21 @@ void AddInputs( for (size_t i = 0; i < input_tensors.size(); i++) { auto in = input_tensors[i]; - auto dims = input_specs[i]; + auto spec = input_specs[i]; std::string name = std::string("input_") + std::to_string(ctx->num_inputs); LOG_INFO( ctx->logger, - "Adding Input " << in->debugName() << " (named: " << name << "): " << dims << " in engine (conversion.AddInputs)"); + "Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)"); - auto trt_in = ctx->net->addInput(name.c_str(), dims.dtype, dims.input_shape); + auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape); TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)"); - trt_in->setAllowedFormats(1U << static_cast(dims.format)); + trt_in->setAllowedFormats(1U << static_cast(spec.format)); - profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min); - profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt); - profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max); + profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, spec.min); + profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, spec.opt); + profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, spec.max); - if (dims.input_is_dynamic) { + if (spec.input_is_dynamic) { ctx->input_is_dynamic = true; } diff --git a/core/ir/Input.cpp b/core/ir/Input.cpp index 62702cccb9..f1f9f694eb 100644 --- a/core/ir/Input.cpp +++ b/core/ir/Input.cpp @@ -129,7 +129,7 @@ Input::Input(std::vector shape, nvinfer1::DataType dtype, nvinfer1::Ten input_shape = util::toDims(shape); input_is_dynamic = false; format = nvinfer1::TensorFormat::kLINEAR; - dtype = nvinfer1::DataType::kFLOAT; + dtype = dtype; TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype); this->dtype = dtype; @@ -165,6 +165,7 @@ Input::Input(std::vector min_shape, std::vector opt_shape, std dim.insert(min_shape[i]); dim.insert(opt_shape[i]); dim.insert(max_shape[i]); + LOG_DEBUG(dim.size()); if (dim.size() != 1) { dyn_shape.push_back(-1); input_is_dynamic = true; @@ -182,7 +183,7 @@ Input::Input(std::vector min_shape, std::vector opt_shape, std } std::ostream& operator<<(std::ostream& os, const Input& input) { - if (input.input_is_dynamic) { + if (!input.input_is_dynamic) { os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')'; } else { os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')'; diff --git a/py/setup.py b/py/setup.py index 9790000ea5..8a35239ecd 100644 --- a/py/setup.py +++ b/py/setup.py @@ -181,7 +181,7 @@ def run(self): include_dirs=[ dir_path + "trtorch/csrc", dir_path + "/../", - dir_path + "/../bazel-TRTorch/external/tensorrt/include", + dir_path + "/../bazel-trtorch-testing/external/tensorrt/include", ], extra_compile_args=[ "-Wno-deprecated", diff --git a/py/trtorch/Input.py b/py/trtorch/Input.py new file mode 100644 index 0000000000..6fd01d1416 --- /dev/null +++ b/py/trtorch/Input.py @@ -0,0 +1,181 @@ +from enum import Enum +from typing import List, Dict, Any + +import torch + +from trtorch import _types +import trtorch._C + +class Input(object): + """ + Defines an input to a module in terms of expected shape, data type and tensor format. + + Attributes: + shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped + shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. + Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form + ``{ + "min_shape": Tuple, + "opt_shape": Tuple, + "max_shape": Tuple + }`` + dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32) + format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW) + """ + + class _ShapeMode(Enum): + STATIC = 0 + DYNAMIC = 1 + + shape_mode = None + shape = None + dtype = _types.dtype.float32 + format = _types.TensorFormat.contiguous + + def __init__(self, *args, **kwargs): + """ __init__ Method for trtorch.Input + + Input accepts one of a few construction patterns + + Args: + shape (Tuple or List, optional): Static shape of input tensor + + Keyword Arguments: + shape (Tuple or List, optional): Static shape of input tensor + min_shape (Tuple or List, optional): Min size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + opt_shape (Tuple or List, optional): Opt size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + max_shape (Tuple or List, optional): Max size of input tensor's shape range + Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC + dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32) + format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW) + + Examples: + - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) + - Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW) + - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW + """ + if len(args) == 1: + if not Input._supported_input_size_type(args[0]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(args[0]))) + if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]): + raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined") + self.shape = tuple(args[0]) + self.shape_mode = Input._ShapeMode.STATIC + + elif len(args) == 0: + if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])): + raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined") + elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]): + raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined") + + if "shape" in kwargs: + if not Input._supported_input_size_type(kwargs["shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(kwargs["shape"]))) + self.shape = tuple(kwargs["shape"]) + self.shape_mode = Input._ShapeMode.STATIC + else: + if not Input._supported_input_size_type(kwargs["min_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(kwargs["min_shape"])) + " for min_shape") + if not Input._supported_input_size_type(kwargs["opt_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(kwargs["opt_shape"])) + " for opt_shape") + if not Input._supported_input_size_type(kwargs["max_shape"]): + raise TypeError( + "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " + + str(type(kwargs["max_shape"])) + " for max_shape") + + self.shape = { + "min_shape": tuple(kwargs["min_shape"]), + "opt_shape": tuple(kwargs["opt_shape"]), + "max_shape": tuple(kwargs["max_shape"]) + } + self.shape_mode = Input._ShapeMode.DYNAMIC + + else: + raise ValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args))) + + if "dtype" in kwargs: + self.dtype = Input._parse_dtype(kwargs["dtype"]) + + if "format" in kwargs: + self.format = Input._parse_format(kwargs["format"]) + + def __str__(self) -> str: + if self.shape_mode == Input._ShapeMode.STATIC: + return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format)) + elif self.shape_mode == Input._ShapeMode.DYNAMIC: + return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format)) + else: + raise RuntimeError("Unknown input shape mode") + + def _to_internal(self) -> trtorch._C.Input: + internal_in = trtorch._C.Input() + if self.shape_mode == Input._ShapeMode.DYNAMIC: + internal_in.min = self.shape["min_shape"] + internal_in.opt = self.shape["opt_shape"] + internal_in.max = self.shape["max_shape"] + internal_in.input_is_dynamic = True + else: + internal_in.opt = self.shape + internal_in.input_is_dynamic = False + internal_in.dtype = self.dtype + internal_in.format = self.format + return internal_in + + @staticmethod + def _supported_input_size_type(input_size: Any) -> bool: + if isinstance(input_size, torch.Size): + return True + elif isinstance(input_size, tuple): + return True + elif isinstance(input_size, list): + return True + else: + return False + + @staticmethod + def _parse_dtype(dtype: Any) -> _types.dtype: + if isinstance(dtype, torch.dtype): + if dtype == torch.int32: + return _types.dtype.int32 + elif dtype == torch.half: + return _types.dtype.half + elif dtype == torch.float: + return _types.dtype.float + elif dtype == torch.bool: + return _types.dtype.bool + else: + raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " + + str(dtype)) + + elif isinstance(dtype, _types.DataTypes): + return dtype + + else: + raise TypeError("Input data type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + + str(type(dtype))) + + @staticmethod + def _parse_format(format: Any) -> _types.TensorFormat: + if isinstance(format, torch.memory_format): + if format == torch.contiguous_format: + return _types.TensorFormat.contiguous + elif format == torch.channels_last: + return _types.TensorFormat.channel_last + else: + raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)") + + elif isinstance(format, _types.TensorFormat): + return format + + else: + raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat") \ No newline at end of file diff --git a/py/trtorch/__init__.py b/py/trtorch/__init__.py index 49e13e71d4..a31aafbd1f 100644 --- a/py/trtorch/__init__.py +++ b/py/trtorch/__init__.py @@ -13,6 +13,7 @@ from trtorch import ptq from trtorch._types import * from trtorch import logging +from trtorch.Input import Input def _register_with_torch(): diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index eccad7118e..0ed137a45d 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -1,7 +1,10 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Set import torch import trtorch._C from trtorch import _types +from trtorch.Input import Input + +import warnings def _supported_input_size_type(input_size: Any) -> bool: @@ -26,36 +29,23 @@ def _parse_input_ranges(input_sizes: List) -> List: for i in input_sizes: if isinstance(i, dict): if all(k in i for k in ["min", "opt", "min"]): - in_range = trtorch._C.InputRange() - in_range.min = i["min"] - in_range.opt = i["opt"] - in_range.max = i["max"] - parsed_input_sizes.append(in_range) + parsed_input_sizes.append(Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal()) elif "opt" in i: - in_range = trtorch._C.InputRange() - in_range.min = i["opt"] - in_range.opt = i["opt"] - in_range.max = i["opt"] - parsed_input_sizes.append(in_range) + parsed_input_sizes.append(Input(shape=i["opt"])._to_internal()) else: raise KeyError( "An input size must either be a static size or a range of three sizes (min, opt, max) as Dict") elif isinstance(i, list): - in_range = trtorch._C.InputRange() - in_range.min = i - in_range.opt = i - in_range.max = i - parsed_input_sizes.append(in_range) + parsed_input_sizes.append(Input(shape=i)._to_internal()) elif isinstance(i, tuple): - in_range = trtorch._C.InputRange() - in_range.min = list(i) - in_range.opt = list(i) - in_range.max = list(i) - parsed_input_sizes.append(in_range) + parsed_input_sizes.append(Input(shape=i)._to_internal()) + + elif isinstance(i, torch.Size): + parsed_input_sizes.append(Input(shape=i)._to_internal()) return parsed_input_sizes @@ -80,6 +70,15 @@ def _parse_op_precision(precision: Any) -> _types.dtype: str(type(precision))) +def _parse_enabled_precisions(precisions: Any) -> Set: + parsed_precisions = set() + if any([isinstance(precisions, type) for type in [list, tuple, set]]): + for p in precisions: + parsed_precisions.add(_parse_op_precision(p)) + else: + parsed_precisions.add(_parse_op_precision(precisions)) + return parsed_precisions + def _parse_device_type(device: Any) -> _types.DeviceType: if isinstance(device, torch.device): if device.type == 'cuda': @@ -140,39 +139,36 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall return info -def _parse_input_dtypes(input_dtypes: List) -> List: - parsed_input_dtypes = [] - for dtype in input_dtypes: - if isinstance(dtype, torch.dtype): - if dtype == torch.int8: - parsed_input_dtypes.append(_types.dtype.int8) - elif dtype == torch.half: - parsed_input_dtypes.append(_types.dtype.half) - elif dtype == torch.float: - parsed_input_dtypes.append(_types.dtype.float) - elif dtype == torch.int32: - parsed_input_dtypes.append(_types.dtype.int32) - elif dtype == torch.bool: - parsed_input_dtypes.append(_types.dtype.bool) - else: - raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype)) - - return parsed_input_dtypes - def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: info = trtorch._C.CompileSpec() - if "input_shapes" not in compile_spec: + if "input_shapes" not in compile_spec and "inputs" not in compile_spec: raise KeyError( - "Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict" + "Module input definitions are requried to compile module. Provide a list of trtorch.Input keyed to \"inputs\" in the compile spec" ) - info.input_ranges = _parse_input_ranges(compile_spec["input_shapes"]) + if "input_shapes" in compile_spec and "inputs" in compile_spec: + raise KeyError( + "Found both key \"input_shapes\", and \"inputs\" in compile spec, please port forward to using only \"inputs\"" + ) + + if "input_shapes" in compile_spec: + warnings.warn("Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0", DeprecationWarning) + info.inputs = _parse_input_ranges(compile_spec["input_shapes"]) + + if "inputs" in compile_spec: + info.inputs = [ i._to_internal() for i in compile_spec["inputs"] ] + + if "op_precision" in compile_spec and "enabled_precisions" in compile_spec: + raise KeyError( + "Found both key \"op_precision\", and \"enabled_precisions\" in compile spec, please port forward to using only \"enabled_precisions\"" + ) if "op_precision" in compile_spec: - info.op_precision = _parse_op_precision(compile_spec["op_precision"]) + warnings.warn("Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0", DeprecationWarning) + info.enabled_precisions = _parse_enabled_precisions(compile_spec["op_precision"]) - if "input_dtypes" in compile_spec: - info.input_dtypes = _parse_input_dtypes(compile_spec["input_dtypes"]) + if "enabled_precisions" in compile_spec: + info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"]) if "calibrator" in compile_spec: info.ptq_calibrator = compile_spec["calibrator"] @@ -233,7 +229,8 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. Args: compile_spec (dict): Compilation settings including operating precision, target device, etc. One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs - to the graph. All other keys are optional. Entries for each method to be compiled. + to the graph as well as expect types and formats for those inputs. All other keys are optional. + Entries for each method to be compiled. Note: Partial compilation of TorchScript modules is not supported through the PyTorch TensorRT backend If you need this feature, use trtorch.compile to compile your module. Usage of the resulting module is @@ -243,13 +240,15 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. CompileSpec = { "forward" : trtorch.TensorRTCompileSpec({ - "input_shapes": [ - (1, 3, 224, 224), # Static input shape for input #1 - { - "min": (1, 3, 224, 224), - "opt": (1, 3, 512, 512), - "max": (1, 3, 1024, 1024) - } # Dynamic input shape for input #2 + "inputs": [ + trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1 + trtorch.Input( + min_shape=1, 3, 224, 224), + opt_shape=(1, 3, 512, 512), + max_shape=(1, 3, 1024, 1024), + dtype=torch.int32 + format=torch.channel_last + ) # Dynamic input shape for input #2 ], "device": { "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) @@ -284,12 +283,15 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec = torch.classes.tensorrt.CompileSpec() - for i in parsed_spec.input_ranges: - ir = torch.classes.tensorrt._InputRange() - ir._set_min(i.min) - ir._set_opt(i.opt) - ir._set_max(i.max) - backend_spec._append_input_range(ir) + for i in parsed_spec.inputs: + clone = torch.classes.tensorrt._Input() + clone._set_min(i.min) + clone._set_opt(i.opt) + clone._set_max(i.max) + clone._set_dtype(i.dtype) + clone._set_format(i.format) + clone._set_input_is_dynamic(i.input_is_dynamic) + backend_spec._append_input(clone) d = torch.classes.tensorrt._Device() d._set_device_type(int(parsed_spec.device.device_type)) @@ -309,9 +311,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt. backend_spec._set_device(d) backend_spec._set_torch_fallback(torch_fallback) - backend_spec._set_op_precision(int(parsed_spec.op_precision)) - for dtype in parsed_spec.input_dtypes: - backend_spec._append_input_dtypes(int64_t(dtype)) + backend_spec._set_precisions([int(i) for i in parsed_spec.enabled_precisions]) backend_spec._set_disable_tf32(parsed_spec.disable_tf32) backend_spec._set_refit(parsed_spec.refit) diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index bcb9e8ba77..dc4a93f1ac 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -20,19 +20,21 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch ``torch.nn.Module`` compile_spec (dict): Compilation settings including operating precision, target device, etc. - One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs - to the graph. All other keys are optional + One key is required which is ``inputs``, describing the input sizes or ranges for inputs + to the graph as well as expect types and formats for those inputs. All other keys are optional .. code-block:: py compile_spec = { - "input_shapes": [ - (1, 3, 224, 224), # Static input shape for input #1 - { - "min": (1, 3, 224, 224), - "opt": (1, 3, 512, 512), - "max": (1, 3, 1024, 1024) - } # Dynamic input shape for input #2 + "inputs": [ + trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1 + trtorch.Input( + min_shape=1, 3, 224, 224), + opt_shape=(1, 3, 512, 512), + max_shape=(1, 3, 1024, 1024), + dtype=torch.int32 + format=torch.channel_last + ) # Dynamic input shape for input #2 ], "device": { "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) @@ -86,19 +88,21 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st ``torch.nn.Module`` method_name (str): Name of method to convert compile_spec (dict): Compilation settings including operating precision, target device, etc. - One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs - to the graph. All other keys are optional + One key is required which is ``inputs``, describing the input sizes or ranges for inputs + to the graph as well as expect types and formats for those inputs. All other keys are optional .. code-block:: py CompileSpec = { - "input_shapes": [ - (1, 3, 224, 224), # Static input shape for input #1 - { - "min": (1, 3, 224, 224), - "opt": (1, 3, 512, 512), - "max": (1, 3, 1024, 1024) - } # Dynamic input shape for input #2 + "inputs": [ + trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1 + trtorch.Input( + min_shape=1, 3, 224, 224), + opt_shape=(1, 3, 512, 512), + max_shape=(1, 3, 1024, 1024), + dtype=torch.int32 + format=torch.channel_last + ) # Dynamic input shape for input #2 ], "device": { "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) diff --git a/py/trtorch/_types.py b/py/trtorch/_types.py index 48244c3e85..7f323c7140 100644 --- a/py/trtorch/_types.py +++ b/py/trtorch/_types.py @@ -1 +1 @@ -from trtorch._C import dtype, DeviceType, EngineCapability +from trtorch._C import dtype, DeviceType, EngineCapability, TensorFormat diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index c5f93b4560..9080f33fdd 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -10,13 +10,17 @@ namespace { void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = - torch::class_("tensorrt", "_InputRange") + torch::class_("tensorrt", "_Input") .def(torch::init<>()) - .def("__str__", &trtorch::pyapi::InputRange::to_str); + .def("__str__", &trtorch::pyapi::Input::to_str); + + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, max); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, dtype); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic); - ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min); - ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt); - ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max); static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_("tensorrt", "_Device") .def(torch::init<>()) @@ -39,14 +43,13 @@ void RegisterTRTCompileSpec() { static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration = torch::class_("tensorrt", "CompileSpec") .def(torch::init<>()) - .def("_append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange) + .def("_append_input", &trtorch::pyapi::CompileSpec::appendInput) + .def("_set_precisions", &trtorch::pyapi::CompileSpec::setPrecisions) .def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive) .def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive) .def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle) - .def("_append_input_dtypes", &trtorch::pyapi::CompileSpec::appendInputDtypes) .def("__str__", &trtorch::pyapi::CompileSpec::stringify); - ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug); diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index a74cfa92ab..19b921dc8a 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -4,26 +4,6 @@ namespace trtorch { namespace pyapi { -std::string InputRange::to_str() { - auto vec_to_str = [](std::vector shape) -> std::string { - std::stringstream ss; - ss << '['; - for (auto i : shape) { - ss << i << ','; - } - ss << ']'; - return ss.str(); - }; - - std::stringstream ss; - ss << " {" << std::endl; - ss << " min: " << vec_to_str(min) << ',' << std::endl; - ss << " opt: " << vec_to_str(opt) << ',' << std::endl; - ss << " max: " << vec_to_str(max) << ',' << std::endl; - ss << " }" << std::endl; - return ss.str(); -} - std::string to_str(DataType value) { switch (value) { case DataType::kHalf: @@ -35,8 +15,9 @@ std::string to_str(DataType value) { case DataType::kBool: return "Bool"; case DataType::kFloat: - default: return "Float"; + default: + return "Unknown data type"; } } @@ -51,11 +32,69 @@ nvinfer1::DataType toTRTDataType(DataType value) { case DataType::kBool: return nvinfer1::DataType::kBOOL; case DataType::kFloat: - default: return nvinfer1::DataType::kFLOAT; + default: + TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value)); } } +nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { + switch (value) { + case TensorFormat::kChannelLast: + return nvinfer1::TensorFormat::kHWC; + case TensorFormat::kContiguous: + default: + return nvinfer1::TensorFormat::kLINEAR; + } +} + +std::string to_str(TensorFormat value) { + switch (value) { + case TensorFormat::kContiguous: + return "Contiguous/Linear/NCHW"; + case TensorFormat::kChannelLast: + return "Channel Last/NHWC"; + default: + return "UNKNOWN"; + } +} + +core::ir::Input Input::toInternalInput() { + if (!input_is_dynamic) { + return core::ir::Input(opt, toTRTDataType(dtype), toTRTTensorFormat(format)); + } else { + return core::ir::Input(min, opt, max, toTRTDataType(dtype), toTRTTensorFormat(format)); + } +} + +std::string Input::to_str() { + auto vec_to_str = [](std::vector shape) -> std::string { + std::stringstream ss; + ss << '('; + for (auto i : shape) { + ss << i << ','; + } + ss << ')'; + return ss.str(); + }; + + std::stringstream ss; + ss << "Input("; + + if (!input_is_dynamic) { + ss << "shape=" << vec_to_str(opt) << ", "; + } else { + ss << "min_shape=" << vec_to_str(min) << ", "; + ss << "opt_shape=" << vec_to_str(opt) << ", "; + ss << "max_shape=" << vec_to_str(max) << ", "; + } + + ss << "dtype=" << pyapi::to_str(dtype) << ", "; + ss << "format=" << pyapi::to_str(format) << ')'; + + return ss.str(); +} + std::string to_str(DeviceType value) { switch (value) { case DeviceType::kDLA: @@ -128,19 +167,17 @@ std::string TorchFallback::to_str() { } core::CompileSpec CompileSpec::toInternalCompileSpec() { - std::vector internal_input_ranges; - for (auto i : input_ranges) { - internal_input_ranges.push_back(i.toInternalInputRange()); + std::vector internal_inputs; + for (auto i : inputs) { + internal_inputs.push_back(i.toInternalInput()); } - std::vector trt_input_dtypes; - for (auto dtype : input_dtypes) { - trt_input_dtypes.push_back(toTRTDataType(dtype)); + auto info = core::CompileSpec(internal_inputs); + + for (auto p : enabled_precisions) { + info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } - auto info = core::CompileSpec(internal_input_ranges); - info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision); - info.convert_info.engine_settings.input_dtypes = trt_input_dtypes; info.convert_info.engine_settings.calibrator = ptq_calibrator; info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; @@ -170,15 +207,14 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { std::string CompileSpec::stringify() { std::stringstream ss; ss << "TensorRT Compile Spec: {" << std::endl; - ss << " \"Input Shapes\": [" << std::endl; - for (auto i : input_ranges) { + ss << " \"Inputs\": [" << std::endl; + for (auto i : inputs) { ss << i.to_str(); } ss << " ]" << std::endl; - ss << " \"Op Precision\": " << to_str(op_precision) << std::endl; - ss << " \"Input dtypes\": [" << std::endl; - for (auto i : input_dtypes) { - ss << to_str(i); + ss << " \"Enabled Precision\": [" << std::endl; + for (auto p : enabled_precisions) { + ss << to_str(p); } ss << " ]" << std::endl; ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl; diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 4d6d4395f8..bba3af46ad 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -27,27 +27,34 @@ namespace pyapi { return static_cast(field_name); \ } -struct InputRange : torch::CustomClassHolder { +enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool }; +std::string to_str(DataType value); +nvinfer1::DataType toTRTDataType(DataType value); + +enum class TensorFormat : int8_t { kContiguous, kChannelLast }; +std::string to_str(TensorFormat value); +nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value); + +struct Input : torch::CustomClassHolder { std::vector min; std::vector opt; std::vector max; - core::ir::InputRange toInternalInputRange() { - return core::ir::InputRange(min, opt, max); - } + bool input_is_dynamic; + DataType dtype; + TensorFormat format; ADD_FIELD_GET_SET(min, std::vector); ADD_FIELD_GET_SET(opt, std::vector); ADD_FIELD_GET_SET(max, std::vector); + ADD_FIELD_GET_SET(input_is_dynamic, bool); + ADD_ENUM_GET_SET(dtype, DataType, static_cast(DataType::kBool)); + ADD_ENUM_GET_SET(format, TensorFormat, static_cast(TensorFormat::kContiguous)); + core::ir::Input toInternalInput(); std::string to_str(); }; -enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool }; - -std::string to_str(DataType value); -nvinfer1::DataType toTRTDataType(DataType value); - enum DeviceType : int8_t { kGPU, kDLA, @@ -101,12 +108,17 @@ nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value); struct CompileSpec : torch::CustomClassHolder { core::CompileSpec toInternalCompileSpec(); std::string stringify(); - void appendInputRange(const c10::intrusive_ptr& ir) { - input_ranges.push_back(*ir); + void appendInput(const c10::intrusive_ptr& ir) { + inputs.push_back(*ir); } - void appendInputDtypes(int64_t dtype) { - input_dtypes.push_back(static_cast(dtype)); + + void setPrecisions(const std::vector& precisions_raw) { + for (auto p : precisions_raw) { + TRTORCH_CHECK(p >= 0 && p <= static_cast(DataType::kBool), "Invalid enum value for field"); + enabled_precisions.insert(static_cast(p)); + } } + int64_t getPTQCalibratorHandle() { return (int64_t)ptq_calibrator; } @@ -123,7 +135,6 @@ struct CompileSpec : torch::CustomClassHolder { ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle; } - ADD_ENUM_GET_SET(op_precision, DataType, static_cast(DataType::kChar)); ADD_FIELD_GET_SET(disable_tf32, bool); ADD_FIELD_GET_SET(refit, bool); ADD_FIELD_GET_SET(debug, bool); @@ -138,10 +149,9 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(torch_fallback, TorchFallback); ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); - std::vector input_ranges; + std::vector inputs; nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; - DataType op_precision = DataType::kFloat; - std::vector input_dtypes; + std::set enabled_precisions = {DataType::kFloat}; bool disable_tf32 = false; bool refit = false; bool debug = false; diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 91e76a707c..6b2520bd57 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -163,12 +163,15 @@ void log(core::util::logging::LogLevel lvl, const std::string& msg) { } // namespace logging PYBIND11_MODULE(_C, m) { - py::class_(m, "InputRange") + py::class_(m, "Input") .def(py::init<>()) - .def("__str__", &trtorch::pyapi::InputRange::to_str) - .def_readwrite("min", &InputRange::min) - .def_readwrite("opt", &InputRange::opt) - .def_readwrite("max", &InputRange::max); + .def("__str__", &trtorch::pyapi::Input::to_str) + .def_readwrite("min", &Input::min) + .def_readwrite("opt", &Input::opt) + .def_readwrite("max", &Input::max) + .def_readwrite("input_is_dynamic", &Input::input_is_dynamic) + .def_readwrite("dtype", &Input::dtype) + .def_readwrite("format", &Input::format); py::enum_(m, "dtype", "Enum to specifiy operating precision for engine execution") .value("float", DataType::kFloat, "32 bit floating point number") @@ -193,6 +196,11 @@ PYBIND11_MODULE(_C, m) { .value("safe_dla", EngineCapability::kSAFE_DLA, "Use safety DLA kernels only") .value("default", EngineCapability::kDEFAULT, "Use default behavior"); + py::enum_(m, "TensorFormat", "Enum to specifiy the memory layout of tensors") + .value("contiguous", TensorFormat::kContiguous, "Contiguous memory layout (NCHW / Linear)") + .value("channel_last", TensorFormat::kChannelLast, "Channel last memory layout (NHWC)") + .export_values(); + py::enum_(m, "CalibrationAlgo", py::module_local(), "Type of calibration algorithm") .value("LEGACY_CALIBRATION", nvinfer1::CalibrationAlgoType::kLEGACY_CALIBRATION) .value("ENTROPY_CALIBRATION", nvinfer1::CalibrationAlgoType::kENTROPY_CALIBRATION) @@ -242,9 +250,8 @@ PYBIND11_MODULE(_C, m) { .def(py::init<>()) .def("__str__", &trtorch::pyapi::CompileSpec::stringify) .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") - .def_readwrite("input_ranges", &CompileSpec::input_ranges) - .def_readwrite("op_precision", &CompileSpec::op_precision) - .def_readwrite("input_dtypes", &CompileSpec::input_dtypes) + .def_readwrite("inputs", &CompileSpec::inputs) + .def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions) .def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) .def_readwrite("refit", &CompileSpec::refit) .def_readwrite("disable_tf32", &CompileSpec::disable_tf32)