Skip to content

Commit 39403d1

Browse files
[0-size Tensor Retest] fix take_along_axis bug (#74354)
1 parent e554b1e commit 39403d1

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6680,10 +6680,15 @@ def take_along_axis(
66806680
)
66816681
axis = non_negative_axis(arr, axis)
66826682
if broadcast:
6683-
broadcast_shape = infer_broadcast_shape(arr, indices, axis)
6684-
if not broadcast_shape:
6685-
# if indices matrix have larger size than arr, arr should broadcast into indices shape.
6686-
broadcast_shape = indices.shape
6683+
broadcast_shape_list = list(arr.shape)
6684+
for i in range(len(arr.shape)):
6685+
if indices.shape[i] == 0 or arr.shape[i] == 0:
6686+
broadcast_shape_list[i] = 0
6687+
else:
6688+
broadcast_shape_list[i] = max(arr.shape[i], indices.shape[i])
6689+
broadcast_shape_list[axis] = list(indices.shape)[axis]
6690+
broadcast_shape = tuple(broadcast_shape_list)
6691+
66876692
indices = paddle.broadcast_to(indices, broadcast_shape)
66886693
broadcast_shape_list = list(broadcast_shape)
66896694
broadcast_shape_list[axis] = list(arr.shape)[axis]

0 commit comments

Comments
 (0)