@@ -420,7 +420,7 @@ def forward(self, x):
420420 f"MaxPool3d TRT outputs don't match with the original model." ,
421421 )
422422
423- def test_lowering_select_scatter_module (self ):
423+ def test_lowering_select_scatter_dimZero_module (self ):
424424 class selectScatter (torch .nn .Module ):
425425 def __init__ (self , * args , ** kwargs ) -> None :
426426 super ().__init__ (* args , ** kwargs )
@@ -483,6 +483,69 @@ def forward(self, x, src, dim, index):
483483 f"Select_scatter TRT outputs don't match with the original model." ,
484484 )
485485
486+ def test_lowering_select_scatter_dimOne_module (self ):
487+ class selectScatter (torch .nn .Module ):
488+ def __init__ (self , * args , ** kwargs ) -> None :
489+ super ().__init__ (* args , ** kwargs )
490+
491+ def forward (self , x , src , dim , index ):
492+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
493+ return y
494+
495+ # Operations expected to be removed in the traced graph after decompositions
496+ expected_ops = {
497+ torch .ops .aten .slice .Tensor ,
498+ torch .ops .aten .squeeze .dim ,
499+ torch .ops .aten .cat .default ,
500+ }
501+ unexpected_ops = {torch .ops .aten .select_scatter .default }
502+
503+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
504+
505+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
506+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
507+ fx_graph ,
508+ inputs ,
509+ expected_ops = expected_ops ,
510+ unexpected_ops = unexpected_ops ,
511+ min_block_size = 1 ,
512+ )
513+
514+ self .assertEquals (
515+ len (unexpected_ops_seen ),
516+ 0 ,
517+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
518+ )
519+
520+ self .assertEquals (
521+ len (expected_ops_unseen ),
522+ 0 ,
523+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
524+ )
525+
526+ torch ._dynamo .reset ()
527+
528+ # Validate that the results between Torch and Torch-TRT are similar
529+ optimized_model = torch_tensorrt .compile (
530+ fx_graph ,
531+ "torch_compile" ,
532+ inputs ,
533+ min_block_size = 1 ,
534+ pass_through_build_failures = True ,
535+ )
536+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
537+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
538+
539+ max_diff = float (
540+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
541+ )
542+ self .assertAlmostEqual (
543+ max_diff ,
544+ 0 ,
545+ DECIMALS_OF_AGREEMENT ,
546+ f"Select_scatter TRT outputs don't match with the original model." ,
547+ )
548+
486549
487550if __name__ == "__main__" :
488551 run_tests ()
0 commit comments