From 29eaaa1237e58fa8539903c8642488bbcaeb0384 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 15 Nov 2023 08:43:18 +0000 Subject: [PATCH] first fix the UT --- test/legacy_test/test_zero_dim_tensor.py | 2 ++ test/xpu/test_zero_dim_tensor_xpu.py | 1 + 2 files changed, 3 insertions(+) diff --git a/test/legacy_test/test_zero_dim_tensor.py b/test/legacy_test/test_zero_dim_tensor.py index 10954ed59c60d7..9fb80fae82c2da 100644 --- a/test/legacy_test/test_zero_dim_tensor.py +++ b/test/legacy_test/test_zero_dim_tensor.py @@ -830,6 +830,7 @@ def test_setitem(self): np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10)) self.assertEqual(x.grad.shape, [2, 3, 4, 5]) x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 np.testing.assert_allclose(x.grad, x_grad_expected) # case2: 0-D Tensor indice in some axis @@ -847,6 +848,7 @@ def test_setitem(self): self.assertEqual(out.shape, x.shape) np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5) x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1, 1] = 0 np.testing.assert_allclose(x.grad, x_grad_expected) # case3:0-D Tensor indice in some axis, value is a Tensor diff --git a/test/xpu/test_zero_dim_tensor_xpu.py b/test/xpu/test_zero_dim_tensor_xpu.py index 08c3bc8a1814a4..a836e2e7fb58ed 100644 --- a/test/xpu/test_zero_dim_tensor_xpu.py +++ b/test/xpu/test_zero_dim_tensor_xpu.py @@ -441,6 +441,7 @@ def test_setitem(self): np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10)) self.assertEqual(x.grad.shape, [2, 3, 4, 5]) x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 np.testing.assert_allclose(x.grad, x_grad_expected) # case2: 0-D Tensor indice in some axis