Skip to content

Commit

Permalink
Add support for ragged tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Frightera committed Mar 7, 2023
1 parent 5696b5a commit 40e547f
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,26 @@ def _smooth_labels():
from_logits=from_logits,
axis=axis, )

@dispatch.dispatch_for_types(categorical_focal_crossentropy, tf.RaggedTensor)
def _ragged_tensor_categorical_focal_crossentropy(
y_true,
y_pred,
alpha=0.25,
gamma=2.0,
from_logits=False,
label_smoothing=0.0,
axis=-1,
):
fn = functools.partial(
categorical_focal_crossentropy,
alpha=alpha,
gamma=gamma,
from_logits=from_logits,
label_smoothing=label_smoothing,
axis=axis,
)
return _ragged_tensor_apply_loss(fn, y_true, y_pred)

@keras_export(
"keras.metrics.sparse_categorical_crossentropy",
"keras.losses.sparse_categorical_crossentropy",
Expand Down

0 comments on commit 40e547f

Please sign in to comment.