Skip to content

Commit 0206c34

Browse files
committed
fix/feat: Add and repair multiple converters
- Focus on SD-performance-accelerating converters - Add test cases for converters to avoid regressions - Add prims sum converter
1 parent a7f9055 commit 0206c34

File tree

10 files changed

+230
-27
lines changed

10 files changed

+230
-27
lines changed

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from .aten_ops_converters import * # noqa: F403
44
from .conversion import * # noqa: F403
55
from .op_evaluators import * # noqa: F403
6+
from .prims_ops_converters import * # noqa: F403
67
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,14 +651,15 @@ def aten_ops_amax(
651651

652652
@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
653653
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
654+
@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc]
654655
def aten_ops_sum(
655656
ctx: ConversionContext,
656657
target: Target,
657658
args: Tuple[Argument, ...],
658659
kwargs: Dict[str, Argument],
659660
name: str,
660661
) -> Union[TRTTensor, Sequence[TRTTensor]]:
661-
return impl.reduce.sum(
662+
sum_ = impl.reduce.sum(
662663
ctx,
663664
target,
664665
SourceIR.ATEN,
@@ -668,6 +669,19 @@ def aten_ops_sum(
668669
args_bounds_check(args, 2, replacement=False),
669670
)
670671

672+
if kwargs.get("output_dtype", None) is not None:
673+
return impl.cast.to_copy(
674+
ctx,
675+
target,
676+
SourceIR.ATEN,
677+
name,
678+
sum_,
679+
kwargs["output_dtype"],
680+
force_layer=False,
681+
)
682+
else:
683+
return sum_
684+
671685

672686
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
673687
def aten_ops_exp(
@@ -1166,6 +1180,7 @@ def aten_ops_sub(
11661180
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
11671181
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
11681182
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
1183+
@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc]
11691184
def aten_ops_div(
11701185
ctx: ConversionContext,
11711186
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import numpy as np
34
import tensorrt as trt
45
import torch
56
from torch.fx.node import Target
@@ -23,16 +24,6 @@ def where(
2324
other: TRTTensor,
2425
condition: TRTTensor,
2526
) -> TRTTensor:
26-
input_dim = len(tuple(input.shape))
27-
other_dim = len(tuple(other.shape))
28-
condition_dim = len(tuple(condition.shape))
29-
30-
if type(input) != TRTTensor:
31-
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"
32-
33-
if type(other) != TRTTensor:
34-
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"
35-
3627
if not (broadcastable(input, other)):
3728
assert "The two torch tensors should be broadcastable"
3829

@@ -49,33 +40,37 @@ def where(
4940
x_shape = list(input.shape)
5041
y_shape = list(other.shape)
5142
condition_shape = list(condition.shape)
43+
5244
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
5345

5446
# expand shape
55-
if type(condition) != TRTTensor:
56-
assert condition.dtype == torch.bool, "condition dtype is not bool"
47+
if not isinstance(condition, TRTTensor):
48+
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
5749
if condition_shape != output_shape:
58-
condition.expand(output_shape)
59-
condition = condition.to(torch.int32)
60-
condition_const = get_trt_tensor(ctx, condition, f"{name}_condition")
61-
condition_layer = ctx.net.add_identity(condition_const)
62-
condition_layer.set_output_type(0, trt.bool)
63-
set_layer_name(condition_layer, target, f"{name}_condition")
64-
condition_val = condition_layer.get_output(0)
50+
condition = (
51+
condition.expand(output_shape)
52+
if isinstance(condition, torch.Tensor)
53+
else np.broadcast_to(condition, output_shape)
54+
)
55+
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
6556
else:
6657
assert condition.dtype == trt.bool, "mask dtype is not bool!"
67-
if len(condition_shape) != condition_dim:
58+
if condition_shape != output_shape:
6859
condition_val = expand(
6960
ctx, target, source_ir, f"{name}_expand", condition, output_shape
7061
)
7162
else:
7263
condition_val = condition
7364

74-
if type(input) != TRTTensor:
65+
if not isinstance(input, TRTTensor):
7566
if x_shape != output_shape:
7667
# special case where 1 element in input
7768
if len(input.shape) == 0:
78-
input = input.unsqueeze(0)
69+
input = (
70+
input.unsqueeze(0)
71+
if isinstance(input, torch.Tensor)
72+
else np.expand_dims(input, axis=0)
73+
)
7974
input = input.expand(output_shape)
8075
x_val = get_trt_tensor(ctx, input, f"{name}_x")
8176
else:
@@ -85,11 +80,15 @@ def where(
8580
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
8681
)
8782

88-
if type(other) != TRTTensor:
83+
if not isinstance(other, TRTTensor):
8984
if y_shape != output_shape:
9085
# special case where 1 element in other
9186
if len(other.shape) == 0:
92-
other = other.unsqueeze(0)
87+
other = (
88+
other.unsqueeze(0)
89+
if isinstance(other, torch.Tensor)
90+
else np.expand_dims(other, axis=0)
91+
)
9392
other = other.expand(output_shape)
9493
y_val = get_trt_tensor(ctx, other, f"{name}_y")
9594
else:

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def sum(
5151
):
5252
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
5353

54-
if dim is None:
54+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5555
dim = tuple(range(len(input_val.shape)))
5656
layer = ctx.net.add_reduce(
5757
input_val,

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast
1+
from typing import List, Optional, Sequence, cast
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -49,3 +49,42 @@ def unsqueeze(
4949
)
5050
set_layer_name(layer, target, name, source_ir)
5151
return layer.get_output(0)
52+
53+
54+
def broadcast_in_dim(
55+
ctx: ConversionContext,
56+
target: Target,
57+
source_ir: Optional[SourceIR],
58+
name: str,
59+
input_t: TRTTensor,
60+
shape: Sequence[int],
61+
broadcast_dimensions: Sequence[int],
62+
) -> TRTTensor:
63+
augmented_shape_list: List[Optional[int]] = list(shape)
64+
65+
# For each dimension being broadcasted, set the augmented shape to None
66+
for broadcast_dim in broadcast_dimensions:
67+
augmented_shape_list[broadcast_dim] = None
68+
69+
# TODO: Expand support to arbitrary broadcasts
70+
assert all(
71+
dim in (1, None) for dim in augmented_shape_list
72+
), "broadcast_in_dim currently only supports unsqueeze broadcasting"
73+
74+
# Unsqueeze the shape repeatedly to broadcast
75+
output = input_t
76+
for idx, x in enumerate(augmented_shape_list):
77+
# If the value is not None, that dimension is to be broadcasted
78+
if x is not None:
79+
output = unsqueeze(
80+
ctx,
81+
target,
82+
source_ir,
83+
name + f"_unsqueeze_for_broadcast_{idx}",
84+
output,
85+
idx,
86+
)
87+
88+
assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match"
89+
90+
return output
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Dict, Sequence, Tuple, Union
3+
4+
import torch
5+
from torch.fx.node import Argument, Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.fx.types import TRTTensor
10+
11+
from .converter_registry import dynamo_tensorrt_converter
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
# TODO: expand the scope of this converter with aten.expand implementation
17+
def broadcast_checker(broadcast_node: torch.fx.Node) -> bool:
18+
# The current implementation of broadcast_in_dim can only handle unsqueeze
19+
return all(
20+
broadcast_node.args[1][i] == 1
21+
for i in range(len(broadcast_node.args[1]))
22+
if i not in broadcast_node.args[2]
23+
)
24+
25+
26+
@dynamo_tensorrt_converter(
27+
torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker
28+
) # type: ignore[misc]
29+
def aten_ops_broadcast_in_dim(
30+
ctx: ConversionContext,
31+
target: Target,
32+
args: Tuple[Argument, ...],
33+
kwargs: Dict[str, Argument],
34+
name: str,
35+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
36+
return impl.unsqueeze.broadcast_in_dim(
37+
ctx,
38+
target,
39+
SourceIR.ATEN,
40+
name,
41+
args[0],
42+
args[1],
43+
args[2],
44+
)

tests/py/dynamo/conversion/test_div_aten.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def forward(self, lhs_val):
8282
expected_ops={torch.ops.aten.div.Tensor_mode},
8383
)
8484

85+
@parameterized.expand(
86+
[
87+
("2d", (2, 1)),
88+
("3d", (2, 1, 2)),
89+
]
90+
)
91+
def test_prims_div_tensor(self, _, shape):
92+
class div(nn.Module):
93+
def forward(self, lhs_val, rhs_val):
94+
return torch.ops.prims.div.default(lhs_val, rhs_val)
95+
96+
inputs = [torch.randn(shape), torch.randn(shape)]
97+
self.run_test(
98+
div(),
99+
inputs,
100+
expected_ops={torch.ops.prims.div.default},
101+
)
102+
85103

86104
if __name__ == "__main__":
87105
run_tests()

tests/py/dynamo/conversion/test_sum_aten.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,27 @@ def forward(self, x):
113113
)
114114

115115

116+
class TestPrimsSumConverter(DispatchTestCase):
117+
@parameterized.expand(
118+
[
119+
((3, 2, 4), [1]),
120+
((2, 1, 4, 5), [1, 2]),
121+
((2, 3, 4, 5), [0, 1, 2, 3]),
122+
((6, 7, 5, 4, 5), [1, 3, 4]),
123+
]
124+
)
125+
def test_sum_dim_sequence(self, input_shape, dim):
126+
class Sum(nn.Module):
127+
def forward(self, x):
128+
return torch.ops.prims.sum.default(x, dim)
129+
130+
inputs = [torch.randn(*input_shape)]
131+
self.run_test(
132+
Sum(),
133+
inputs,
134+
expected_ops={torch.ops.prims.sum.default},
135+
)
136+
137+
116138
if __name__ == "__main__":
117139
run_tests()

tests/py/dynamo/conversion/test_unsqueeze_aten.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
66
from torch_tensorrt import Input
7+
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
78

89
from .harness import DispatchTestCase
910

@@ -59,5 +60,51 @@ def forward(self, x):
5960
)
6061

6162

63+
class TestBroadcastInDim(DispatchTestCase):
64+
def test_broadcast_in_dim_supported(
65+
self,
66+
):
67+
class Unsqueeze(nn.Module):
68+
def forward(self, x):
69+
return torch.ops.prims.broadcast_in_dim.default(
70+
x, [4, 5, 6, 1, 1], [0, 1, 2]
71+
)
72+
73+
inputs = [torch.randn(4, 5, 6)]
74+
self.run_test(
75+
Unsqueeze(), inputs, expected_ops={torch.ops.prims.broadcast_in_dim.default}
76+
)
77+
78+
def test_broadcast_in_dim_supported_singleton(
79+
self,
80+
):
81+
class Unsqueeze(nn.Module):
82+
def forward(self, x):
83+
return torch.ops.prims.broadcast_in_dim.default(x, [1, 1, 1], [0, 1])
84+
85+
inputs = [torch.randn(1, 1)]
86+
self.run_test(
87+
Unsqueeze(), inputs, expected_ops={torch.ops.prims.broadcast_in_dim.default}
88+
)
89+
90+
# TODO: Remove this test when support is updated
91+
def test_broadcast_in_dim_unsupported(
92+
self,
93+
):
94+
class Unsqueeze(nn.Module):
95+
def forward(self, x):
96+
return torch.ops.prims.broadcast_in_dim.default(
97+
x, [4, 5, 6, 7, 1], [0, 1, 2]
98+
)
99+
100+
inputs = [torch.randn(4, 5, 6)]
101+
with self.assertRaises(UnsupportedOperatorException):
102+
self.run_test(
103+
Unsqueeze(),
104+
inputs,
105+
expected_ops={torch.ops.prims.broadcast_in_dim.default},
106+
)
107+
108+
62109
if __name__ == "__main__":
63110
run_tests()

tests/py/dynamo/conversion/test_where_aten.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ def forward(self, condition, x, y):
4444
expected_ops={torch.ops.aten.where.self},
4545
)
4646

47+
def test_const_input(self):
48+
class Where(nn.Module):
49+
def __init__(self, *args, **kwargs) -> None:
50+
super().__init__(*args, **kwargs)
51+
self.inputY = torch.randn((5, 6, 7))
52+
self.inputX = torch.randn((5, 6, 7))
53+
54+
def forward(self, condition):
55+
return torch.where(condition, self.inputX, self.inputY)
56+
57+
input1 = torch.randn((5, 6, 7))
58+
condition = input1 < 0
59+
self.run_test(
60+
Where(),
61+
(condition,),
62+
expected_ops={torch.ops.aten.where.self},
63+
)
64+
4765

4866
if __name__ == "__main__":
4967
run_tests()

0 commit comments

Comments
 (0)