-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lerp support 0 Tensor #49667
lerp support 0 Tensor #49667
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -966,6 +966,20 @@ def test_argsort(self): | |||
self.assertEqual(x1.grad.numpy(), 0) | |||
self.assertEqual(x2.grad.numpy(), 0) | |||
|
|||
def test_lerp(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个API目前支持广播吗?如果支持的话就需要有0D+0D,0D+ND,ND+0D三种case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thanks
@@ -101,12 +127,20 @@ void LerpGradKernel(const Context& ctx, | |||
DenseTensor* x_grad, | |||
DenseTensor* y_grad) { | |||
int rank = out.dims().size(); | |||
PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个判断反向shape的判读是不是可以不加?不然静态图创建反向目前可能有问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
y2.stop_gradient = False | ||
out2 = paddle.lerp(x2, y2, w2) | ||
|
||
prog = paddle.static.default_main_program() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要加paddle.static.append_backward,现在上面的判断跑反向应该会报错。.grad的shape现在会是[1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
|
||
prog = paddle.static.default_main_program() | ||
res = self.exe.run(prog, fetch_list=[out0, out1, out2]) | ||
self.assertEqual(res[0].shape, ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个我试了下可以这么写,paddle.static.append_backward(out0.sum())
前面out、x的梯度shape就都是正常的了,像动态图一样加下反向shape的测试吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* lerp support 0 Tensor * fix lerp grad * fix lerp zero test * fix 0D + ND/ND + 0D * fix check * update code * fix lerp infer shape * static backward test * updata static graph test
PR types
New features
PR changes
OPs
Describe
Support input 0D Tensor for Lerp