Skip to content

paddle.put_along_axis 系列的 API 在数据量稍大时,include_self=False模式下,部分计算结果就不正确 #72803

@HydrogenSulfate

Description

@HydrogenSulfate

bug描述 Describe the Bug

paddle.put_along_axis 系列的 API 在数据量稍大时,include_self=False模式下,部分计算结果就不正确,其根本原因应该是 Paddle/paddle/phi/kernels/funcs/gather_scatter_functor.cu 这个文件内的代码写的有 BUG
复现代码如下,可以看到

import paddle
import numpy as np
import torch


x = paddle.randn([10, 400])
x.stop_gradient = False
include_self = False
axis = 0

indices = paddle.randint(0, x.shape[axis], [10, 400])

values = paddle.randn(indices.shape)
values.stop_gradient = False

out = paddle.put_along_axis(
    x,
    indices,
    values,
    axis,
    'add',
    include_self=include_self,
)

dout = paddle.randn_like(out)
dout.stop_gradient = False

dx, dv = paddle.grad(
    out,
    [x, values],
    dout,
    create_graph=True,
)
print(f"dx.shape = {dx.shape}, dv.shape = {dv.shape}")

ddx = paddle.randn_like(dx)
ddx.stop_gradient = False
ddv = paddle.randn_like(dv)
ddv.stop_gradient = False

ddout = paddle.grad(
    [dx, dv], # 此处可以仅保留 dx 或者 dv,以测试optional input分支
    dout,
    [ddx, ddv], # 此处可以仅保留 ddx 或者 ddv,以测试optional input分支
)[0]
print(f"ddout.shape = {ddout.shape}")

tx = torch.as_tensor(x.detach()).cuda()
tx.requires_grad = True

tindices = torch.as_tensor(indices.detach()).cuda()
tvalues = torch.as_tensor(values.detach()).cuda()
tvalues.requires_grad = True

tout = torch.scatter_reduce(
    tx,
    axis,
    tindices,
    tvalues,
    reduce='sum',
    include_self=include_self,
)

tdout = torch.as_tensor(dout.detach())
tdout.requires_grad = True

tdx, tdv = torch.autograd.grad(
    tout,
    [tx, tvalues],
    tdout,
    create_graph=True,
)
print(f"dx.shape = {dx.shape}, dv.shape = {dv.shape}")

tddx = torch.as_tensor(ddx.detach())
tddx.requires_grad = True
tddv = torch.as_tensor(ddv.detach())
tddv.requires_grad = True

tddout = torch.autograd.grad(
    [tdx, tdv], # 此处可以仅保留 tdx 或者 tdv,以测试optional input分支
    tdout,
    [tddx, tddv], # 此处可以仅保留 tddx 或者 tddv,以测试optional input分支
)[0]
print(f"ddout.shape = {ddout.shape}")

np.testing.assert_allclose(out.numpy(), tout.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(dx.numpy(), tdx.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(dv.numpy(), tdv.detach().cpu().numpy(), 1e-6, 1e-6)
np.testing.assert_allclose(ddout.numpy(), tddout.detach().cpu().numpy(), 1e-6, 1e-6)

其他补充信息 Additional Supplementary Information

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions