diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c353451d0c8ef..18b16a2d99434 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1650,25 +1650,16 @@ def cross_entropy(input, label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): if soft_label == False: - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) + valid_label = paddle.cast( + label != ignore_index, dtype=label.dtype) * label + label_min = paddle.min(valid_label) + label_max = paddle.max(valid_label) + if label_min < 0: + raise ValueError("label should not out of bound, but got{}". + format(label_min)) + if label_max >= input.shape[axis]: + raise ValueError("label should not out of bound, but got{}". + format(label_max)) _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', @@ -1817,8 +1808,9 @@ def cross_entropy(input, "when weight is provided"\ .format(input.shape[axis], weight.shape[-1])) - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) + valid_label = paddle.multiply( + paddle.cast( + label != ignore_index, dtype=label.dtype), label) ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[