@@ -24,6 +24,31 @@ def forward(self, x):
2424 inputs ,
2525 )
2626
27+ def test_layernorm_with_dynamic_shape (self ):
28+ class LayerNorm (torch .nn .Module ):
29+ def forward (self , x ):
30+ return torch .ops .aten .layer_norm .default (
31+ x ,
32+ torch .tensor ([3 , 224 , 224 ]),
33+ torch .ones ((3 , 224 , 224 )),
34+ torch .zeros ((3 , 224 , 224 )),
35+ 1e-05 ,
36+ True ,
37+ )
38+
39+ input_specs = [
40+ Input (
41+ shape = (- 1 , 3 , 224 , 224 ),
42+ dtype = torch .float32 ,
43+ shape_ranges = [((1 , 3 , 224 , 224 ), (1 , 3 , 224 , 224 ), (2 , 3 , 224 , 224 ))],
44+ ),
45+ ]
46+
47+ self .run_test_with_dynamic_shape (
48+ LayerNorm (),
49+ input_specs ,
50+ )
51+
2752
2853class TestNativeLayerNormConverter (DispatchTestCase ):
2954 def test_layer_norm (self ):
@@ -43,6 +68,30 @@ def forward(self, x):
4368 inputs ,
4469 )
4570
71+ def test_layernorm_with_dynamic_shape (self ):
72+ class LayerNorm (torch .nn .Module ):
73+ def forward (self , x ):
74+ return torch .ops .aten .native_layer_norm .default (
75+ x ,
76+ torch .tensor ([3 , 224 , 224 ]),
77+ torch .ones ((3 , 224 , 224 )),
78+ torch .zeros ((3 , 224 , 224 )),
79+ 1e-05 ,
80+ )[0 ]
81+
82+ input_specs = [
83+ Input (
84+ shape = (- 1 , 3 , 224 , 224 ),
85+ dtype = torch .float32 ,
86+ shape_ranges = [((1 , 3 , 224 , 224 ), (1 , 3 , 224 , 224 ), (2 , 3 , 224 , 224 ))],
87+ ),
88+ ]
89+
90+ self .run_test_with_dynamic_shape (
91+ LayerNorm (),
92+ input_specs ,
93+ )
94+
4695
4796if __name__ == "__main__" :
4897 run_tests ()
0 commit comments