Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support backward during backward. #32355

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/autograd/py_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def backward(ctx, dy):

class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs)
with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs)


class LayerMeta(type):
Expand Down
40 changes: 37 additions & 3 deletions python/paddle/fluid/tests/unittests/test_pylayer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,20 +283,54 @@ def test_pylayer_inplace(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x.mean()
return x

@staticmethod
def backward(ctx, dy):
return dy

class Layer(paddle.nn.Layer):
def __init__(self):
super(Layer, self).__init__()

def forward(self, data):
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
return z.mean()

for i in range(2):
data = paddle.ones([2, 3], dtype="float64") / (i + 1)
data.stop_gradient = False
layer = Layer()
z = layer(data)
z.backward()
self.assertTrue(data.grad is not None)

def test_backward_in_backward(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
temp = x.detach()
ctx.inputs = temp
return x.mean()

@staticmethod
def backward(ctx, dy):
with paddle.set_grad_enabled(True):
temp = ctx.inputs
temp.stop_gradient = False
z = paddle.tanh(temp)
z.backward()
self.assertTrue(temp.grad is not None)
return paddle.to_tensor(temp.grad)

for i in range(2):
data = paddle.ones([2, 3], dtype="float32") / (i + 1)
data.stop_gradient = False
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
z.backward()
self.assertTrue(data.grad is not None)


if __name__ == '__main__':
Expand Down