Skip to content

Commit

Permalink
use 'paddle.framework.set_grad_enabled' in pylayer
Browse files Browse the repository at this point in the history
  • Loading branch information
hbwx24 committed Apr 23, 2021
1 parent 51bcd97 commit e9d935e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
5 changes: 3 additions & 2 deletions python/paddle/autograd/py_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def backward(ctx, dy):

class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs)
with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs)


class LayerMeta(type):
Expand Down
40 changes: 37 additions & 3 deletions python/paddle/fluid/tests/unittests/test_pylayer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,20 +283,54 @@ def test_pylayer_inplace(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x.mean()
return x

@staticmethod
def backward(ctx, dy):
return dy

class Layer(paddle.nn.Layer):
def __init__(self):
super(Layer, self).__init__()

def forward(self, data):
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
return z.mean()

for i in range(2):
data = paddle.ones([2, 3], dtype="float64") / (i + 1)
data.stop_gradient = False
layer = Layer()
z = layer(data)
z.backward()
self.assertTrue(data.grad is not None)

def test_backward_in_backward(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
temp = x.detach()
ctx.inputs = temp
return x.mean()

@staticmethod
def backward(ctx, dy):
with paddle.set_grad_enabled(True):
temp = ctx.inputs
temp.stop_gradient = False
z = paddle.tanh(temp)
z.backward()
self.assertTrue(temp.grad is not None)
return paddle.to_tensor(temp.grad)

for i in range(2):
data = paddle.ones([2, 3], dtype="float32") / (i + 1)
data.stop_gradient = False
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
z.backward()
self.assertTrue(data.grad is not None)


if __name__ == '__main__':
Expand Down

0 comments on commit e9d935e

Please sign in to comment.