Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dygraph_grad_maker to support the situation where inplace var is leaf var (by using set_value method) #38014

Merged
merged 1 commit into from
Dec 10, 2021

Conversation

pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Dec 9, 2021

PR types

Bug fixes

PR changes

Others

Describe

问题一:dygraph_grad_maker指定grad_pending_node的问题

stop_gradient=False的叶子节点本来是无法进行inplace操作的。但是set_value操作可以对stop_gradient=True的叶子节点进行inplace操作(set_value为inplace操作),且操作过程中,可以将set_value的tensor的stop_gradient变成False。等同于对stop_gradient=False的叶子节点进行了inplace操作,框架中没有相应的处理导致报错。

问题示例:

import paddle

a = paddle.rand(shape=[1, 4])
b = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b.stop_gradient = False

d = paddle.zeros((4, 4))  # d为叶子节点,其stop_gradient=True
print(d.stop_gradient)

d[0, :] = a / b  # 由于a和b的stop_gradient=False,在set_value操作后,d的stop_gradient值变为False
d[0, :].sum().backward()

print(a.grad)
# None
print(b.grad)
# None

问题:输出a和b的grad时,没有结果输出来。梯度在之前被截断了。

  • 问题定位

结论:在dygraph_grad_maker.cc中为grad_node设定grad_pending_node时出错。

模拟上面set_value的问题:x = inplace_op(x, y),也就是对x作inplace操作,且x.stop_gradient=Truey.stop_gradient=False。操作完后,x.stop_gradient=Falsey.stop_gradient=False

为了容易区分,我们将做inplace的输入和输出x加标志为x1 = inplace_op(x0, y),其中x0x1为同一个var。
那么在使用dygraph_grad_maker.cc创建其对应的grad_op时,公式为grad_x0 = inplace_grad_op(x0, y, grad_x1)grad_y = inplace_grad_op(x0, y, grad_x1)。其中,grad_y = inplace_grad_op(x0, y, grad_x1)操作没有问题,grad_x0 = inplace_grad_op(x0, y, grad_x1)有问题,分析如下。

按照dygraph_grad_maker.cc中的逻辑,首先为inplace_grad_op SetInput,在该过程中,会指定grad_x1->grad_node = inplace_grad_op
然后为inplace_grad_op SetOutput,在该过程中,会为inplace_grad_op指定grad_pending_node,即inplace_grad_op->grad_pending_node = grad_x0->grad_node = grad_x1->grad_node = inplace_grad_op(其中grad_x0grad_x1为同一个var)。成环导致无法输出最终结果。

实际上,在inplace操作中,我们已经为了防止反向网络成环,加入了dirty_grad_node来防止成环。但是,由于set_value操作会让叶子节点做inplace操作,dirty_grad_node不会对叶子节点生效,因此在这种特殊情况下,反向还是成环了。

  • 修复

dirty_grad_node处理不了的情况下,直接不允许grad_pending_node与当前的grad_node相同,就避免成环了。

问题二:basic_engine中处理stop_gradient=False的叶子节点

由于以前默认stop_gradient=False的叶子节点无法做inplace操作,因此用于梯度累加的accumulators_with_grad_node_并没有考虑该情况。

也就是说,原来的逻辑是:只要一个grad_var有grad_node,那么使用grad_var和对应的grad_pending_node作为key。对于stop_gradient=False的叶子inplace_grad_var,它有grad_node,所以本来应该使用accumulators_with_grad_node_来进行梯度累加的统计。但是,这个inplace_grad_var并不是其对应的grad_pending_node的输入,因为他是叶子节点,不是任何grad_op的输入。

在框架原来的逻辑中,对于这种grad_var不是grad_pending_node的输入的情况,会直接报错。

现在不报错了,而是特殊处理这种情况。在这种情况下,会将该inplace_var作为叶子节点,使用accumulators_来统计梯度累加信息。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 9, 2021

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pangyoki pangyoki changed the title fix dygraph_grad_maker to support inplace var as leaf var using set_value method fix dygraph_grad_maker to support the situation where inplace var is leaf var (by using set_value method) Dec 10, 2021
@pangyoki pangyoki merged commit dabf815 into PaddlePaddle:develop Dec 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants