Skip to content

Commit

Permalink
Reformatting after focal loss implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Frightera committed Mar 9, 2023
1 parent 363baaf commit c267fa0
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 185 deletions.
67 changes: 34 additions & 33 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5578,41 +5578,41 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
def categorical_focal_crossentropy(
target,
output,
alpha=0.25,
gamma=2.0,
from_logits=False,
axis=-1,
target,
output,
alpha=0.25,
gamma=2.0,
from_logits=False,
axis=-1,
):
"""Categorical focal crossentropy (alpha balanced) between an output tensor and a target tensor.
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. By default, the focal tensor is computed as follows:
It has pt defined as:
pt = p, if y = 1 else 1 - p
The authors use alpha-balanced variant of focal loss in the paper:
FL(pt) = −α_t * (1 − pt)^gamma * log(pt)
Extending this to multi-class case is straightforward:
FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from negative log-likelihood and included in CE.
`modulating_factor` is (1 − pt)^gamma,
where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
effect on the categorical crossentropy.
Args:
target: A tensor with the same shape as `output`.
output: A tensor.
alpha: A weight balancing factor for all classes, default is `0.25` as
mentioned in the reference. It can be a list of floats or a scalar.
In the multi-class case, alpha may be set by inverse class frequency by
using `compute_class_weight` from `sklearn.utils`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference. It helps to gradually reduce the importance given to
simple examples in a smooth manner.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability distribution.
Returns:
A tensor.
"""
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. By default, the focal tensor is computed as follows:
It has pt defined as:
pt = p, if y = 1 else 1 - p
The authors use alpha-balanced variant of focal loss in the paper:
FL(pt) = −α_t * (1 − pt)^gamma * log(pt)
Extending this to multi-class case is straightforward:
FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from negative log-likelihood and included in CE.
`modulating_factor` is (1 − pt)^gamma,
where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal
effect on the categorical crossentropy.
Args:
target: A tensor with the same shape as `output`.
output: A tensor.
alpha: A weight balancing factor for all classes, default is `0.25` as
mentioned in the reference. It can be a list of floats or a scalar.
In the multi-class case, alpha may be set by inverse class frequency by
using `compute_class_weight` from `sklearn.utils`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference. It helps to gradually reduce the importance given to
simple examples in a smooth manner.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability distribution.
Returns:
A tensor.
"""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)
Expand Down Expand Up @@ -5642,6 +5642,7 @@ def categorical_focal_crossentropy(
focal_cce = tf.reduce_sum(focal_cce, axis=axis)
return focal_cce


@keras_export("keras.backend.sparse_categorical_crossentropy")
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
Expand Down
Loading

0 comments on commit c267fa0

Please sign in to comment.