Skip to content

Commit

Permalink
first fix the UT
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 15, 2023
1 parent 341afb7 commit 29eaaa1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/legacy_test/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def test_setitem(self):
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
x_grad_expected[1, 2, 3, 4] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case2: 0-D Tensor indice in some axis
Expand All @@ -847,6 +848,7 @@ def test_setitem(self):
self.assertEqual(out.shape, x.shape)
np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5)
x_grad_expected = np.ones((2, 3, 4, 5))
x_grad_expected[1, 1] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case3:0-D Tensor indice in some axis, value is a Tensor
Expand Down
1 change: 1 addition & 0 deletions test/xpu/test_zero_dim_tensor_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def test_setitem(self):
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
x_grad_expected[1, 2, 3, 4] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case2: 0-D Tensor indice in some axis
Expand Down

0 comments on commit 29eaaa1

Please sign in to comment.