Skip to content

Commit

Permalink
cherry pick 2740 to release2.4 branch. (#3033)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored Jul 24, 2024
1 parent 9a7f272 commit 9cced93
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 3 deletions.
46 changes: 43 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.utils import to_torch_device

from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
Expand Down Expand Up @@ -166,7 +168,7 @@ def var_decomposition(
@register_torch_trt_decomposition(
torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
def empty_permuted_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
empty_size = args[0]
empty_permute = args[1]
perm = [0] * len(empty_size)
Expand All @@ -185,7 +187,7 @@ def slice_scatter_decomposition(
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
):
) -> torch.Tensor:
dim_size = input_tensor.shape[dim]
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
Expand Down Expand Up @@ -230,12 +232,50 @@ def select_scatter_decomposition(
@register_torch_trt_decomposition(
torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
def empty_strided_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
empty_size = args[0]
empty_stride = args[1]
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)


@register_torch_trt_decomposition(
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def scatter_add_decomposition(
input_tensor: torch.Tensor,
dim: int,
index: torch.Tensor,
src_tensor: torch.Tensor,
) -> torch.Tensor:
scatter_add_tensor = input_tensor
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
for i in range(0, src_dim):
to_scatter_tensor = torch.zeros_like(input_tensor)

# index and src slice
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)

# unsqueeze src and index in dim
src_slice = torch.unsqueeze(src_slice, dim)
index_slice = torch.unsqueeze(index_slice, dim)

# moving tensor to default device
device = to_torch_device(default_device())
scatter_add_tensor = scatter_add_tensor.to(device)
to_scatter_tensor = to_scatter_tensor.to(device)
index_slice = index_slice.to(device)
src_slice = src_slice.to(device)

scatter_add_tensor = torch.add(
scatter_add_tensor,
torch.scatter(to_scatter_tensor, dim, index_slice, src_slice),
)

return scatter_add_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
113 changes: 113 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,119 @@ def forward(self, input):
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
)

@parameterized.expand(
[
(
"scatter_add_zero_dim_indexOne_constant",
0,
torch.tensor([[0, 1, 2, 0]]).cuda(),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
{torch.ops.aten.add.Tensor},
),
(
"scatter_add_zero_dim_indexTwo_constant",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(),
{torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src},
),
(
"scatter_add_one_dim_indexOne_constant",
1,
torch.tensor([[0, 1, 2, 0]]).cuda(),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
"scatter_add_one_dim_indexTwo_constant",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
"scatter_add_one_dim_indexTwo_constant",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1], [3, 2, 1, 2]]).cuda(),
torch.tensor(
[[1, 2, 3, 1], [5, 6, 5, 5], [2, 4, 3, 2]], dtype=torch.int32
).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
]
)
def test_scatter_add(self, _, dim, index, src, expected_ops_param):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter_add.default(input, dim, index, src)

# Operations expected to be included in the traced graph after decompositions
expected_ops = expected_ops_param
unexpected_ops = {torch.ops.aten.scatter_add.default}

input = torch.zeros(3, 5, dtype=torch.int32).cuda()
inputs = [input]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=2,
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following expected ops were not encountered: {unexpected_ops_seen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Scatter_add TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 9cced93

Please sign in to comment.