Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyLayer] pylayer add api #5148

Merged
merged 6 commits into from
Aug 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions docs/api/paddle/autograd/PyLayerContext_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,159 @@ saved_tensor(self, *tensors)
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad


mark_not_inplace(self, *tensors)
'''''''''

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、补充对mark_not_inplace(self, *args) 功能的一句话说明,便于用户理解。 参考此文档中save_for_backward(self, *tensors)的一句话说明:
image

2、注意事项,也可以参考图片中蓝色高亮的形式,优化一下。用户会更加清楚这个API的注意事项。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

标记一些输入是不需要 inplace 的。
如果 ``forward`` 的输入输出是同一个 ``Tensor`` ,并且这个 ``Tensor`` 被标记为 not_inplace 的。Paddle 会替用户创建一个新的 Tensor 作为输出。
这样可以防止输入的 ``Tensor`` 的 auto grad 信息被错误的篡改。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要补充mark_not_inplace(self, *args) 下参数args的说明,方便用户理解,参考此文档中save_for_backward(self, *tensors)对参数tensors的说明,如图片:

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

.. note::
这个函数最多只能在 ``forward`` 调用一次,并且所有的参数必须是 ``forward`` 输入的 ``Tensor`` 。

**参数**

- **tensors** (list of Tensor) - 需要标记 not inplace 的 ``Tensor``

**返回**

None

**代码示例**

.. code-block:: python

import paddle

class Exp(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x):
ctx.mark_not_inplace(x)
return x

@staticmethod
def backward(ctx, grad_output):
out = grad_output.exp()
return out

x = paddle.randn((1, 1))
x.stop_gradient = False
attn_layers = []
for idx in range(0, 2):
attn_layers.append(Exp())

for step in range(0, 2):
a = x
for j in range(0,2):
a = attn_layers[j].apply(x)
a.backward()


mark_non_differentiable(self, *tensors)
'''''''''

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、补充对mark_non_differentiable(self, *args)功能的一句话说明,便于用户理解。 参考此文档中save_for_backward(self, *tensors)的一句话说明:
image

2、注意事项,也可以参考图片中蓝色高亮的形式,优化一下。用户会更加清楚这个API的注意事项。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

标记一些输出是不需要反向的。
如果 ``forward`` 的输入输出是同一个 ``Tensor`` ,并且这个 ``Tensor`` 被标记为 not_inplace 的。Paddle 会替用户创建一个新的 Tensor 作为输出。
将不需要反向的 ``Tensor`` 标记为 non-differentiable,可以提升反向的性能。但是你在 ``backward`` 函数的输入参数中,仍要为其留有反向梯度的位置。
只是这个反向梯度是 1 个全为 0 的、shape 和 ``forward`` 的输出一样的 ``Tensor`` .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark_non_differentiable(self, *args) 需要补充参数args的说明,方便用户理解。同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

.. note::
这个函数最多只能在 ``forward`` 调用一次,并且所有的参数必须是 ``forward`` 输出的 ``Tensor`` 。

**参数**

- **tensors** (list of Tensor) - 需要标记不需要反向的 ``Tensor``


**返回**

None

**代码示例**

.. code-block:: python

import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle
from paddle.autograd import PyLayer
import numpy as np

class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
a = x + x
b = x + x + x
ctx.mark_non_differentiable(a)
return a, b

@staticmethod
def backward(ctx, grad_a, grad_b):
assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy())
assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy())
return grad_b

x = paddle.ones([1], dtype="float64")
x.stop_gradient = False
a, b = Tanh.apply(x)
b.sum().backward()

set_materialize_grads(self, value)
'''''''''

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、补充对set_materialize_grads(self, value: bool) 功能的一句话说明,便于用户理解。同上
2、注意事项,也可以参考图片中蓝色高亮的形式,优化一下。用户会更加清楚这个API的注意事项。
3、 value: bool) 这个地方,注意一下书写规范。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

设置是否要框架来初始化未初始化的反向梯度。默认是 True。
如果设置为 True,框架会将未初始化的反向梯度数据初始化为 0,然后再调用 ``backward`` 函数。
如果设置为 False,框架会将未初始化的反向梯度以 None 向 ``backward`` 函数传递。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、value: bool 能否特殊注明一下,同样需要单独对参数说明:参考图中:
image
2、说明一下是bool类型。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thx!

.. note::
这个函数最多只能在 ``forward`` 中调用。

**参数**

- **value** (bool) - 是否要框架来初始化未初始化的反向梯度


**返回**

None

**代码示例**

.. code-block:: python

import os
os.environ['FLAGS_enable_eager_mode'] = '1'
import paddle
from paddle.autograd import PyLayer
import numpy as np

class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x+x+x, x+x

@staticmethod
def backward(ctx, grad, grad2):
assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy())
return grad

class Tanh2(PyLayer):
@staticmethod
def forward(ctx, x):
ctx.set_materialize_grads(False)
return x+x+x, x+x

@staticmethod
def backward(ctx, grad, grad2):
assert grad2==None
return grad

x = paddle.ones([1], dtype="float64")
x.stop_gradient = False
Tanh.apply(x)[0].backward()

x2 = paddle.ones([1], dtype="float64")
x2.stop_gradient = False
Tanh2.apply(x2)[0].backward()