From 41fab53e85dfae2c8908322a9aed3bf9764b7200 Mon Sep 17 00:00:00 2001 From: CLRSDev Date: Wed, 16 Feb 2022 11:51:58 -0800 Subject: [PATCH] Bug in hint loss for CATEGORICAL type. The number of unmasked datapoints (jnp.sum(unmasked_data)) was computed over the whole time sequence instead of the pertinent time slice. The change doesn't seem to affect results. PiperOrigin-RevId: 429104366 --- clrs/_src/losses.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/clrs/_src/losses.py b/clrs/_src/losses.py index 35b88ce7..6f24acbd 100644 --- a/clrs/_src/losses.py +++ b/clrs/_src/losses.py @@ -31,6 +31,8 @@ _Trajectory = samplers.Trajectory _Type = specs.Type +EPS = 1e-12 + def output_loss(truth: _DataPoint, preds: _Trajectory, nb_nodes: int) -> float: """Calculates the output loss.""" @@ -166,19 +168,14 @@ def _hint_loss( truth.data[i + 1] * jax.nn.log_softmax(pred) * is_not_done, axis=-1)) elif truth.type_ == _Type.CATEGORICAL: - unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE] - masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( - jnp.float32) + masked_truth = truth.data[i + 1] * ( + truth.data[i + 1] != _OutputClass.MASKED).astype(jnp.float32) + loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred) * is_not_done, + axis=-1) if decode_diffs: - total_loss = jnp.sum(-jnp.sum( - masked_truth[i + 1] * jax.nn.log_softmax(pred), - axis=-1, - keepdims=True) * jnp.expand_dims(gt_diffs[i][truth.location], -1) * - is_not_done) / jnp.sum(unmasked_data) - else: - total_loss = jnp.sum( - -jnp.sum(masked_truth[i + 1] * jax.nn.log_softmax(pred), axis=-1) * - is_not_done) / jnp.sum(unmasked_data) + loss *= gt_diffs[i][truth.location] + total_unmasked = jnp.sum(truth.data[i + 1] == _OutputClass.POSITIVE) + total_loss = jnp.sum(loss) / jnp.maximum(total_unmasked, EPS) elif truth.type_ == _Type.POINTER: if decode_diffs: