From f1be9fe059ceced8652183d8951058b9be752aa4 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 15 Jul 2024 14:15:30 -0700 Subject: [PATCH 1/5] scatter reduce decomposition --- .../dynamo/conversion/impl/elementwise/ops.py | 7 + .../dynamo/lowering/_decompositions.py | 94 ++++ .../py/dynamo/lowering/test_decompositions.py | 407 ++++++++++++++++++ 3 files changed, 508 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 6a6b4ea3a1..3f8d9667b3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,6 +1,7 @@ from typing import Optional, Union import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target @@ -17,6 +18,7 @@ from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.converters.converter_utils import broadcast @@ -67,6 +69,11 @@ def trunc_div( prod_output, ) + # cast the sign_output back to int32 for trunc div + # This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32) + if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32): + sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name) + # Convert constant input into ITensor for UnaryOperation if not isinstance(input, trt.tensorrt.ITensor): input = get_trt_tensor(ctx, input, f"{name}_input") diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index f86e3c5cb5..9397b7c762 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,4 +1,5 @@ import logging +from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional import torch @@ -287,6 +288,99 @@ def scatter_add_decomposition( return scatter_add_tensor +# enum class for reduce operation of scatter_reduce +class ReduceOperation(Enum): + SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y)) + PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y)) + MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y)) + AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y)) + AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y)) + + def __new__(cls, description, func): + obj = object.__new__(cls) + obj._value_ = auto() + obj.description = description + obj.func = func + return obj + + def reduce_operation_with_scatter( + self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor + ): + scatter_tensor = None + if self == ReduceOperation.SUM or self == ReduceOperation.MEAN: + scatter_tensor = torch.zeros_like(initial_tensor) + elif self == ReduceOperation.PROD: + scatter_tensor = torch.ones_like(initial_tensor) + elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX: + scatter_tensor = initial_tensor + else: + # This case would not be encountered from torch itself + print("Invalid Operation for Reduce op!!") + + operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) + device = to_torch_device(default_device()) + operation_lhs = operation_lhs.to(device) + operation_rhs = operation_rhs.to(device) + return self.func(operation_lhs, operation_rhs) + + +@register_torch_trt_decomposition( + torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS +) +def scatter_reduce_decomposition( + input_tensor: torch.Tensor, + dim: int, + index: torch.Tensor, + src_tensor: torch.Tensor, + reduce: str, +) -> torch.Tensor: + scatter_loop_tensor = input_tensor + # required for mean reduce operation + scatter_count_tensor = torch.zeros_like(input_tensor) + src_shape = list(src_tensor.shape) + src_dim = src_shape[dim] + + for i in range(0, src_dim): + src_slice = torch.select(src_tensor, dim, i) + index_slice = torch.select(index, dim, i) + # unsqueeze src and index in dim + src_slice = torch.unsqueeze(src_slice, dim) + index_slice = torch.unsqueeze(index_slice, dim) + device = to_torch_device(default_device()) + + # moving tensor to default device + scatter_loop_tensor = scatter_loop_tensor.to(device) + index_slice = index_slice.to(device) + src_slice = src_slice.to(device) + if reduce == "sum": + reduceOp = ReduceOperation.SUM + elif reduce == "prod": + reduceOp = ReduceOperation.PROD + elif reduce == "mean": + reduceOp = ReduceOperation.MEAN + scatter_count_tensor = reduceOp.reduce_operation_with_scatter( + scatter_count_tensor, + input_tensor, + dim, + index_slice, + torch.ones_like(src_slice), + ) + elif reduce == "amax": + reduceOp = ReduceOperation.AMAX + elif reduce == "amin": + reduceOp = ReduceOperation.AMIN + scatter_loop_tensor = reduceOp.reduce_operation_with_scatter( + scatter_loop_tensor, input_tensor, dim, index_slice, src_slice + ) + if reduce == "mean": + scatter_loop_tensor = torch.div( + scatter_loop_tensor, + torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)), + rounding_mode="trunc", + ) + return scatter_loop_tensor + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 2f06aa7d23..6568dab202 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1129,6 +1129,413 @@ def forward(self, input): f"Scatter_add TRT outputs don't match with the original model.", ) + @parameterized.expand( + [ + ############################sum########################### + ( + "scatter_reduce_add_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor}, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + ), + ( + "scatter_reduce_add_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + ), + ( + "scatter_reduce_add_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + ), + ( + "scatter_reduce_add_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "sum", + ), + ( + "scatter_reduce_add_one_dim_indexOne_constant_3D", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "sum", + ), + ###########################prod########################### + ( + "scatter_reduce_prod_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + ), + ( + "scatter_reduce_prod_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + ), + ( + "scatter_reduce_prod_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + ), + ( + "scatter_reduce_prod_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.ones(3, 5, dtype=torch.int32).cuda(), + "prod", + ), + ( + "scatter_reduce_prod_one_dim_indexTwo_constant_3D", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.mul.Tensor, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.ones(3, 5, 6, dtype=torch.int32).cuda(), + "prod", + ), + # #############################mean########################### + ( + "scatter_reduce_mean_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + ), + ( + "scatter_reduce_mean_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + ), + ( + "scatter_reduce_mean_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + ), + ( + "scatter_reduce_mean_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "mean", + ), + ( + "scatter_reduce_mean_one_dim_indexTwo_constant_3D", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.add.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.scatter.src, + torch.ops.aten.full_like.default, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "mean", + ), + # #############################amax########################### + ( + "scatter_reduce_amax_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + ), + ( + "scatter_reduce_amax_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + ), + ( + "scatter_reduce_amax_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + ), + ( + "scatter_reduce_amax_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amax", + ), + ( + "scatter_reduce_amax_one_dim_indexTwo_constant_3D", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.maximum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "amax", + ), + # #############################amin########################### + ( + "scatter_reduce_amin_zero_dim_indexOne_constant", + 0, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + ), + ( + "scatter_reduce_amin_zero_dim_indexTwo_constant", + 0, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + ), + ( + "scatter_reduce_amin_one_dim_indexOne_constant", + 1, + torch.tensor([[0, 1, 2, 0]]).cuda(), + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + ), + ( + "scatter_reduce_amin_one_dim_indexTwo_constant", + 1, + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, dtype=torch.int32).cuda(), + "amin", + ), + ( + "scatter_reduce_amin_one_dim_indexTwo_constant_3D", + 1, + torch.tensor( + [[[0, 1, 2, 0], [1, 2, 1, 1]], [[3, 2, 1, 2], [0, 1, 2, 0]]] + ).cuda(), + torch.tensor( + [[[1, 2, 3, 1], [5, 6, 5, 5]], [[2, 4, 3, 2], [1, 2, 3, 1]]], + dtype=torch.int32, + ).cuda(), + { + torch.ops.aten.minimum.default, + torch.ops.aten.scatter.src, + }, + torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), + "amin", + ), + ] + ) + def test_scatter_reduce( + self, _, dim, index, src, expected_ops_param, input_reduce_op, reduce_op_str + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + + return torch.ops.aten.scatter_reduce_.two( + input, dim, index, src, reduce=reduce_op_str + ) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = expected_ops_param + unexpected_ops = {torch.ops.aten.scatter_reduce_.two} + + input = torch.zeros(3, 5, dtype=torch.int32).cuda() + inputs = [input_reduce_op] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following expected ops were not encountered: {unexpected_ops_seen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=3, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Scatter_reduce TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() From cd7d682daf1aee8583265bbd185666afac7cd16a Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 16 Aug 2024 17:11:12 -0700 Subject: [PATCH 2/5] addressing review comments-move all tensors to input tensor device --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 9397b7c762..e9a741621d 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -335,6 +335,7 @@ def scatter_reduce_decomposition( reduce: str, ) -> torch.Tensor: scatter_loop_tensor = input_tensor + device_input_tensor = input_tensor.device # required for mean reduce operation scatter_count_tensor = torch.zeros_like(input_tensor) src_shape = list(src_tensor.shape) @@ -346,12 +347,11 @@ def scatter_reduce_decomposition( # unsqueeze src and index in dim src_slice = torch.unsqueeze(src_slice, dim) index_slice = torch.unsqueeze(index_slice, dim) - device = to_torch_device(default_device()) # moving tensor to default device - scatter_loop_tensor = scatter_loop_tensor.to(device) - index_slice = index_slice.to(device) - src_slice = src_slice.to(device) + scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor) + index_slice = index_slice.to(device_input_tensor) + src_slice = src_slice.to(device_input_tensor) if reduce == "sum": reduceOp = ReduceOperation.SUM elif reduce == "prod": From 485adf95c715fae9b4875b3fea59460c98a23425 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 4 Sep 2024 14:42:12 -0700 Subject: [PATCH 3/5] removing full_like decomposition op after PR-3077 --- tests/py/dynamo/lowering/test_decompositions.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 6568dab202..4082caafb1 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1158,7 +1158,6 @@ def forward(self, input): { torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", @@ -1171,7 +1170,6 @@ def forward(self, input): { torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "sum", @@ -1189,7 +1187,6 @@ def forward(self, input): { torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "sum", @@ -1203,7 +1200,6 @@ def forward(self, input): { torch.ops.aten.mul.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", @@ -1216,7 +1212,6 @@ def forward(self, input): { torch.ops.aten.mul.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", @@ -1229,7 +1224,6 @@ def forward(self, input): { torch.ops.aten.mul.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", @@ -1242,7 +1236,6 @@ def forward(self, input): { torch.ops.aten.mul.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.ones(3, 5, dtype=torch.int32).cuda(), "prod", @@ -1260,7 +1253,6 @@ def forward(self, input): { torch.ops.aten.mul.Tensor, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.ones(3, 5, 6, dtype=torch.int32).cuda(), "prod", @@ -1300,7 +1292,6 @@ def forward(self, input): torch.ops.aten.add.Tensor, torch.ops.aten.div.Tensor_mode, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", @@ -1314,7 +1305,6 @@ def forward(self, input): torch.ops.aten.add.Tensor, torch.ops.aten.div.Tensor_mode, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, dtype=torch.int32).cuda(), "mean", @@ -1333,7 +1323,6 @@ def forward(self, input): torch.ops.aten.add.Tensor, torch.ops.aten.div.Tensor_mode, torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, torch.zeros(3, 5, 6, dtype=torch.int32).cuda(), "mean", From 35f2b00afca56f001d1859e45bfbefbd2fb62405 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 9 Sep 2024 12:18:25 -0700 Subject: [PATCH 4/5] changing the device setting in conversion.py --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 2 +- py/torch_tensorrt/dynamo/utils.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index e9a741621d..359574ecaf 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -318,7 +318,7 @@ def reduce_operation_with_scatter( print("Invalid Operation for Reduce op!!") operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor) - device = to_torch_device(default_device()) + device = to_torch_device(scatter_tensor.device) operation_lhs = operation_lhs.to(device) operation_rhs = operation_rhs.to(device) return self.func(operation_lhs, operation_rhs) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2af7922cd1..75fbf4c935 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -6,16 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -import tensorrt as trt from packaging import version from .types import TRTDataType @@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device: device = None for parameter in list(module.parameters()): if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)): - device = parameter.device - break + return parameter.device + + for buffer in list(module.buffers()): + if isinstance(buffer, (torch.Tensor)): + return buffer.device if device is None: - device = torch.device("cpu") + device = to_torch_device(default_device()) logger.warning( "Could not detect the device on which the model exists. Assuming the model is on CPU" ) From e0eda18a641c369dfa44c280c089bb1a0ca29f13 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 10 Sep 2024 21:08:33 -0700 Subject: [PATCH 5/5] Assertion error for include_self=False --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 359574ecaf..8c391afa5b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -333,6 +333,7 @@ def scatter_reduce_decomposition( index: torch.Tensor, src_tensor: torch.Tensor, reduce: str, + include_self: bool = True, ) -> torch.Tensor: scatter_loop_tensor = input_tensor device_input_tensor = input_tensor.device @@ -340,7 +341,8 @@ def scatter_reduce_decomposition( scatter_count_tensor = torch.zeros_like(input_tensor) src_shape = list(src_tensor.shape) src_dim = src_shape[dim] - + if include_self == False: + raise AssertionError("include_self False for scatter reduce not yet supported") for i in range(0, src_dim): src_slice = torch.select(src_tensor, dim, i) index_slice = torch.select(index, dim, i)