Skip to content

Commit

Permalink
remove redundant equal (PaddlePaddle#57986)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 authored Oct 18, 2023
1 parent a0a049c commit 82ffd1e
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,10 @@ def get_value_for_bool_tensor(var, item):
empty_shape = [0] + list(var.shape[i:])

def idx_not_empty(var, item):
from ..tensor import gather_nd
bool_2_idx = paddle.nonzero(item)
return paddle.gather_nd(var, bool_2_idx)

bool_2_idx = paddle.nonzero(item == True)
return gather_nd(var, bool_2_idx)

from paddle.static.nn import cond

return cond(
return paddle.static.nn.cond(
item.any(),
lambda: idx_not_empty(var, item),
lambda: paddle.empty(empty_shape, var.dtype),
Expand Down

0 comments on commit 82ffd1e

Please sign in to comment.