Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lerp support 0 Tensor #49667

Merged
merged 11 commits into from
Jan 12, 2023

Conversation

SunNy820828449
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

Support input 0D Tensor for Lerp

@paddle-bot
Copy link

paddle-bot bot commented Jan 9, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -966,6 +966,20 @@ def test_argsort(self):
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)

def test_lerp(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API目前支持广播吗?如果支持的话就需要有0D+0D,0D+ND,ND+0D三种case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thanks

@@ -101,12 +127,20 @@ void LerpGradKernel(const Context& ctx,
DenseTensor* x_grad,
DenseTensor* y_grad) {
int rank = out.dims().size();
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个判断反向shape的判读是不是可以不加?不然静态图创建反向目前可能有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

y2.stop_gradient = False
out2 = paddle.lerp(x2, y2, w2)

prog = paddle.static.default_main_program()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加paddle.static.append_backward,现在上面的判断跑反向应该会报错。.grad的shape现在会是[1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out0, out1, out2])
self.assertEqual(res[0].shape, ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我试了下可以这么写,paddle.static.append_backward(out0.sum()) 前面out、x的梯度shape就都是正常的了,像动态图一样加下反向shape的测试吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thanks.

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 8cd0d5b into PaddlePaddle:develop Jan 12, 2023
yjjiang11 pushed a commit to yjjiang11/Paddle that referenced this pull request Jan 13, 2023
* lerp support 0 Tensor

* fix lerp grad

* fix lerp zero test

* fix 0D + ND/ND + 0D

* fix check

* update code

* fix lerp infer shape

* static backward test

* updata static graph test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants