Skip to content

fix: Error with aten.view across Tensor memory #2464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -21,6 +22,7 @@
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
view_to_reshape,
]
)

Expand Down
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Callable, List, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
"""Constructs the original and replacement functions for view"""

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)

return orig, replacement
68 changes: 67 additions & 1 deletion tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down Expand Up @@ -375,5 +376,70 @@ def forward(self, input, weight, bias):
torch._dynamo.reset()


class TestLowerViewToReshape(TestCase):
def test_view_to_reshape(self):
class ViewToReshape(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.view.default(input, (1, 1, -1))
return out

inputs = [
torch.rand((3, 4, 5, 32)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(ViewToReshape())
expected_ops = {torch.ops.aten.reshape.default}
unexpected_ops = {
torch.ops.aten.view.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"ViewToReshape TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()