@@ -251,6 +251,31 @@ def test():
251251 expect = """clamp(): expected at least one of `min` or `max` arguments to be specified."""
252252 )
253253
254+ def test_avg_pool_3d_raises_error_on_bad_spec (self ):
255+ device = torch_xla .device ()
256+ a = torch .rand (1 , 1 , 4 , 4 , 4 , device = device )
257+
258+ def gen_test_fn (kernel_size = [2 , 2 , 2 ], stride = [], padding = [0 ]):
259+ return lambda : torch .nn .functional .avg_pool3d (a , kernel_size , stride , padding )
260+
261+ self .assertExpectedRaisesInline (
262+ exc_type = RuntimeError ,
263+ callable = gen_test_fn (kernel_size = [2 , 2 ]),
264+ expect = """avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3."""
265+ )
266+
267+ self .assertExpectedRaisesInline (
268+ exc_type = RuntimeError ,
269+ callable = gen_test_fn (stride = [1 , 2 ]),
270+ expect = """avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3."""
271+ )
272+
273+ self .assertExpectedRaisesInline (
274+ exc_type = RuntimeError ,
275+ callable = gen_test_fn (padding = [1 , 2 ]),
276+ expect = """avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
277+ )
278+
254279
255280if __name__ == "__main__" :
256281 unittest .main ()
0 commit comments