From 6effcfb11ebf9a30d4aaa160b4a1efc41578d661 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 10 Oct 2023 07:02:48 +0000 Subject: [PATCH] remove redundant equal --- python/paddle/base/variable_index.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index 0ff628ed48f4f..5f886ea91386c 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -335,14 +335,10 @@ def get_value_for_bool_tensor(var, item): empty_shape = [0] + list(var.shape[i:]) def idx_not_empty(var, item): - from ..tensor import gather_nd + bool_2_idx = paddle.nonzero(item) + return paddle.gather_nd(var, bool_2_idx) - bool_2_idx = paddle.nonzero(item == True) - return gather_nd(var, bool_2_idx) - - from paddle.static.nn import cond - - return cond( + return paddle.static.nn.cond( item.any(), lambda: idx_not_empty(var, item), lambda: paddle.empty(empty_shape, var.dtype),