From 10a384ea481be5512577d92517fffd1afb94396b Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 9 Apr 2024 14:17:25 -0700 Subject: [PATCH 1/2] scatter_add_decomposition Fixing scatter_add test cases. To do: fix the index collision cases Index collision cases Index collision cases- removing the torch.unique checl --- .../dynamo/lowering/_decompositions.py | 31 ++++++++++ .../py/dynamo/lowering/test_decompositions.py | 56 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 8ec5a95da2..7d7d22031e 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -243,6 +243,37 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: ) +@register_torch_trt_decomposition( + torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def scatter_add_decomposition( + input_tensor: torch.Tensor, + src_tensor: torch.Tensor, + dim: int, + index: torch.Tensor, +) -> torch.Tensor: + scatter_add_tensor = input_tensor + src_copy = src_tensor + src_shape = list(src_tensor.shape) + del src_shape[dim] + select_src_dim = src_copy.shape[dim] + to_stack_dummy_src = tuple(torch.empty(src_shape) for _ in range(select_src_dim)) + for index_src_dim in range(0, select_src_dim, 1): + select_tensor_dim = torch.select(src_copy, dim, index_src_dim) + to_stack_src = to_stack_dummy_src + if(index_src_dim == 0): + to_stack_src = (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:] + elif(index_src_dim == select_src_dim - 1 ): + to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + else: + to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:] + + stacked_src = torch.stack(to_stack_src, dim) + input_tensor_to_add = torch.scatter(torch.empty_like(input_tensor, dtype= torch.float32), dim, index, stacked_src.cuda()) + scatter_add_tensor = torch.add(scatter_add_tensor, input_tensor_to_add) + return scatter_add_tensor + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index edf7d04d44..422346a0f9 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -2,6 +2,7 @@ import torch_tensorrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from parameterized import parameterized from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -963,5 +964,60 @@ def forward(self, input): ) +class TestScatterAdd(TestCase): + @parameterized.expand( + [ + ( + "scatter_add_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + ), + ( + "scatter_add_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + ), + ( + "scatter_add_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + ), + ( + "scatter_add_one_dim_indexTwo_costant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + ), + ] + ) + def test_scatter_add(self, _, dim, index, src): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.scatter_add.default(input, dim, index, src) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.aten.scatter.src} + + input = torch.zeros(3, 5, dtype=torch.int32) + inputs = [input] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + if __name__ == "__main__": run_tests() From 2f7221e04a82ac22984c1349008151fff4963088 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 11 Jul 2024 14:07:24 -0700 Subject: [PATCH 2/2] changing the implementation and adding more test cases --- .../dynamo/lowering/_decompositions.py | 45 +++++---- .../py/dynamo/lowering/test_decompositions.py | 91 +++++++++++++++---- 2 files changed, 100 insertions(+), 36 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7d7d22031e..2729e38ff5 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -248,32 +248,39 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: ) def scatter_add_decomposition( input_tensor: torch.Tensor, - src_tensor: torch.Tensor, dim: int, index: torch.Tensor, + src_tensor: torch.Tensor, ) -> torch.Tensor: scatter_add_tensor = input_tensor - src_copy = src_tensor src_shape = list(src_tensor.shape) - del src_shape[dim] - select_src_dim = src_copy.shape[dim] - to_stack_dummy_src = tuple(torch.empty(src_shape) for _ in range(select_src_dim)) - for index_src_dim in range(0, select_src_dim, 1): - select_tensor_dim = torch.select(src_copy, dim, index_src_dim) - to_stack_src = to_stack_dummy_src - if(index_src_dim == 0): - to_stack_src = (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:] - elif(index_src_dim == select_src_dim - 1 ): - to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) - else: - to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:] - - stacked_src = torch.stack(to_stack_src, dim) - input_tensor_to_add = torch.scatter(torch.empty_like(input_tensor, dtype= torch.float32), dim, index, stacked_src.cuda()) - scatter_add_tensor = torch.add(scatter_add_tensor, input_tensor_to_add) + src_dim = src_shape[dim] + for i in range(0, src_dim): + to_scatter_tensor = torch.zeros_like(input_tensor) + + # index and src slice + src_slice = torch.select(src_tensor, dim, i) + index_slice = torch.select(index, dim, i) + + # unsqueeze src and index in dim + src_slice = torch.unsqueeze(src_slice, dim) + index_slice = torch.unsqueeze(index_slice, dim) + + # moving tensor to default device + device = to_torch_device(default_device()) + scatter_add_tensor = scatter_add_tensor.to(device) + to_scatter_tensor = to_scatter_tensor.to(device) + index_slice = index_slice.to(device) + src_slice = src_slice.to(device) + + scatter_add_tensor = torch.add( + scatter_add_tensor, + torch.scatter(to_scatter_tensor, dim, index_slice, src_slice), + ) + return scatter_add_tensor - + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 422346a0f9..a1416c00db 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -2,7 +2,6 @@ import torch_tensorrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests -from parameterized import parameterized from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -963,37 +962,60 @@ def forward(self, input): f"The optimized model results shape and torch model results shape should be equal in empty_stride", ) - -class TestScatterAdd(TestCase): @parameterized.expand( [ ( "scatter_add_zero_dim_indexOne_constant", 0, - torch.tensor([[0, 1, 2, 0]]), - torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor}, ), ( "scatter_add_zero_dim_indexTwo_constant", 0, - torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), - torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, ), ( "scatter_add_one_dim_indexOne_constant", 1, - torch.tensor([[0, 1, 2, 0]]), - torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + ), + ( + "scatter_add_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, ), ( - "scatter_add_one_dim_indexTwo_costant", + "scatter_add_one_dim_indexTwo_constant", 1, - torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), - torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1], [3, 2, 1, 2]]).cuda(), + torch.tensor( + [[1, 2, 3, 1], [5, 6, 5, 5], [2, 4, 3, 2]], dtype=torch.int32 + ).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, ), ] ) - def test_scatter_add(self, _, dim, index, src): + def test_scatter_add(self, _, dim, index, src, expected_ops_param): class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1002,14 +1024,19 @@ def forward(self, input): return torch.ops.aten.scatter_add.default(input, dim, index, src) # Operations expected to be included in the traced graph after decompositions - expected_ops = {torch.ops.aten.scatter.src} + expected_ops = expected_ops_param + unexpected_ops = {torch.ops.aten.scatter_add.default} - input = torch.zeros(3, 5, dtype=torch.int32) + input = torch.zeros(3, 5, dtype=torch.int32).cuda() inputs = [input] fx_graph = torch.fx.symbolic_trace(TestModule()) - _, expected_ops_unseen = lower_graph_testing( - fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=2, ) self.assertEquals( @@ -1018,6 +1045,36 @@ def forward(self, input): f"The following expected ops were not encountered: {expected_ops_unseen}", ) + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following expected ops were not encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Scatter_add TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()