Skip to content

Commit a49c420

Browse files
committed
scatter_add_decomposition
1 parent 7d30714 commit a49c420

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,21 @@ def var_decomposition(
162162
return variance
163163

164164

165+
@register_torch_trt_decomposition(
166+
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
167+
)
168+
def scatter_add_decomposition(
169+
input_tensor: torch.Tensor,
170+
src_tensor: torch.Tensor,
171+
dim: int,
172+
index: torch.Tensor,
173+
) -> torch.Tensor:
174+
input_tensor_to_add = torch.empty_like(input_tensor)
175+
input_tensor_to_add = torch.scatter(input_tensor_to_add, dim, index, src_tensor)
176+
scatter_add_tensor = input_tensor + input_tensor_to_add
177+
return scatter_add_tensor
178+
179+
165180
def get_decompositions(
166181
enable_experimental_decompositions: bool = False,
167182
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
45

56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -421,5 +422,60 @@ def forward(self, x):
421422
)
422423

423424

425+
class TestScatterAdd(TestCase):
426+
@parameterized.expand(
427+
[
428+
(
429+
"scatter_add_zero_dim_indexOne_constant",
430+
0,
431+
torch.tensor([[0, 1, 2, 0]]),
432+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
433+
),
434+
(
435+
"scatter_add_zero_dim_indexTwo_constant",
436+
0,
437+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
438+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
439+
),
440+
(
441+
"scatter_add_one_dim_indexOne_constant",
442+
1,
443+
torch.tensor([[0, 1, 2, 0]]),
444+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
445+
),
446+
(
447+
"scatter_add_one_dim_indexTwo_costant",
448+
1,
449+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
450+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
451+
),
452+
]
453+
)
454+
def test_scatter_add(self, _, dim, index, src):
455+
class TestModule(torch.nn.Module):
456+
def __init__(self):
457+
super().__init__()
458+
459+
def forward(self, input):
460+
return torch.ops.aten.scatter_add.default(input, dim, index, src)
461+
462+
# Operations expected to be included in the traced graph after decompositions
463+
expected_ops = {torch.ops.aten.scatter.src}
464+
465+
input = torch.zeros(3, 5, dtype=torch.int32)
466+
inputs = [input]
467+
468+
fx_graph = torch.fx.symbolic_trace(TestModule())
469+
_, expected_ops_unseen = lower_graph_testing(
470+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=2
471+
)
472+
473+
self.assertEquals(
474+
len(expected_ops_unseen),
475+
0,
476+
f"The following expected ops were not encountered: {expected_ops_unseen}",
477+
)
478+
479+
424480
if __name__ == "__main__":
425481
run_tests()

0 commit comments

Comments
 (0)