-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_functions.py
153 lines (132 loc) · 5.13 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import jax
import jax.numpy as jnp
import optax
from functools import partial
from flax.training import common_utils
def loss_fn(
variables,
state,
inputs,
labels,
num_classes,
loss_method="cross_entropy",
ignore_label=0,
weight_decay=1e-7,
epsilon=1e-15
):
"""
Calculates the loss for a batch of images.
Args:
variables: The segmentation models parameters.
state: State of the semantic segmentation model.
inputs: A batch of raw input images.
labels: A batch of arrays of segmentation mask ID's.
num_classes: Total number of distinct classes.
loss_method: Loss calculation method.
ignore_label: Background class label to ignore.
weight_decay: Regularization coefficient.
epsilon: A small number to prevent loss function from dividing by zero.
Returns:
The loss from a batch of images and the model log odds for the batch.
"""
forward_fn = state.apply_fn
logits = forward_fn(variables, inputs)
if loss_method == "cross_entropy":
loss_calc = cross_entropy_loss
elif loss_method == "dice":
loss_calc = dice_loss
else:
raise NotImplementedError(f"The loss method {loss_method} is not supported.")
loss = loss_calc(logits, labels, num_classes, ignore_label, epsilon=epsilon)
# Regularization
weight_penalty_params = jax.tree_util.tree_leaves(variables["params"])
weight_l2 = sum([jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, logits
@partial(jax.jit, static_argnames=['num_classes', 'ignore_label', 'class_weights'])
def dice_loss(
logits,
labels,
num_classes,
ignore_label=0,
class_weights=None,
label_smoothing=1e-3,
epsilon=1e-15,
):
"""
Calculates the dice loss for a batch of images.
Args:
logits: The log odds for a batch of images.
labels: A batch of arrays of segmentation mask ID's.
num_classes: Total number of distinct classes.
ignore_label: Background class label to ignore.
class_weights: List of importance weightings for the classes.
label_smoothing: Prevents overconfident predictions.
epsilon: A small number to prevent loss function from dividing by zero.
Returns:
The dice loss from a batch of images.
"""
one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)
smoothed_one_hot_labels = (
one_hot_labels * (1 - label_smoothing) + label_smoothing / num_classes
)
if class_weights is None:
class_weights = jnp.ones(num_classes, dtype=jnp.float32)
probs = jax.nn.softmax(logits, axis=-1)
numerator = 2 * (probs * smoothed_one_hot_labels).sum(axis=(-2, -3))
denominator = probs.sum(axis=(-2, -3)) + smoothed_one_hot_labels.sum(axis=(-2, -3))
dice_coefficient = (numerator + epsilon) / (denominator + epsilon)
dice_coefficient = dice_coefficient.mean(axis=0)
# Remove the background class from the loss function
dice_coefficient = jnp.delete(dice_coefficient, ignore_label)
class_weights = jnp.delete(class_weights, ignore_label)
# Get average coefficient using the weighting provided for each class
mean_dice_coeff = (dice_coefficient * class_weights).sum() / class_weights.sum()
dice_loss = 1 - mean_dice_coeff
return dice_loss
@partial(jax.jit, static_argnames=['num_classes', 'class_weights'])
def cross_entropy_loss(
logits,
labels,
num_classes,
ignore_label=0,
class_weights=None,
label_smoothing=1e-3,
epsilon=1e-15,
):
"""
Calculates the cross entropy loss for a batch of images.
Args:
logits: The log odds for a batch of images.
labels: A batch of arrays of segmentation mask ID's.
num_classes: Total number of distinct classes.
ignore_label: Background class label to ignore.
class_weights: List of importance weightings for the classes.
label_smoothing: Prevents overconfident predictions.
epsilon: A small number to prevent loss function from dividing by zero.
Returns:
The dice loss from a batch of images.
Obtained from: https://github.com/NobuoTsukamoto/jax_examples/blob/4711d70bdf6ce707c8c5130a1e57fe4741176198/segmentation/train.py#L58
"""
valid_mask = jnp.not_equal(labels, ignore_label)
normalizer = jnp.sum(valid_mask.astype(jnp.float32)) + epsilon
labels = jnp.where(valid_mask, labels, jnp.zeros_like(labels))
one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)
smoothed_one_hot_labels = (
one_hot_labels * (1 - label_smoothing) + label_smoothing / num_classes
)
if class_weights is None:
class_weights = jnp.ones(num_classes, dtype=jnp.float32)
weight_mask = jnp.einsum(
"...y,y->...",
one_hot_labels,
class_weights,
)
valid_mask *= weight_mask
xentropy = optax.softmax_cross_entropy(
logits=logits, labels=smoothed_one_hot_labels
)
xentropy *= valid_mask.astype(jnp.float32)
loss = jnp.sum(xentropy) / normalizer
return loss