fix dygraph_grad_maker to support the situation where inplace var is leaf var (by using set_value method) #38014
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR types
Bug fixes
PR changes
Others
Describe
问题一:dygraph_grad_maker指定grad_pending_node的问题
stop_gradient=False的叶子节点本来是无法进行inplace操作的。但是set_value操作可以对stop_gradient=True的叶子节点进行inplace操作(set_value为inplace操作),且操作过程中,可以将set_value的tensor的stop_gradient变成False。等同于对stop_gradient=False的叶子节点进行了inplace操作,框架中没有相应的处理导致报错。
问题示例:
问题:输出a和b的grad时,没有结果输出来。梯度在之前被截断了。
结论:在
dygraph_grad_maker.cc
中为grad_node
设定grad_pending_node
时出错。模拟上面set_value的问题:
x = inplace_op(x, y)
,也就是对x
作inplace操作,且x.stop_gradient=True
,y.stop_gradient=False
。操作完后,x.stop_gradient=False
,y.stop_gradient=False
。为了容易区分,我们将做inplace的输入和输出
x
加标志为x1 = inplace_op(x0, y)
,其中x0
与x1
为同一个var。那么在使用
dygraph_grad_maker.cc
创建其对应的grad_op时,公式为grad_x0 = inplace_grad_op(x0, y, grad_x1)
,grad_y = inplace_grad_op(x0, y, grad_x1)
。其中,grad_y = inplace_grad_op(x0, y, grad_x1)
操作没有问题,grad_x0 = inplace_grad_op(x0, y, grad_x1)
有问题,分析如下。按照
dygraph_grad_maker.cc
中的逻辑,首先为inplace_grad_op
SetInput,在该过程中,会指定grad_x1->grad_node = inplace_grad_op
。然后为
inplace_grad_op
SetOutput,在该过程中,会为inplace_grad_op
指定grad_pending_node
,即inplace_grad_op->grad_pending_node = grad_x0->grad_node = grad_x1->grad_node = inplace_grad_op
(其中grad_x0
与grad_x1
为同一个var)。成环导致无法输出最终结果。实际上,在inplace操作中,我们已经为了防止反向网络成环,加入了
dirty_grad_node
来防止成环。但是,由于set_value
操作会让叶子节点做inplace操作,dirty_grad_node
不会对叶子节点生效,因此在这种特殊情况下,反向还是成环了。在
dirty_grad_node
处理不了的情况下,直接不允许grad_pending_node
与当前的grad_node
相同,就避免成环了。问题二:basic_engine中处理stop_gradient=False的叶子节点
由于以前默认stop_gradient=False的叶子节点无法做inplace操作,因此用于梯度累加的
accumulators_with_grad_node_
并没有考虑该情况。也就是说,原来的逻辑是:只要一个grad_var有grad_node,那么使用grad_var和对应的grad_pending_node作为key。对于stop_gradient=False的叶子inplace_grad_var,它有grad_node,所以本来应该使用
accumulators_with_grad_node_
来进行梯度累加的统计。但是,这个inplace_grad_var并不是其对应的grad_pending_node的输入,因为他是叶子节点,不是任何grad_op的输入。在框架原来的逻辑中,对于这种grad_var不是grad_pending_node的输入的情况,会直接报错。
现在不报错了,而是特殊处理这种情况。在这种情况下,会将该inplace_var作为叶子节点,使用
accumulators_
来统计梯度累加信息。