From 7f8849434d9871ffc6993b2943eec6d4fe492461 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 13 Nov 2023 20:28:07 -0800 Subject: [PATCH] fix: Error with `aten.view` across Tensor memory - Address error where `aten.view` is called on TRT output Tensors, which can be in a different memory format than Torch expects - Specifically, TRT can modify tensor memory to optimize certain layers, but Torch's view operator depends on specific configurations which can be violated at runtime (but not at compile time, since Torch itself would run these configurations correctly) - Add a custom lowering pass to replace `view` with `reshape`, avoiding this issue. Reshape will make a copy of the underlying Tensor if necessary - Torch-TRT's `aten.view` implementation is the same as that for `aten.reshape`, and they share a schema so no changes are needed on the converter side - Add test case to validate new lowering pass --- .../lowering/passes/_aten_lowering_pass.py | 2 + .../dynamo/lowering/passes/view_to_reshape.py | 41 +++++++++++ .../lowering/test_aten_lowering_passes.py | 68 ++++++++++++++++++- 3 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 604eda8c96..d6e12f5215 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -11,6 +11,7 @@ from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .view_to_reshape import view_to_reshape ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ @@ -21,6 +22,7 @@ lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, + view_to_reshape, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py new file mode 100644 index 0000000000..efc836814f --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -0,0 +1,41 @@ +import logging +from typing import Callable, List, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def view_to_reshape( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" + orig, replacement = view_replacement() + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") + + return gm + + +def view_replacement() -> ( + Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], + ] +): + """Constructs the original and replacement functions for view""" + + # Original graph + def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.view.default(input, shape) + + # Replacement graph + def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.reshape.default(input, shape) + + return orig, replacement diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 184e7c9c54..11b989bd90 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,7 +1,8 @@ import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -375,5 +376,70 @@ def forward(self, input, weight, bias): torch._dynamo.reset() +class TestLowerViewToReshape(TestCase): + def test_view_to_reshape(self): + class ViewToReshape(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.view.default(input, (1, 1, -1)) + return out + + inputs = [ + torch.rand((3, 4, 5, 32)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(ViewToReshape()) + expected_ops = {torch.ops.aten.reshape.default} + unexpected_ops = { + torch.ops.aten.view.default, + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"ViewToReshape TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests()