Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter axis to tversky loss #20563

Merged
merged 8 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions keras/src/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,7 @@ def __init__(
beta=0.5,
reduction="sum_over_batch_size",
name="tversky",
axis=None,
dtype=None,
):
super().__init__(
Expand All @@ -1397,13 +1398,17 @@ def __init__(
dtype=dtype,
alpha=alpha,
beta=beta,
axis=axis,
)
self.alpha = alpha
self.beta = beta
self.axis = axis

def get_config(self):
config = Loss.get_config(self)
config.update({"alpha": self.alpha, "beta": self.beta})
config.update(
{"alpha": self.alpha, "beta": self.beta, "axis": self.axis}
)
return config


Expand Down Expand Up @@ -2465,7 +2470,7 @@ def dice(y_true, y_pred, axis=None):


@keras_export("keras.losses.tversky")
def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None):
"""Computes the Tversky loss value between `y_true` and `y_pred`.

This loss function is weighted by the alpha and beta coefficients
Expand All @@ -2479,6 +2484,7 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
y_pred: tensor of predicted targets.
alpha: coefficient controlling incidence of false positives.
beta: coefficient controlling incidence of false negatives.
axis: tuple for which dimensions the loss is calculated.

Returns:
Tversky loss value.
Expand All @@ -2490,12 +2496,13 @@ def tversky(y_true, y_pred, alpha=0.5, beta=0.5):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, y_pred.dtype)

inputs = ops.reshape(y_true, [-1])
targets = ops.reshape(y_pred, [-1])
inputs = y_true
targets = y_pred

intersection = ops.sum(inputs * targets, axis=axis)
fp = ops.sum((1 - targets) * inputs, axis=axis)
fn = ops.sum(targets * (1 - inputs), axis=axis)

intersection = ops.sum(inputs * targets)
fp = ops.sum((1 - targets) * inputs)
fn = ops.sum(targets * (1 - inputs))
tversky = ops.divide(
intersection,
intersection + fp * alpha + fn * beta + backend.epsilon(),
Expand Down
22 changes: 22 additions & 0 deletions keras/src/losses/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,16 @@ def test_binary_segmentation(self):
output = losses.Tversky()(y_true, y_pred)
self.assertAllClose(output, 0.77777773)

def test_binary_segmentation_with_axis(self):
y_true = np.array(
[[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]
)
y_pred = np.array(
[[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]
)
output = losses.Tversky(axis=(1, 2, 3), reduction=None)(y_true, y_pred)
self.assertAllClose(output, [0.5, 0.75757575])

def test_binary_segmentation_custom_coefficients(self):
y_true = np.array(
([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])
Expand All @@ -1640,6 +1650,18 @@ def test_binary_segmentation_custom_coefficients(self):
output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)
self.assertAllClose(output, 0.7916667)

def test_binary_segmentation_custom_coefficients_with_axis(self):
y_true = np.array(
[[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]
)
y_pred = np.array(
[[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]
)
output = losses.Tversky(
alpha=0.2, beta=0.8, axis=(1, 2, 3), reduction=None
)(y_true, y_pred)
self.assertAllClose(output, [0.5, 0.7222222])

def test_dtype_arg(self):
y_true = np.array(([[1, 2], [1, 2]]))
y_pred = np.array(([[4, 1], [6, 1]]))
Expand Down