Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St]Fix cond_block_grad error when handle no need grad vras #43034

Merged
merged 2 commits into from
May 30, 2022

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented May 26, 2022

PR types

Bug fixes

PR changes

Others

Describe

这个问题是在套件中ppyoloe模型中报出,排查后尝试使用demo进行了复现

复现case:

import paddle

class Net(paddle.nn.Layer):
    def __init__(self):
        super(Net, self).__init__()
        self.param = self.create_parameter(
            shape=[3, 2],
            dtype='float32',
            is_bias=False)
    @paddle.jit.to_static
    def forward(self, a, b, c):
        a = paddle.matmul(a, self.param)
        cond = paddle.to_tensor([10])
        # import pdb; pdb.set_trace()
        a = paddle.reshape(a, (2, 4))
        if cond == 10:
            tmp = a.argmax(axis=-1)
            b = b + self.param
        else:
            print(c)
        return b

a = paddle.randn((4, 3), dtype='float32')
a.stop_gradient = False
b = paddle.to_tensor([10])
b.stop_gradient = False
c = paddle.to_tensor([2])
c.stop_gradient = False

net = Net()
net.train()
out = net(a, b, c)
out.backward()
print(out)

报错分析

报错信息:
image
打印program,报错信息中未初始化var应该是reshape2_0.tmp_0@GRAD,后面reshape_grad时报出该var没有被初始化。这个var是的conditional_block_grad计算得到的对应的sublock是4号,但是block4中并没有做reshape2_0.tmp_0@GRAD的计算。

{Input@GRAD=['net_0.w_0@GRAD@RENAME@block0@0', 'reshape2_0.tmp_0@GRAD', 'generated_tensor_0@GRAD@RENAME@block0@1']} = conditional_block_grad(inputs={Cond=['tmp_1'], Input=['net_0.w_0', 'reshape2_0.tmp_0', 'generated_tensor_0'], Out=['_generated_var_0'], Out@GRAD=['_generated_var_0@GRAD'], Scope=['_generated_var_1']}, is_scalar_condition = True, op_device = , op_role = 1, sub_block = block[4])

image
为什么block0中有reshape2_0.tmp_0@GRAD这个var,但是反向的block4中没有计算这个var呢?在反向_append_backward_ops_的逻辑中会找到sub_block_path表示sub_block中需要求导op的路径,在block1中可以看到reshape2_0.tmp_0在arg_max之后其输出var argmax_0.tmp_0的stop_gradient属性是True,那么对于argmax_0.tmp_0reshape2_0.tmp_0应该是不需要求导的,arg_max这个op应该不会在sub_block_path中,所以block4中也就不会对reshape2_0.tmp_0进行求导,即不会计算reshape2_0.tmp_0@GRAD
那按理说在block0中也不不应该存在reshape2_0.tmp_0@GRAD这个var呀,为什么还是会有呢?no_grad_dict中对于不需要求导var的分析似乎有点问题?reshape2_0.tmp_0应该存在于这个dict中(把reshape2_0.tmp_0放到no_grad_dict似乎也不合理,因为在block0中reshape2_0.tmp_0的stop_gradient=False),然后调用get_grad_op_desc生成反向op_desc是将reshape2_0.tmp_0@GRAD对应位置的Input@GRAD用@empty@代替。但是reshape2_0.tmp_0并不在no_grad_dict中,导致生成反向op的时候没有用@empty@进行替换。
image

修复

尝试修改_append_backward_ops_sub_block_pathno_grad_dict的处理逻辑,阅读相关代码后发现修改这部分的逻辑有些困难。于是转而修改cond_block_grad计算后返回的逻辑,cond_block_grad计算逻辑最后会调用AssignLocalGradientToParentScopesub_block的grad_var拷贝到parent_block,我们只需要将在AssignLocalGradientToParentScope中找到parent_block需要计算反向的grad_var,但是在sub_block却没有计算,将这些grad_var赋值为0。

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@0x45f 0x45f changed the title Fix cond_block_grad error when handle no need grad vras [Dy2St]Fix cond_block_grad error when handle no need grad vras May 26, 2022
Copy link
Contributor

@2742195759 2742195759 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@0x45f 0x45f merged commit cd3d091 into PaddlePaddle:develop May 30, 2022
@0x45f 0x45f deleted the dy2st-fix-sub-block-grad-var branch May 30, 2022 04:09
0x45f added a commit to 0x45f/Paddle that referenced this pull request May 30, 2022
…ePaddle#43034)

* Fix cond_block_grad error when handle no need grad vras

* Add comment and UT
phlrain pushed a commit that referenced this pull request May 30, 2022
… (#43084)

* Fix cond_block_grad error when handle no need grad vras

* Add comment and UT
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants