Skip to content

Commit

Permalink
update test_uniform_random_inplace_op.py (#44852)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyefeilin authored Aug 3, 2022
1 parent a0bf44f commit 9a17f05
Showing 1 changed file with 5 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class TestUniformRandomInplaceGrad(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)

def test_uniform_random_inplace_grad(self):
def run_(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})

def test_grad():
Expand All @@ -191,33 +191,12 @@ def test_grad():
test_grad()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})


class TestUniformRandomInplaceGradOldDygraph(unittest.TestCase):

def setUp(self):
self.shape = (1000, 784)

def test_uniform_random_inplace_grad(self):
_enable_legacy_dygraph()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})

def test_grad():
tensor_a = paddle.ones(self.shape)
tensor_a.stop_gradient = False
tensor_b = tensor_a * 0.5
tensor_b.uniform_(min=-2, max=2)
loss = tensor_b.sum()
loss.backward()
uniform_grad = tensor_b.grad.numpy()
self.assertTrue((uniform_grad == 0).all())
self.run_()

places = ['cpu']
if fluid.core.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
test_grad()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_uniform_random_inplace_grad_old_dygraph(self):
_enable_legacy_dygraph()
self.run_()
_disable_legacy_dygraph()


Expand Down

0 comments on commit 9a17f05

Please sign in to comment.