diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index cc18b37df65..559bb472656 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -1388,6 +1388,7 @@ def __init__( beta=0.5, reduction="sum_over_batch_size", name="tversky", + axis=None, dtype=None, ): super().__init__( @@ -1397,13 +1398,17 @@ def __init__( dtype=dtype, alpha=alpha, beta=beta, + axis=axis, ) self.alpha = alpha self.beta = beta + self.axis = axis def get_config(self): config = Loss.get_config(self) - config.update({"alpha": self.alpha, "beta": self.beta}) + config.update( + {"alpha": self.alpha, "beta": self.beta, "axis": self.axis} + ) return config @@ -2465,7 +2470,7 @@ def dice(y_true, y_pred, axis=None): @keras_export("keras.losses.tversky") -def tversky(y_true, y_pred, alpha=0.5, beta=0.5): +def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None): """Computes the Tversky loss value between `y_true` and `y_pred`. This loss function is weighted by the alpha and beta coefficients @@ -2479,6 +2484,7 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5): y_pred: tensor of predicted targets. alpha: coefficient controlling incidence of false positives. beta: coefficient controlling incidence of false negatives. + axis: tuple for which dimensions the loss is calculated. Returns: Tversky loss value. @@ -2490,12 +2496,13 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5): y_pred = ops.convert_to_tensor(y_pred) y_true = ops.cast(y_true, y_pred.dtype) - inputs = ops.reshape(y_true, [-1]) - targets = ops.reshape(y_pred, [-1]) + inputs = y_true + targets = y_pred + + intersection = ops.sum(inputs * targets, axis=axis) + fp = ops.sum((1 - targets) * inputs, axis=axis) + fn = ops.sum(targets * (1 - inputs), axis=axis) - intersection = ops.sum(inputs * targets) - fp = ops.sum((1 - targets) * inputs) - fn = ops.sum(targets * (1 - inputs)) tversky = ops.divide( intersection, intersection + fp * alpha + fn * beta + backend.epsilon(), diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py index bbecbc06d08..144b5413604 100644 --- a/keras/src/losses/losses_test.py +++ b/keras/src/losses/losses_test.py @@ -1630,6 +1630,16 @@ def test_binary_segmentation(self): output = losses.Tversky()(y_true, y_pred) self.assertAllClose(output, 0.77777773) + def test_binary_segmentation_with_axis(self): + y_true = np.array( + [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]] + ) + y_pred = np.array( + [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]] + ) + output = losses.Tversky(axis=(1, 2, 3), reduction=None)(y_true, y_pred) + self.assertAllClose(output, [0.5, 0.75757575]) + def test_binary_segmentation_custom_coefficients(self): y_true = np.array( ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) @@ -1640,6 +1650,18 @@ def test_binary_segmentation_custom_coefficients(self): output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred) self.assertAllClose(output, 0.7916667) + def test_binary_segmentation_custom_coefficients_with_axis(self): + y_true = np.array( + [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]] + ) + y_pred = np.array( + [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]] + ) + output = losses.Tversky( + alpha=0.2, beta=0.8, axis=(1, 2, 3), reduction=None + )(y_true, y_pred) + self.assertAllClose(output, [0.5, 0.7222222]) + def test_dtype_arg(self): y_true = np.array(([[1, 2], [1, 2]])) y_pred = np.array(([[4, 1], [6, 1]]))