Skip to content

Commit eb30b4f

Browse files
committed
Add test.
1 parent 12f35bf commit eb30b4f

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

test/test_ops_error_message.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,31 @@ def test():
180180
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181181
)
182182

183+
def test_avg_pool_3d_raises_error_on_bad_spec(self):
184+
device = torch_xla.device()
185+
a = torch.rand(1, 1, 4, 4, 4, device=device)
186+
187+
def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
188+
return lambda: torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)
189+
190+
self.assertExpectedRaisesInline(
191+
exc_type=RuntimeError,
192+
callable=gen_test_fn(kernel_size=[2, 2]),
193+
expect="""avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3."""
194+
)
195+
196+
self.assertExpectedRaisesInline(
197+
exc_type=RuntimeError,
198+
callable=gen_test_fn(stride=[1, 2]),
199+
expect="""avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3."""
200+
)
201+
202+
self.assertExpectedRaisesInline(
203+
exc_type=RuntimeError,
204+
callable=gen_test_fn(padding=[1, 2]),
205+
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
206+
)
207+
183208

184209
if __name__ == "__main__":
185210
unittest.main()

0 commit comments

Comments
 (0)