-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Closed
Labels
Description
bug描述 Describe the Bug
paddle.put_along_axis 系列的 API 在数据量稍大时,include_self=False模式下,部分计算结果就不正确,其根本原因应该是 Paddle/paddle/phi/kernels/funcs/gather_scatter_functor.cu 这个文件内的代码写的有 BUG
复现代码如下,可以看到
import paddle
import numpy as np
import torch
x = paddle.randn([10, 400])
x.stop_gradient = False
include_self = False
axis = 0
indices = paddle.randint(0, x.shape[axis], [10, 400])
values = paddle.randn(indices.shape)
values.stop_gradient = False
out = paddle.put_along_axis(
x,
indices,
values,
axis,
'add',
include_self=include_self,
)
dout = paddle.randn_like(out)
dout.stop_gradient = False
dx, dv = paddle.grad(
out,
[x, values],
dout,
create_graph=True,
)
print(f"dx.shape = {dx.shape}, dv.shape = {dv.shape}")
ddx = paddle.randn_like(dx)
ddx.stop_gradient = False
ddv = paddle.randn_like(dv)
ddv.stop_gradient = False
ddout = paddle.grad(
[dx, dv], # 此处可以仅保留 dx 或者 dv,以测试optional input分支
dout,
[ddx, ddv], # 此处可以仅保留 ddx 或者 ddv,以测试optional input分支
)[0]
print(f"ddout.shape = {ddout.shape}")
tx = torch.as_tensor(x.detach()).cuda()
tx.requires_grad = True
tindices = torch.as_tensor(indices.detach()).cuda()
tvalues = torch.as_tensor(values.detach()).cuda()
tvalues.requires_grad = True
tout = torch.scatter_reduce(
tx,
axis,
tindices,
tvalues,
reduce='sum',
include_self=include_self,
)
tdout = torch.as_tensor(dout.detach())
tdout.requires_grad = True
tdx, tdv = torch.autograd.grad(
tout,
[tx, tvalues],
tdout,
create_graph=True,
)
print(f"dx.shape = {dx.shape}, dv.shape = {dv.shape}")
tddx = torch.as_tensor(ddx.detach())
tddx.requires_grad = True
tddv = torch.as_tensor(ddv.detach())
tddv.requires_grad = True
tddout = torch.autograd.grad(
[tdx, tdv], # 此处可以仅保留 tdx 或者 tdv,以测试optional input分支
tdout,
[tddx, tddv], # 此处可以仅保留 tddx 或者 tddv,以测试optional input分支
)[0]
print(f"ddout.shape = {ddout.shape}")
np.testing.assert_allclose(out.numpy(), tout.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(dx.numpy(), tdx.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(dv.numpy(), tdv.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(ddout.numpy(), tddout.detach().cpu().numpy(), 1e-6, 1e-6)其他补充信息 Additional Supplementary Information
No response