diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index e485714f2276a..27e0be1ccb92f 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -344,14 +344,10 @@ def get_value_for_bool_tensor(var, item): empty_shape = [0] + list(var.shape[i:]) def idx_not_empty(var, item): - from ..tensor import gather_nd + bool_2_idx = paddle.nonzero(item) + return paddle.gather_nd(var, bool_2_idx) - bool_2_idx = paddle.nonzero(item == True) - return gather_nd(var, bool_2_idx) - - from paddle.static.nn import cond - - return cond( + return paddle.static.nn.cond( item.any(), lambda: idx_not_empty(var, item), lambda: paddle.empty(empty_shape, var.dtype),