Skip to content

Commit

Permalink
feat(//py)!: Implementing top level python api changes to reflect new
Browse files Browse the repository at this point in the history
Input type and enabled_precisions set

BREAKING CHANGE: This commit introduces the next iteration of the Python
TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes"
and "op_precision" compile spec keys will be removed. Users should port
forward to using the "inputs" key which expects a list of trtorch.Input
objects and the "enabled_precisions" key which expects a set of data
type specifying enums.

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jul 21, 2021
1 parent 316df28 commit 482265f
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 163 deletions.
16 changes: 8 additions & 8 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,21 @@ void AddInputs(

for (size_t i = 0; i < input_tensors.size(); i++) {
auto in = input_tensors[i];
auto dims = input_specs[i];
auto spec = input_specs[i];
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger,
"Adding Input " << in->debugName() << " (named: " << name << "): " << dims << " in engine (conversion.AddInputs)");
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");

auto trt_in = ctx->net->addInput(name.c_str(), dims.dtype, dims.input_shape);
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
trt_in->setAllowedFormats(1U << static_cast<int>(dims.format));
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));

profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, spec.min);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, spec.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, spec.max);

if (dims.input_is_dynamic) {
if (spec.input_is_dynamic) {
ctx->input_is_dynamic = true;
}

Expand Down
5 changes: 3 additions & 2 deletions core/ir/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
input_shape = util::toDims(shape);
input_is_dynamic = false;
format = nvinfer1::TensorFormat::kLINEAR;
dtype = nvinfer1::DataType::kFLOAT;
dtype = dtype;

TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
this->dtype = dtype;
Expand Down Expand Up @@ -165,6 +165,7 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
dim.insert(min_shape[i]);
dim.insert(opt_shape[i]);
dim.insert(max_shape[i]);
LOG_DEBUG(dim.size());
if (dim.size() != 1) {
dyn_shape.push_back(-1);
input_is_dynamic = true;
Expand All @@ -182,7 +183,7 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
}

std::ostream& operator<<(std::ostream& os, const Input& input) {
if (input.input_is_dynamic) {
if (!input.input_is_dynamic) {
os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
} else {
os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';
Expand Down
2 changes: 1 addition & 1 deletion py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def run(self):
include_dirs=[
dir_path + "trtorch/csrc",
dir_path + "/../",
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
dir_path + "/../bazel-trtorch-testing/external/tensorrt/include",
],
extra_compile_args=[
"-Wno-deprecated",
Expand Down
181 changes: 181 additions & 0 deletions py/trtorch/Input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from enum import Enum
from typing import List, Dict, Any

import torch

from trtorch import _types
import trtorch._C

class Input(object):
"""
Defines an input to a module in terms of expected shape, data type and tensor format.
Attributes:
shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape.
Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
``{
"min_shape": Tuple,
"opt_shape": Tuple,
"max_shape": Tuple
}``
dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32)
format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
"""

class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1

shape_mode = None
shape = None
dtype = _types.dtype.float32
format = _types.TensorFormat.contiguous

def __init__(self, *args, **kwargs):
""" __init__ Method for trtorch.Input
Input accepts one of a few construction patterns
Args:
shape (Tuple or List, optional): Static shape of input tensor
Keyword Arguments:
shape (Tuple or List, optional): Static shape of input tensor
min_shape (Tuple or List, optional): Min size of input tensor's shape range
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
opt_shape (Tuple or List, optional): Opt size of input tensor's shape range
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
max_shape (Tuple or List, optional): Max size of input tensor's shape range
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32)
format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
Examples:
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
- Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW)
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW
"""
if len(args) == 1:
if not Input._supported_input_size_type(args[0]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(args[0])))
if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
self.shape = tuple(args[0])
self.shape_mode = Input._ShapeMode.STATIC

elif len(args) == 0:
if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined")
elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")

if "shape" in kwargs:
if not Input._supported_input_size_type(kwargs["shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(kwargs["shape"])))
self.shape = tuple(kwargs["shape"])
self.shape_mode = Input._ShapeMode.STATIC
else:
if not Input._supported_input_size_type(kwargs["min_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(kwargs["min_shape"])) + " for min_shape")
if not Input._supported_input_size_type(kwargs["opt_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(kwargs["opt_shape"])) + " for opt_shape")
if not Input._supported_input_size_type(kwargs["max_shape"]):
raise TypeError(
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
+ str(type(kwargs["max_shape"])) + " for max_shape")

self.shape = {
"min_shape": tuple(kwargs["min_shape"]),
"opt_shape": tuple(kwargs["opt_shape"]),
"max_shape": tuple(kwargs["max_shape"])
}
self.shape_mode = Input._ShapeMode.DYNAMIC

else:
raise ValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args)))

if "dtype" in kwargs:
self.dtype = Input._parse_dtype(kwargs["dtype"])

if "format" in kwargs:
self.format = Input._parse_format(kwargs["format"])

def __str__(self) -> str:
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format))
else:
raise RuntimeError("Unknown input shape mode")

def _to_internal(self) -> trtorch._C.Input:
internal_in = trtorch._C.Input()
if self.shape_mode == Input._ShapeMode.DYNAMIC:
internal_in.min = self.shape["min_shape"]
internal_in.opt = self.shape["opt_shape"]
internal_in.max = self.shape["max_shape"]
internal_in.input_is_dynamic = True
else:
internal_in.opt = self.shape
internal_in.input_is_dynamic = False
internal_in.dtype = self.dtype
internal_in.format = self.format
return internal_in

@staticmethod
def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
return True
elif isinstance(input_size, tuple):
return True
elif isinstance(input_size, list):
return True
else:
return False

@staticmethod
def _parse_dtype(dtype: Any) -> _types.dtype:
if isinstance(dtype, torch.dtype):
if dtype == torch.int32:
return _types.dtype.int32
elif dtype == torch.half:
return _types.dtype.half
elif dtype == torch.float:
return _types.dtype.float
elif dtype == torch.bool:
return _types.dtype.bool
else:
raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " +
str(dtype))

elif isinstance(dtype, _types.DataTypes):
return dtype

else:
raise TypeError("Input data type needs to be specified with a torch.dtype or a trtorch.dtype, got: " +
str(type(dtype)))

@staticmethod
def _parse_format(format: Any) -> _types.TensorFormat:
if isinstance(format, torch.memory_format):
if format == torch.contiguous_format:
return _types.TensorFormat.contiguous
elif format == torch.channels_last:
return _types.TensorFormat.channel_last
else:
raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")

elif isinstance(format, _types.TensorFormat):
return format

else:
raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
1 change: 1 addition & 0 deletions py/trtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from trtorch import ptq
from trtorch._types import *
from trtorch import logging
from trtorch.Input import Input


def _register_with_torch():
Expand Down
Loading

0 comments on commit 482265f

Please sign in to comment.