diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index c093565dc92ff..35e2cd2439177 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -176,8 +176,9 @@ def backward(ctx, dy): class PyLayerBackward(PyLayerContext): def backward(self, *args, **kwargs): - with paddle.fluid.dygraph.no_grad(): - return self._forward_cls.backward(*args, **kwargs) + with paddle.fluid.dygraph.guard(): + with paddle.fluid.dygraph.no_grad(): + return self._forward_cls.backward(*args, **kwargs) class LayerMeta(type): diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index 89f8330fe5ba4..72d8efc80a938 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -283,20 +283,54 @@ def test_pylayer_inplace(self): class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): - return x.mean() + return x @staticmethod def backward(ctx, dy): return dy + class Layer(paddle.nn.Layer): + def __init__(self): + super(Layer, self).__init__() + + def forward(self, data): + data = paddle.nn.functional.relu(data) + z = paddle.tanh(data) + z = cus_tanh.apply(data) + return z.mean() + for i in range(2): data = paddle.ones([2, 3], dtype="float64") / (i + 1) data.stop_gradient = False + layer = Layer() + z = layer(data) + z.backward() + self.assertTrue(data.grad is not None) + + def test_backward_in_backward(self): + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + temp = x.detach() + ctx.inputs = temp + return x.mean() + + @staticmethod + def backward(ctx, dy): + with paddle.set_grad_enabled(True): + temp = ctx.inputs + temp.stop_gradient = False + z = paddle.tanh(temp) + z.backward() + self.assertTrue(temp.grad is not None) + return paddle.to_tensor(temp.grad) + + for i in range(2): + data = paddle.ones([2, 3], dtype="float32") / (i + 1) + data.stop_gradient = False data = paddle.nn.functional.relu(data) z = paddle.tanh(data) z = cus_tanh.apply(data) - z.backward() - self.assertTrue(data.grad is not None) if __name__ == '__main__':