@@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase):
99 @parameterized .expand (
1010 [
1111 ("pos" , 1 ),
12- # ("neg", -2), #dim can not have dynamic input
12+ ("neg" , - 2 ),
1313 ]
1414 )
1515 def test_cat (self , _ , dim ):
@@ -27,7 +27,7 @@ def forward(self, x, y, z):
2727 @parameterized .expand (
2828 [
2929 ("pos" , 1 ),
30- # ("neg", -2), #dim can not have dynamic input
30+ ("neg" , - 2 ),
3131 ]
3232 )
3333 def test_cat_dynamic_shape (self , _ , dim ):
@@ -53,6 +53,41 @@ def forward(self, x, y):
5353 expected_ops = {torch .ops .aten .cat .default },
5454 )
5555
56+ def test_cat_no_dim (self ):
57+ class Cat (nn .Module ):
58+ def forward (self , x , y , z ):
59+ return torch .cat ((x , y , z ))
60+
61+ inputs = [torch .randn (2 , 1 , 3 ), torch .randn (1 , 1 , 3 ), torch .randn (3 , 1 , 3 )]
62+ self .run_test (
63+ Cat (),
64+ inputs ,
65+ expected_ops = {torch .ops .aten .cat .default },
66+ )
67+
68+ def test_cat_dynamic_shape_no_dim (self ):
69+ class Cat (nn .Module ):
70+ def forward (self , x , y ):
71+ return torch .cat ((x , y ))
72+
73+ input_specs = [
74+ InputTensorSpec (
75+ shape = (- 1 , 16 , 3 ),
76+ dtype = torch .float32 ,
77+ shape_ranges = [((2 , 16 , 3 ), (3 , 16 , 3 ), (32 , 16 , 3 ))],
78+ ),
79+ InputTensorSpec (
80+ shape = (- 1 , 16 , 3 ),
81+ dtype = torch .float32 ,
82+ shape_ranges = [((2 , 16 , 3 ), (3 , 16 , 3 ), (32 , 16 , 3 ))],
83+ ),
84+ ]
85+ self .run_test_with_dynamic_shape (
86+ Cat (),
87+ input_specs ,
88+ expected_ops = {torch .ops .aten .cat .default },
89+ )
90+
5691
5792if __name__ == "__main__" :
5893 run_tests ()
0 commit comments