Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gradient(nan) when two inputs are equal #32448

Merged
merged 1 commit into from
Apr 25, 2021

Conversation

zhangting2020
Copy link
Contributor

@zhangting2020 zhangting2020 commented Apr 22, 2021

PR types

Bug fixes

PR changes

OPs

Describe

fix gradient(nan) when two inputs are equal

test code

mport torch
import paddle
import numpy as np

x_i = np.random.random((4, 4)).astype("float32")

x = torch.tensor(x_i, requires_grad=True)
out = torch.dist(x, x, 2)
out.backward(gradient=torch.tensor(1.0), retain_graph=True)
print(x.grad)

x = paddle.to_tensor(x_i)
x.stop_gradient = False
out = paddle.dist(x, x, 2)
dx = paddle.grad(out, x)
print(dx)

Before

  • torch
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
  • paddle
[Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])]

After

paddle

[Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])]

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhangting2020 zhangting2020 merged commit 1896c77 into PaddlePaddle:develop Apr 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants