diff --git a/README.md b/README.md index 0c86c41ed2..9bf645af0b 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ bazel build //:libtorchtrt --compilation_mode opt ``` ### FX path (Python only) installation -If the user plan to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below. +If the user plans to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below. ``` shell cd py && python3 setup.py install --fx-only ``` diff --git a/examples/fx/quantized_resnet_test.py b/examples/fx/quantized_resnet_test.py index 13a044b53b..c25691b95b 100644 --- a/examples/fx/quantized_resnet_test.py +++ b/examples/fx/quantized_resnet_test.py @@ -6,7 +6,11 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer import torchvision.models as models -from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx +from torch.ao.quantization.quantize_fx import ( + convert_fx, + convert_to_reference, + prepare_fx, +) from torch.fx.experimental.normalize import NormalizeArgs from torch.fx.passes import shape_prop from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule @@ -48,7 +52,7 @@ def build_int8_trt(rn18): prepared = prepare_fx(rn18, {"": qconfig}) for _ in range(10): prepared(data) - quantized_rn18 = convert_fx(prepared, is_reference=True) + quantized_rn18 = convert_to_reference(prepared) ref_res = quantized_rn18(data) print("quantized model:", quantized_rn18) diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 49765f4fd3..5558df28f5 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -2222,7 +2222,7 @@ def acc_ops_adaptive_avg_poolnd( extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3 assert all( input_val.shape[-(i + 1)] != -1 for i in range(extend_len) - ), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims." + ), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." output_size = cast( Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index c59763b5ab..e54bd83efb 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -415,9 +415,8 @@ def add_binary_elementwise_layer( This function adds a TensorRT elementwise layer. We allow both operands to be constant (not a trt tensor) because in implicit batch dimension mode, we could introduce constant via .size() op. Other scenario should be const folded first. - If any operand is not a trt tensor, we make it a trt constant layer which has - the same type as the other trt tensor. Then we broadcast these two inputs to - have the same number of dimensions. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. Limitation: If we are using implicit batch dim mode, the operand that is not a trt @@ -436,14 +435,16 @@ def add_binary_elementwise_layer( Returns: The output of TensorRT Elementwise layer. """ - dtype = None + lhs_dtype = None + rhs_dtype = None is_lhs_trt_tensor = False is_rhs_trt_tensor = False + if isinstance(lhs_val, TRTTensor): - dtype = torch_dtype_from_trt(lhs_val.dtype) + lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - dtype = torch_dtype_from_trt(rhs_val.dtype) + rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: @@ -463,10 +464,14 @@ def add_binary_elementwise_layer( # this way the shape will become [1], and then will be properly squeezed # into [], meaning that the result will have shape [], which is what we # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=dtype) + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=dtype) + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) # When lhs is scalar, and rhs has shape [1,], then currently the assert # will fail because lhs shape has fewer dimensions than rhs shape. This @@ -482,8 +487,8 @@ def add_binary_elementwise_layer( if isinstance(rhs_val, torch.Tensor): rhs_val = squeeze_left(rhs_val) - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype) + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) # Check the limitation in the doc string. if network.has_implicit_batch_dimension: diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 29b1490586..7deed3e470 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -80,7 +80,7 @@ def __init__( ] = dict() def validate_input_specs(self): - for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs: + for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: assert ( has_batch_dim diff --git a/py/torch_tensorrt/fx/passes/graph_opts.py b/py/torch_tensorrt/fx/passes/graph_opts.py new file mode 100644 index 0000000000..2adc5c7fe3 --- /dev/null +++ b/py/torch_tensorrt/fx/passes/graph_opts.py @@ -0,0 +1,74 @@ +from collections.abc import Sequence + +import torch +import torch.fx + + +def common_subexpression_elimination(graph_module: torch.fx.GraphModule) -> bool: + """ + Optimize quantization by removing repeated subexpressions. + + Args: + graph_module(torch.fx.GraphModule): target module to be optimized + + Returns: + Graph changed or not. + """ + + def seq_hashable(seq): + if seq is None: + return None + + items = [] + for old in seq: + if isinstance(old, Sequence) and not isinstance(old, str): + new = seq_hashable(old) + elif isinstance(old, dict): + new = dict_hashable(old) + elif isinstance(old, slice): + new = old.__reduce__() + else: + new = old + + items.append(new) + + return tuple(items) + + def dict_hashable(d): + if d is None: + return None + + items = [] + for k, old_v in d.items(): + if isinstance(old_v, Sequence): + new_v = seq_hashable(old_v) + elif isinstance(old_v, dict): + new_v = dict_hashable(old_v) + elif isinstance(old_v, slice): + new_v = old_v.__reduce__() + else: + new_v = old_v + + items.append((k, new_v)) + return tuple(sorted(items)) + + changed = False + env = {} + for n in graph_module.graph.nodes: + # do not CSE away impure ops + if n.op not in {"call_function", "call_method"} or n.is_impure(): + continue + + # hash target, args, kwargs + hash_val = (n.target, seq_hashable(n.args), dict_hashable(n.kwargs)) + + # check if a node has a substitute and can be eliminated + if hash_val in env: + n.replace_all_uses_with(env[hash_val]) + graph_module.graph.erase_node(n) + changed = True + continue + + env[hash_val] = n + + return changed diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 0fc3557069..0f8e2233a2 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -1,5 +1,5 @@ from functools import partial, wraps -from typing import Any, Callable, NamedTuple, Sequence +from typing import Any, Callable, Sequence import torch from torch import nn @@ -10,6 +10,7 @@ from ..lower_setting import LowerSetting from ..observer import Observer from ..passes.remove_duplicate_output_args import remove_duplicate_output_args +from .graph_opts import common_subexpression_elimination from .lower_basic_pass import run_const_fold @@ -94,6 +95,8 @@ def graph_optimization_pass(self) -> PassManager: passes.append(wrapper(p, self._input)) for p in self.lower_setting.lower_basic_fuse_pass.passes: passes.append(wrapper(p, self._input)) + + passes.append(inplace_wrapper(common_subexpression_elimination)) passes.append( inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py index 3b60c551df..af211f79d3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py @@ -92,6 +92,8 @@ def forward(self, x): TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d} ) + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py index 5ef417605d..e1d24766ae 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_any.py @@ -5,6 +5,8 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +# from torch_tensorrt.fx.tools.common_fx2trt import InputTensorSpec + class TestAnyConverters(AccTestCase): @parameterized.expand( @@ -64,6 +66,26 @@ def forward(self, x): test_implicit_batch_dim=False, ) + # Testing with shape (-1, -1, -1, -1) results into error: torch.zeros(tuple([*input_t.shape])). Trying to create tensor with negative dimension -1: [-1, -1, -1, -1] + """ + def test_ops_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.any(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.any} + ) + """ + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py index dc014a7e6c..9ba1f83474 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py @@ -3,7 +3,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestConverter(AccTestCase): @@ -30,6 +30,26 @@ def forward(self, x): test_implicit_batch_dim=False, ) + # Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1 + """ + def test_as_strided_with_dynamic_shape_four_dimensions(self): + class Stride(nn.Module): + def forward(self, x): + return torch.as_strided(torch.tensor([5, 5]), (2, 3), (1, 2), 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Stride(), input_specs, expected_ops={acc_ops.as_strided} + ) + """ + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py index ca69de7afa..91e9ca9c90 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestAvgPoolConverter(AccTestCase): @@ -39,6 +39,43 @@ def forward(self, x): inputs = [torch.randn(1, 3, 224)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool1d}) + def test_avg_pool2d_with_dynamic_shape_four_dimensions( + self, + test_name="default", + kernel_size=1, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool2d( + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def forward(self, x): + return self.avg_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} + ) + @parameterized.expand( [ ("default", 1), @@ -84,7 +121,7 @@ def forward(self, x): param("stride", 2, stride=()), ] ) - def test_stride_none__avg_pool1d( + def test_stride_none_avg_pool1d( self, test_name, kernel_size, @@ -144,6 +181,75 @@ def forward(self, x): inputs = [torch.randn(1, 3, 224, 224)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.avg_pool2d}) + def test_stride_none_avg_pool2d_with_dynamic_shape_four_dimensions( + self, + test_name="default", + kernel_size=1, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.avg_pool2d( + x, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.avg_pool2d} + ) + + # Testing with (-1, -1, -1, -1) results in error: RuntimeError: ShapeProp error for: node=%avg_pool1d : [#users=1] = call_function[target=torch.avg_pool1d](args = (%x, (1,), (1,), (0,), False, True), kwargs = {}) with meta={} + """ + def test_avg_pool1d_with_dynamic_shape_four_dimensions( + self, + test_name="default", + kernel_size=1, + stride=1, + padding=0, + ceil_mode=False, + count_include_pad=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avg_pool = torch.nn.AvgPool1d( + kernel_size, stride, padding, ceil_mode, count_include_pad + ) + + def forward(self, x): + return self.avg_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test(TestModule(), input_specs, expected_ops={acc_ops.avg_pool1d}) + """ + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py index 5786f2ecba..7b282f5bde 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py @@ -38,6 +38,8 @@ def forward(self, x): TestModule(), input_specs, expected_ops={acc_ops.batch_norm} ) + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py index 56a37b04b0..6da7a4e205 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py @@ -142,6 +142,46 @@ def forward(self, x, y): self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + def test_elementwise_ops_with_scalar_lhs(self): def orig_op(x, y): return x + y diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py index 9408c5f6bc..97360c75e5 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py @@ -42,6 +42,27 @@ def forward(self, x, y): ] self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) + def test_cat_with_dynamic_shape_four_dimensions(self): + class Cat(nn.Module): + def forward(self, x, y): + x = x + y + return torch.cat((x, y), 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 4), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py index f1bf53dc07..555f0ba24b 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py @@ -49,6 +49,31 @@ def forward(self, x): Chunk(), input_specs, expected_ops={acc_ops.chunk} ) + # Testing with (-1, -1, -1, -1) results in Error: AssertionError: Can't chunk on dynamic shape dimension! + @parameterized.expand( + [ + ("chunk", 3, 1), + ("chunk", 2000, 1), + ("chunk", 3, -2), + ] + ) + def test_chunk_with_dynamic_shape_four_dimensions(self, _, chunk, dim): + class Chunk(nn.Module): + def forward(self, x): + return x.chunk(chunk, dim)[0] + + input_specs = [ + InputTensorSpec( + shape=(-1, 1, 3, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 5), (3, 1, 3, 5), (5, 1, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Chunk(), input_specs, expected_ops={acc_ops.chunk} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index 0309cf3e3f..2ba6273daa 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -27,6 +27,37 @@ def forward(self, x): inputs = [torch.randn(3, 4)] self.run_test(TestModule(), inputs, expected_ops={acc_ops.clamp}) + # Error: RuntimeError: ShapeProp error for: node=%clamp : [#users=1] = call_function[target=torch.clamp](args = (%x, 1, 0), kwargs = {}) with meta={} + """ + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + ] + ) + def test_clamp_with_dynamic_shape_four_dimensions( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], + ), + ] + + self.run_test(TestModule(), input_specs, expected_ops={acc_ops.clamp}) + """ + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py index d75afeef3c..c379da0217 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_convolution.py @@ -44,111 +44,111 @@ def forward(self, x): test_explicit_precision=True, ) - # @parameterized.expand( - # [ - # ("default", 1), - # param("no_bias", 1, bias=False), - # ("tuple_parameters", 1, (1, 1), (1, 1)), - # param("non_zero_padding", 1, padding=1), - # param("dilation", 1, dilation=2), - # param("groups", 1, groups=3), - # ] - # ) - # def test_conv2d( - # self, - # _, - # kernel_size, - # stride=1, - # padding=0, - # dilation=1, - # groups=1, - # bias=True, - # ): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv2d( - # 3, 6, kernel_size, stride, padding, dilation, groups, bias - # ) - - # def forward(self, x): - # return self.conv(x) - - # inputs = [torch.randn(1, 3, 32, 32)] - # self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) - - # def test_conv2d_with_dynamic_shape(self): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv2d(3, 6, 1) - - # def forward(self, x): - # return self.conv(x) - - # input_specs = [ - # InputTensorSpec( - # shape=(-1, 3, -1, -1), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(), input_specs, expected_ops={acc_ops.conv2d} - # ) - - # @parameterized.expand( - # [ - # ("default", 1), - # param("no_bias", 1, bias=False), - # ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), - # param("non_zero_padding", 1, padding=1), - # param("dilation", 1, dilation=2), - # param("groups", 1, groups=3), - # ] - # ) - # def test_conv3d( - # self, - # _, - # kernel_size, - # stride=1, - # padding=0, - # dilation=1, - # groups=1, - # bias=True, - # ): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv3d( - # 3, 6, kernel_size, stride, padding, dilation, groups, bias - # ) - - # def forward(self, x): - # return self.conv(x) - - # inputs = [torch.randn(1, 3, 32, 32, 32)] - # self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) - - # def test_conv3d_with_dynamic_shape(self): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.conv = torch.nn.Conv3d(3, 6, 1) - - # def forward(self, x): - # return self.conv(x) - - # input_specs = [ - # InputTensorSpec( - # shape=(-1, 3, -1, -1, -1), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(), input_specs, expected_ops={acc_ops.conv3d} - # ) + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv2d}) + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={acc_ops.conv3d}) + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.conv3d} + ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py index 61aa581e1f..835e50c10a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_sigmoid.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestSigmoid(AccTestCase): @@ -14,6 +14,22 @@ def forward(self, x): inputs = [torch.randn(1, 2, 3)] self.run_test(Sigmoid(), inputs, expected_ops={acc_ops.sigmoid}) + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class Sigmoid(nn.Module): + def forward(self, x): + return torch.sigmoid(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Sigmoid(), input_specs, expected_ops={acc_ops.sigmoid} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py new file mode 100644 index 0000000000..f240c95514 --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py @@ -0,0 +1,183 @@ +import unittest +from collections import Counter +from typing import Callable, Dict, List + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer +from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination + + +def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: + """ + Helper func to print model's graph in plain and tabular format, also print code. + """ + print(mod_graph.graph) + mod_graph.graph.print_tabular() + print(mod_graph.code) + + +@torch.fx.wrap +def test_op(keys, value): + return value + + +class GraphOptsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def _test_opt_with_module( + self, + module: torch.nn.Module, + inputs: List, + opt: Callable, + should_change_graph: bool, + deleted_ops: Dict = None, + created_ops: Dict = None, + rtol: float = None, + atol: float = None, + ): + assert should_change_graph or not bool(deleted_ops or created_ops) + deleted_ops = deleted_ops or {} + created_ops = created_ops or {} + module.eval() + + # Before Opt + before_results = module(*inputs) + mod_traced = acc_tracer.trace(module, inputs) + before_node_list = list(mod_traced.graph.nodes) + print("Model before opt.") + debug_print_graph_module(mod_traced) + + # Apply Opt + graph_changed = bool(opt(mod_traced)) + + # After Opt + after_results = mod_traced(*inputs) + after_node_list = list(mod_traced.graph.nodes) + print("Model after opt.") + mod_traced.recompile() + debug_print_graph_module(mod_traced) + + # Tests + # * Numerics + tol_args = {} + if rtol is not None: + tol_args["rtol"] = rtol + if atol is not None: + tol_args["atol"] = atol + torch.testing.assert_close(before_results, after_results, **tol_args) + + # * opt changes graph + self.assertEqual(graph_changed, before_node_list != after_node_list) + self.assertEqual(should_change_graph, graph_changed) + + # * modified nodes + before_node_set = set(before_node_list) + after_node_set = set(after_node_list) + self.assertEqual( + dict(Counter([node.target for node in before_node_set - after_node_set])), + deleted_ops, + ) + self.assertEqual( + dict(Counter([node.target for node in after_node_set - before_node_set])), + created_ops, + ) + + return mod_traced + + def test_common_subexpression_elimination(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + xx = x + x + xx2 = x + x + return xx * xx2 - x + + self._test_opt_with_module( + module=TestModule(), + inputs=[torch.rand(3, 2, 1)], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.add: 1}, + ) + + def test_common_subexpression_elimination2(self): + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + self._test_opt_with_module( + module=TestModule2(), + inputs=[torch.rand(3, 2, 1)], + opt=common_subexpression_elimination, + should_change_graph=False, + ) + + def test_common_subexpression_elimination3(self): + class TestModule3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + x = a * b + y = b - c + z = a * b + xy = x + y + zy = z + y + return xy - zy + + self._test_opt_with_module( + module=TestModule3(), + inputs=[ + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.add: 1, acc_ops.mul: 1}, + ) + + def test_common_subexpression_elimination4(self): + class TestModule3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + x = torch.cat([a, b, c]) + y = torch.cat([a, b, c]) + z = torch.cat([c, b, a]) + return x + y + z + + self._test_opt_with_module( + module=TestModule3(), + inputs=[ + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=True, + deleted_ops={acc_ops.cat: 1}, + ) + + def test_common_subexpression_elimination_string_arg(self): + class TestModule(torch.nn.Module): + def forward(self, a): + x = test_op(["foo", "bar"], a) + return x + + self._test_opt_with_module( + module=TestModule(), + inputs=[ + torch.rand(3, 2, 1), + ], + opt=common_subexpression_elimination, + should_change_graph=False, + ) diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index 4bdc1124f9..4fabc7f18d 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -15,7 +15,7 @@ from torch.ao.quantization.backend_config.observation_type import ObservationType from torch.ao.quantization.fx.match_utils import MatchAllNode from torch.ao.quantization.quantize_fx import ( - convert_fx, + convert_to_reference, get_tensorrt_backend_config_dict, prepare_fx, prepare_qat_fx, @@ -96,9 +96,7 @@ def forward(self, x): ) self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) mp(torch.randn(1, 1, 4, 4)) - mq = convert_fx( - mp, is_reference=True, backend_config_dict=self.trt_backend_config_dict - ) + mq = convert_to_reference(mp, backend_config_dict=self.trt_backend_config_dict) self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) def test_quantized_input_quantized_output(self): @@ -260,7 +258,7 @@ def forward(self, x): ) # check converted/quantized model - m = convert_fx(m, is_reference=True, backend_config_dict=backend_config_dict) + m = convert_to_reference(m, backend_config_dict=backend_config_dict) self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) self.checkGraphModuleNodes( m.standalone, expected_node_occurrence=standalone_convert_count_check @@ -275,9 +273,7 @@ def forward(self, x): backend_config_dict=backend_config_dict, ) ref_m(data) - ref_m = convert_fx( - ref_m, is_reference=True, backend_config_dict=backend_config_dict - ) + ref_m = convert_to_reference(ref_m, backend_config_dict=backend_config_dict) ref_res = ref_m(data) self.assertEqual(res, ref_res) @@ -437,9 +433,8 @@ def _test_module( self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare) # calibration prepared(*inputs) - quantized = convert_fx( + quantized = convert_to_reference( prepared, - is_reference=True, backend_config_dict=self.trt_backend_config_dict, ) self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert) @@ -556,9 +551,7 @@ def forward(self, x): example_inputs, backend_config_dict=self.trt_backend_config_dict, ) - m = convert_fx( - m, is_reference=True, backend_config_dict=self.trt_backend_config_dict - ) + m = convert_to_reference(m, backend_config_dict=self.trt_backend_config_dict) expected_occurrence = { ns.call_function(torch.quantize_per_tensor): 5, ns.call_method("dequantize"): 5, @@ -591,9 +584,8 @@ def forward(self, x): ) # calibration prepared(linear_module_input) - quantized = convert_fx( + quantized = convert_to_reference( prepared, - is_reference=True, backend_config_dict=self.trt_backend_config_dict, ) node_occurrence = { @@ -622,9 +614,8 @@ def forward(self, x): backend_config_dict=self.trt_backend_config_dict, ) self.assertTrue(len(dict(prepared.named_children())) == 1) - quantized = convert_fx( + quantized = convert_to_reference( prepared, - is_reference=True, backend_config_dict=self.trt_backend_config_dict, ) node_occurrence = { @@ -659,9 +650,8 @@ def forward(self, x): ns.call_module(torch.ao.quantization.HistogramObserver): 2, } self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) - quantized = convert_fx( + quantized = convert_to_reference( prepared, - is_reference=True, backend_config_dict=self.trt_backend_config_dict, ) node_occurrence = { @@ -729,9 +719,7 @@ def conv_add_extra_inputs_getter(pattern): ns.call_module(torch.ao.quantization.HistogramObserver): 3, } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) - m = convert_fx( - m, is_reference=True, backend_config_dict=modified_backend_config_dict - ) + m = convert_to_reference(m, backend_config_dict=modified_backend_config_dict) node_occurrence = { ns.call_function(torch.quantize_per_tensor): 3, ns.call_method("dequantize"): 3, @@ -843,7 +831,7 @@ def forward(self, x): self.checkGraphModuleNodes( m.standalone, expected_node_occurrence=standalone_node_occurrence ) - m = convert_fx(m, is_reference=True, backend_config_dict=backend_config_dict) + m = convert_to_reference(m, backend_config_dict=backend_config_dict) node_occurrence = { # two inputs for standalone module ns.call_function(torch.quantize_per_tensor): 2, @@ -882,9 +870,8 @@ def forward(self, x): example_inputs, backend_config_dict=self.trt_backend_config_dict, ) - quantized = convert_fx( + quantized = convert_to_reference( prepared, - is_reference=True, backend_config_dict=self.trt_backend_config_dict, )