44from parameterized import parameterized
55from torch .testing ._internal .common_utils import TestCase , run_tests
66
7+ from .harness import DispatchTestCase
8+
79rand_ops = [
810 (
911 "rand_one_dimension" ,
1315 (
1416 "rand_two_dimension" ,
1517 (lambda shape : torch .ops .aten .rand (shape )),
16- [2 , 3 ],
18+ [1 , 2 ],
1719 ),
1820 (
1921 "rand_three_dimension" ,
3537 (lambda shape : torch .ops .aten .randn (shape )),
3638 [2 , 3 , 4 ],
3739 ),
40+ ]
41+
42+
43+ rand_perm_ops = [
3844 (
3945 "randperm_one_case" ,
4046 (lambda x : torch .ops .aten .randperm (x )),
41- 1 ,
47+ [ 1 ] ,
4248 ),
4349 (
4450 "randperm_two_case" ,
4551 (lambda x : torch .ops .aten .randperm (x )),
46- 150 ,
52+ [ 150 ] ,
4753 ),
4854 (
4955 "randperm_three_case" ,
5056 (lambda x : torch .ops .aten .randperm (x )),
51- 1500 ,
57+ [ 1500 ] ,
5258 ),
5359]
5460
5561
56- class TestRandConverter (TestCase ):
62+ class TestRandConverter (DispatchTestCase ):
5763 @parameterized .expand (
5864 [
5965 (
@@ -64,41 +70,64 @@ class TestRandConverter(TestCase):
6470 for rand_op in rand_ops
6571 ]
6672 )
67- def test_rand (self , _ , op , shape_or_input ):
73+ def test_rand (self , name , op , shape_or_input ):
6874 class TestModule (nn .Module ):
69- def __init__ (self , rand_op , size ):
75+ def __init__ (self ):
7076 super ().__init__ ()
71- self .rand_op = rand_op
72- self .size = size
7377
74- def forward (self ):
75- return self .rand_op (self .size )
78+ def forward (self , x ):
79+ shape_or_input [0 ] = x .shape [0 ]
80+ return op (shape_or_input )
7681
77- rand_model = TestModule (op , shape_or_input )
78- # cannot use self.run_test() since it expects input in form of tensor
82+ rand_model = TestModule ()
7983
80- fx_graph = torch .fx . symbolic_trace ( grid_model )
81- torch . _dynamo . reset ()
82-
83- optimized_model = torch_tensorrt . compile (
84- fx_graph ,
85- "torch_compile" ,
86- None ,
87- min_block_size = 1 ,
88- pass_through_build_failures = True ,
89- truncate_long_and_double = True ,
90- debug = True ,
84+ inputs = [ torch .randint ( 1 , 3 , shape_or_input , dtype = torch . int32 )]
85+ comparator_shape = lambda x , y , check_dtype : x . shape == y . shape and (
86+ x . dtype == y . dtype if check_dtype else True
87+ )
88+ expected_ops = []
89+ self . run_test_comparator (
90+ rand_model ,
91+ inputs ,
92+ expected_ops ,
93+ [( comparator_shape , [ True ])] ,
94+ use_dynamo_tracer = True ,
9195 )
92- optimized_model_results = optimized_model ().detach ().cpu ()
93- torch_model_results = fx_graph ().detach ().cpu ()
94- max_diff = float (
95- torch .max (torch .abs (optimized_model_results - torch_model_results ))
96+
97+ @parameterized .expand (
98+ [
99+ (
100+ rand_op [0 ],
101+ rand_op [1 ],
102+ rand_op [2 ],
103+ )
104+ for rand_op in rand_perm_ops
105+ ]
106+ )
107+ def test_rand (self , name , op , shape_or_input ):
108+ class TestModule (nn .Module ):
109+ def __init__ (self ):
110+ super ().__init__ ()
111+
112+ def forward (self , x ):
113+ shape_or_input [0 ] = x .shape [0 ]
114+ return op (shape_or_input [0 ])
115+
116+ rand_model = TestModule ()
117+ # cannot use self.run_test() since it expects input in form of tensor
118+
119+ inputs = [torch .randint (1 , 3 , shape_or_input , dtype = torch .int32 )]
120+ comparator_shape = lambda x , y , check_dtype : x .shape == y .shape and (
121+ x .dtype == y .dtype if check_dtype else True
96122 )
97- self .assertAlmostEqual (
98- max_diff ,
99- 0 ,
100- 4 ,
101- f"TRT outputs don't match with the original model." ,
123+ expected_ops = []
124+ # TRT-np returns int32 while torch returns float32
125+ self .run_test_comparator (
126+ rand_model ,
127+ inputs ,
128+ expected_ops ,
129+ [(comparator_shape , [False ])],
130+ use_dynamo_tracer = True ,
102131 )
103132
104133
0 commit comments