Skip to content

Commit

Permalink
【Eager】Allow return none when stop_gradient=False (#51740)
Browse files Browse the repository at this point in the history
* allow return none when stop_gradient=True

* remove useless code

* refine code

* refine code

* fix test cast

* change more test

* add more tests
  • Loading branch information
JiabinYang authored Mar 22, 2023
1 parent 8c61a95 commit db59925
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 18 deletions.
16 changes: 14 additions & 2 deletions paddle/fluid/eager/pylayer/py_layer_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,21 @@ GradNodePyLayer::operator()(
grad_out.push_back({});
} else {
if (ctx->forward_input_tensor_is_duplicable[i]) {
grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj));
grad_out.push_back(
paddle::pybind::GetTensorListFromPyObject(obj, true));
} else {
grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)});
if (paddle::pybind::PyCheckTensor(obj)) {
grad_out.push_back(
{paddle::pybind::UnSafeGetTensorFromPyObject(obj)});
} else if (obj == Py_None) {
VLOG(4) << "Got None for Tensor with pos: " << i;
grad_out.push_back({});
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"argument must be "
"Tensor or None, but got %s",
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}
}
} else {
Expand Down
18 changes: 9 additions & 9 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
return result;
}

std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
bool allow_none) {
std::vector<paddle::Tensor> result;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
Expand All @@ -1229,6 +1230,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (allow_none && (item == Py_None)) {
VLOG(4) << "Got None in Tensor list: " << i;
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
Expand All @@ -1245,6 +1249,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (allow_none && (item == Py_None)) {
VLOG(4) << "Got None in Tensor list: " << i;
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
Expand All @@ -1262,16 +1269,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
return result;
}

paddle::Tensor& GetTensorFromPyObject(PyObject* obj) {
if (!PyCheckTensor(obj)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
"Tensor, but got %s",
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj) {
return reinterpret_cast<TensorObject*>(obj)->tensor;
}

paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(

std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj);

std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj);

paddle::Tensor& GetTensorFromPyObject(PyObject* obj);
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
bool allow_none = false);
paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj);

// end of Slice related methods

Expand Down
38 changes: 34 additions & 4 deletions python/paddle/fluid/tests/unittests/test_pylayer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,36 @@ def backward(ctx, dy1):
np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
)

def test_simple_pylayer_multi_output(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.split):
ctx.func = func2
y1 = func1(x1)
ctx.save_for_backward(y1)
return y1

@staticmethod
def backward(ctx, dy1):
(y1,) = ctx.saved_tensor()
re1 = ctx.func(dy1, 3)
return re1

input1 = paddle.randn([2, 3]).astype("float64")
input2 = paddle.randn([2, 3]).astype("float64")
input3 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
input2.stop_gradient = False
input3.stop_gradient = False
z = tanh.apply(x1=[input1, input2, input3], func1=paddle.concat)
z.mean().backward()
z2 = paddle.concat([input1, input2, input3])
z2.mean().backward()

self.assertTrue(
np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
)

def test_pylayer_num_output_match(self):
class tanh(PyLayer):
@staticmethod
Expand Down Expand Up @@ -269,8 +299,8 @@ def backward(ctx, dy1):
input2.stop_gradient = False
z = Layer_bk_none1.apply(input2)

with self.assertRaises(ValueError):
z.sum().backward()
z.sum().backward()
self.assertEqual(input2.grad, None)

class Layer_bk_none2(PyLayer):
@staticmethod
Expand All @@ -285,8 +315,8 @@ def backward(ctx, dy1):
input1.stop_gradient = False
z = Layer_bk_none2.apply(input1, input1)

with self.assertRaises(ValueError):
z.mean().backward()
z.mean().backward()
self.assertIsNone(z.grad)

class Layer_bk_one1(PyLayer):
@staticmethod
Expand Down

0 comments on commit db59925

Please sign in to comment.