Skip to content

Commit 482265f

Browse files
committed
feat(//py)!: Implementing top level python api changes to reflect new
Input type and enabled_precisions set BREAKING CHANGE: This commit introduces the next iteration of the Python TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes" and "op_precision" compile spec keys will be removed. Users should port forward to using the "inputs" key which expects a list of trtorch.Input objects and the "enabled_precisions" key which expects a set of data type specifying enums. Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 316df28 commit 482265f

12 files changed

+406
-163
lines changed

core/conversion/conversion.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,21 @@ void AddInputs(
160160

161161
for (size_t i = 0; i < input_tensors.size(); i++) {
162162
auto in = input_tensors[i];
163-
auto dims = input_specs[i];
163+
auto spec = input_specs[i];
164164
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
165165
LOG_INFO(
166166
ctx->logger,
167-
"Adding Input " << in->debugName() << " (named: " << name << "): " << dims << " in engine (conversion.AddInputs)");
167+
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");
168168

169-
auto trt_in = ctx->net->addInput(name.c_str(), dims.dtype, dims.input_shape);
169+
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
170170
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
171-
trt_in->setAllowedFormats(1U << static_cast<int>(dims.format));
171+
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));
172172

173-
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);
174-
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
175-
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
173+
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, spec.min);
174+
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, spec.opt);
175+
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, spec.max);
176176

