Skip to content

Commit 2ae73c8

Browse files
committed
Add test.
1 parent 5d3d878 commit 2ae73c8

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

test/test_ops_error_message.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,31 @@ def test():
251251
expect="""clamp(): expected at least one of `min` or `max` arguments to be specified."""
252252
)
253253

254+
def test_avg_pool_3d_raises_error_on_bad_spec(self):
255+
device = torch_xla.device()
256+
a = torch.rand(1, 1, 4, 4, 4, device=device)
257+
258+
def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
259+
return lambda: torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)
260+
261+
self.assertExpectedRaisesInline(
262+
exc_type=RuntimeError,
263+
callable=gen_test_fn(kernel_size=[2, 2]),
264+
expect="""avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3."""
265+
)
266+
267+
self.assertExpectedRaisesInline(
268+
exc_type=RuntimeError,
269+
callable=gen_test_fn(stride=[1, 2]),
270+
expect="""avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3."""
271+
)
272+
273+
self.assertExpectedRaisesInline(
274+
exc_type=RuntimeError,
275+
callable=gen_test_fn(padding=[1, 2]),
276+
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
277+
)
278+
254279

255280
if __name__ == "__main__":
256281
unittest.main()

0 commit comments

Comments
 (0)