|
1 | 1 | import torch |
2 | 2 | import torch_tensorrt |
| 3 | +from parameterized import parameterized |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests |
4 | 5 |
|
5 | 6 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
@@ -421,5 +422,60 @@ def forward(self, x): |
421 | 422 | ) |
422 | 423 |
|
423 | 424 |
|
| 425 | +class TestScatterAdd(TestCase): |
| 426 | + @parameterized.expand( |
| 427 | + [ |
| 428 | + ( |
| 429 | + "scatter_add_zero_dim_indexOne_constant", |
| 430 | + 0, |
| 431 | + torch.tensor([[0, 1, 2, 0]]), |
| 432 | + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), |
| 433 | + ), |
| 434 | + ( |
| 435 | + "scatter_add_zero_dim_indexTwo_constant", |
| 436 | + 0, |
| 437 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 438 | + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), |
| 439 | + ), |
| 440 | + ( |
| 441 | + "scatter_add_one_dim_indexOne_constant", |
| 442 | + 1, |
| 443 | + torch.tensor([[0, 1, 2, 0]]), |
| 444 | + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), |
| 445 | + ), |
| 446 | + ( |
| 447 | + "scatter_add_one_dim_indexTwo_costant", |
| 448 | + 1, |
| 449 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 450 | + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), |
| 451 | + ), |
| 452 | + ] |
| 453 | + ) |
| 454 | + def test_scatter_add(self, _, dim, index, src): |
| 455 | + class TestModule(torch.nn.Module): |
| 456 | + def __init__(self): |
| 457 | + super().__init__() |
| 458 | + |
| 459 | + def forward(self, input): |
| 460 | + return torch.ops.aten.scatter_add.default(input, dim, index, src) |
| 461 | + |
| 462 | + # Operations expected to be included in the traced graph after decompositions |
| 463 | + expected_ops = {torch.ops.aten.scatter.src} |
| 464 | + |
| 465 | + input = torch.zeros(3, 5, dtype=torch.int32) |
| 466 | + inputs = [input] |
| 467 | + |
| 468 | + fx_graph = torch.fx.symbolic_trace(TestModule()) |
| 469 | + _, expected_ops_unseen = lower_graph_testing( |
| 470 | + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 |
| 471 | + ) |
| 472 | + |
| 473 | + self.assertEquals( |
| 474 | + len(expected_ops_unseen), |
| 475 | + 0, |
| 476 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 477 | + ) |
| 478 | + |
| 479 | + |
424 | 480 | if __name__ == "__main__": |
425 | 481 | run_tests() |
0 commit comments