177-
if (dims.input_is_dynamic) {
177+
if (spec.input_is_dynamic) {
178178
ctx->input_is_dynamic = true;
179179
}
180180

core/ir/Input.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
129129
input_shape = util::toDims(shape);
130130
input_is_dynamic = false;
131131
format = nvinfer1::TensorFormat::kLINEAR;
132-
dtype = nvinfer1::DataType::kFLOAT;
132+
dtype = dtype;
133133

134134
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
135135
this->dtype = dtype;
@@ -165,6 +165,7 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
165165
dim.insert(min_shape[i]);
166166
dim.insert(opt_shape[i]);
167167
dim.insert(max_shape[i]);
168+
LOG_DEBUG(dim.size());
168169
if (dim.size() != 1) {
169170
dyn_shape.push_back(-1);
170171
input_is_dynamic = true;
@@ -182,7 +183,7 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
182183
}
183184

184185
std::ostream& operator<<(std::ostream& os, const Input& input) {
185-
if (input.input_is_dynamic) {
186+
if (!input.input_is_dynamic) {
186187
os << "Input(shape: " << input.input_shape << ", dtype: " << input.dtype << ", format: " << input.format << ')';
187188
} else {
188189
os << "Input(shape: " << input.input_shape << ", min: " << input.min << ", opt: " << input.opt << ", max: " << input.max << ", dtype: " << input.dtype << ", format: " << input.format << ')';

py/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def run(self):
181181
include_dirs=[
182182
dir_path + "trtorch/csrc",
183183
dir_path + "/../",
184-
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
184+
dir_path + "/../bazel-trtorch-testing/external/tensorrt/include",
185185
],
186186
extra_compile_args=[
187187
"-Wno-deprecated",

py/trtorch/Input.py

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from enum import Enum
2+
from typing import List, Dict, Any
3+
4+
import torch
5+
6+
from trtorch import _types
7+
import trtorch._C
8+
9+
class Input(object):
10+
"""
11+
Defines an input to a module in terms of expected shape, data type and tensor format.
12+
13+
Attributes:
14+
shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
15+
shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape.
16+
Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
17+
``{
18+
"min_shape": Tuple,
19+
"opt_shape": Tuple,
20+
"max_shape": Tuple
21+
}``
22+
dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32)
23+
format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
24+
"""
25+
26+
class _ShapeMode(Enum):
27+
STATIC = 0
28+
DYNAMIC = 1
29+
30+
shape_mode = None
31+
shape = None
32+
dtype = _types.dtype.float32
33+
format = _types.TensorFormat.contiguous
34+
35+
def __init__(self, *args, **kwargs):
36+
""" __init__ Method for trtorch.Input
37+
38+
Input accepts one of a few construction patterns
39+
40+
Args:
41+
shape (Tuple or List, optional): Static shape of input tensor
42+
43+
Keyword Arguments:
44+
shape (Tuple or List, optional): Static shape of input tensor
45+
min_shape (Tuple or List, optional): Min size of input tensor's shape range
46+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
47+
opt_shape (Tuple or List, optional): Opt size of input tensor's shape range
48+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
49+
max_shape (Tuple or List, optional): Max size of input tensor's shape range
50+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
51+
dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32)
52+
format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
53+
54+
Examples:
55+
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
56+
- Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW)
57+
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW
58+
"""
59+
if len(args) == 1:
60+
if not Input._supported_input_size_type(args[0]):
61+
raise TypeError(
62+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
63+
+ str(type(args[0])))
64+
if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
65+
raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
66+
self.shape = tuple(args[0])
67+
self.shape_mode = Input._ShapeMode.STATIC
68+
69+
elif len(args) == 0:
70+
if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
71+
raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined")
72+
elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
73+
raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
74+
75+
if "shape" in kwargs:
76+
if not Input._supported_input_size_type(kwargs["shape"]):
77+
raise TypeError(
78+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
79+
+ str(type(kwargs["shape"])))
80+
self.shape = tuple(kwargs["shape"])
81+
self.shape_mode = Input._ShapeMode.STATIC
82+
else:
83+
if not Input._supported_input_size_type(kwargs["min_shape"]):
84+
raise TypeError(
85+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
86+
+ str(type(kwargs["min_shape"])) + " for min_shape")
87+
if not Input._supported_input_size_type(kwargs["opt_shape"]):
88+
raise TypeError(
89+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
90+
+ str(type(kwargs["opt_shape"])) + " for opt_shape")
91+
if not Input._supported_input_size_type(kwargs["max_shape"]):
92+
raise TypeError(
93+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
94+
+ str(type(kwargs["max_shape"])) + " for max_shape")
95+
96+
self.shape = {
97+
"min_shape": tuple(kwargs["min_shape"]),
98+
"opt_shape": tuple(kwargs["opt_shape"]),
99+
"max_shape": tuple(kwargs["max_shape"])
100+
}
101+
self.shape_mode = Input._ShapeMode.DYNAMIC
102+
103+
else:
104+
raise ValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args)))
105+
106+
if "dtype" in kwargs:
107+
self.dtype = Input._parse_dtype(kwargs["dtype"])
108+
109+
if "format" in kwargs:
110+
self.format = Input._parse_format(kwargs["format"])
111+
112+
def __str__(self) -> str:
113+
if self.shape_mode == Input._ShapeMode.STATIC:
114+
return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
115+
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
116+
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format))
117+
else:
118+
raise RuntimeError("Unknown input shape mode")
119+
120+
def _to_internal(self) -> trtorch._C.Input:
121+
internal_in = trtorch._C.Input()
122+
if self.shape_mode == Input._ShapeMode.DYNAMIC:
123+
internal_in.min = self.shape["min_shape"]
124+
internal_in.opt = self.shape["opt_shape"]
125+
internal_in.max = self.shape["max_shape"]
126+
internal_in.input_is_dynamic = True
127+
else:
128+
internal_in.opt = self.shape
129+
internal_in.input_is_dynamic = False
130+
internal_in.dtype = self.dtype
131+
internal_in.format = self.format
132+
return internal_in
133+
134+
@staticmethod
135+
def _supported_input_size_type(input_size: Any) -> bool:
136+
if isinstance(input_size, torch.Size):
137+
return True
138+
elif isinstance(input_size, tuple):
139+
return True
140+
elif isinstance(input_size, list):
141+
return True
142+
else:
143+
return False
144+
145+
@staticmethod
146+
def _parse_dtype(dtype: Any) -> _types.dtype:
147+
if isinstance(dtype, torch.dtype):
148+
if dtype == torch.int32:
149+
return _types.dtype.int32
150+
elif dtype == torch.half:
151+
return _types.dtype.half
152+
elif dtype == torch.float:
153+
return _types.dtype.float
154+
elif dtype == torch.bool:
155+
return _types.dtype.bool
156+
else:
157+
raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " +
158+
str(dtype))
159+
160+
elif isinstance(dtype, _types.DataTypes):
161+
return dtype
162+
163+
else:
164+
raise TypeError("Input data type needs to be specified with a torch.dtype or a trtorch.dtype, got: " +
165+
str(type(dtype)))
166+
167+
@staticmethod
168+
def _parse_format(format: Any) -> _types.TensorFormat:
169+
if isinstance(format, torch.memory_format):
170+
if format == torch.contiguous_format:
171+
return _types.TensorFormat.contiguous
172+
elif format == torch.channels_last:
173+
return _types.TensorFormat.channel_last
174+
else:
175+
raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")
176+
177+
elif isinstance(format, _types.TensorFormat):
178+
return format
179+
180+
else:
181+
raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")

py/trtorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from trtorch import ptq
1414
from trtorch._types import *
1515
from trtorch import logging
16+
from trtorch.Input import Input
1617

1718

1819
def _register_with_torch():

0 commit comments

Comments
 (0)