File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -6680,10 +6680,15 @@ def take_along_axis(
66806680 )
66816681 axis = non_negative_axis (arr , axis )
66826682 if broadcast :
6683- broadcast_shape = infer_broadcast_shape (arr , indices , axis )
6684- if not broadcast_shape :
6685- # if indices matrix have larger size than arr, arr should broadcast into indices shape.
6686- broadcast_shape = indices .shape
6683+ broadcast_shape_list = list (arr .shape )
6684+ for i in range (len (arr .shape )):
6685+ if indices .shape [i ] == 0 or arr .shape [i ] == 0 :
6686+ broadcast_shape_list [i ] = 0
6687+ else :
6688+ broadcast_shape_list [i ] = max (arr .shape [i ], indices .shape [i ])
6689+ broadcast_shape_list [axis ] = list (indices .shape )[axis ]
6690+ broadcast_shape = tuple (broadcast_shape_list )
6691+
66876692 indices = paddle .broadcast_to (indices , broadcast_shape )
66886693 broadcast_shape_list = list (broadcast_shape )
66896694 broadcast_shape_list [axis ] = list (arr .shape )[axis ]
You can’t perform that action at this time.
0 commit comments