Skip to content

Commit ed08c56

Browse files
committed
fix test scatter reduce
1 parent 9441f60 commit ed08c56

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/legacy_test/test_scatter_reduce_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
def scatter_reduce_net(x, axis=-1):
2828
index = paddle.full_like(x, fill_value=2, dtype='int64')
2929
n = paddle.numel(x)
30-
ind = paddle.arange(n, dtype='float32')
30+
ind = paddle.arange(n, dtype='int32').astype(x.dtype)
3131
value = paddle.reshape(ind, x.shape)
3232
return paddle.scatter_reduce(x, axis, index, value, reduce='sum')
3333

0 commit comments

Comments
 (0)