You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(//py)!: Implementing top level python api changes to reflect new
Input type and enabled_precisions set
BREAKING CHANGE: This commit introduces the next iteration of the Python
TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes"
and "op_precision" compile spec keys will be removed. Users should port
forward to using the "inputs" key which expects a list of trtorch.Input
objects and the "enabled_precisions" key which expects a set of data
type specifying enums.
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
Defines an input to a module in terms of expected shape, data type and tensor format.
12
+
13
+
Attributes:
14
+
shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
15
+
shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape.
16
+
Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
17
+
``{
18
+
"min_shape": Tuple,
19
+
"opt_shape": Tuple,
20
+
"max_shape": Tuple
21
+
}``
22
+
dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32)
23
+
format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
24
+
"""
25
+
26
+
class_ShapeMode(Enum):
27
+
STATIC=0
28
+
DYNAMIC=1
29
+
30
+
shape_mode=None
31
+
shape=None
32
+
dtype=_types.dtype.float32
33
+
format=_types.TensorFormat.contiguous
34
+
35
+
def__init__(self, *args, **kwargs):
36
+
""" __init__ Method for trtorch.Input
37
+
38
+
Input accepts one of a few construction patterns
39
+
40
+
Args:
41
+
shape (Tuple or List, optional): Static shape of input tensor
42
+
43
+
Keyword Arguments:
44
+
shape (Tuple or List, optional): Static shape of input tensor
45
+
min_shape (Tuple or List, optional): Min size of input tensor's shape range
46
+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
47
+
opt_shape (Tuple or List, optional): Opt size of input tensor's shape range
48
+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
49
+
max_shape (Tuple or List, optional): Max size of input tensor's shape range
50
+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
51
+
dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32)
52
+
format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
raiseValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
raiseValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
94
+
+str(type(kwargs["max_shape"])) +" for max_shape")
95
+
96
+
self.shape= {
97
+
"min_shape": tuple(kwargs["min_shape"]),
98
+
"opt_shape": tuple(kwargs["opt_shape"]),
99
+
"max_shape": tuple(kwargs["max_shape"])
100
+
}
101
+
self.shape_mode=Input._ShapeMode.DYNAMIC
102
+
103
+
else:
104
+
raiseValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args)))
0 commit comments