Skip to content

Commit

Permalink
Bug in hint loss for CATEGORICAL type. The number of unmasked datapoi…
Browse files Browse the repository at this point in the history
…nts (jnp.sum(unmasked_data)) was computed over the whole time sequence instead of the pertinent time slice.

The change doesn't seem to affect results.

PiperOrigin-RevId: 429104366
  • Loading branch information
CLRSDev authored and copybara-github committed Feb 16, 2022
1 parent 387b846 commit 41fab53
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions clrs/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
_Trajectory = samplers.Trajectory
_Type = specs.Type

EPS = 1e-12


def output_loss(truth: _DataPoint, preds: _Trajectory, nb_nodes: int) -> float:
"""Calculates the output loss."""
Expand Down Expand Up @@ -166,19 +168,14 @@ def _hint_loss(
truth.data[i + 1] * jax.nn.log_softmax(pred) * is_not_done, axis=-1))

elif truth.type_ == _Type.CATEGORICAL:
unmasked_data = truth.data[truth.data == _OutputClass.POSITIVE]
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
jnp.float32)
masked_truth = truth.data[i + 1] * (
truth.data[i + 1] != _OutputClass.MASKED).astype(jnp.float32)
loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred) * is_not_done,
axis=-1)
if decode_diffs:
total_loss = jnp.sum(-jnp.sum(
masked_truth[i + 1] * jax.nn.log_softmax(pred),
axis=-1,
keepdims=True) * jnp.expand_dims(gt_diffs[i][truth.location], -1) *
is_not_done) / jnp.sum(unmasked_data)
else:
total_loss = jnp.sum(
-jnp.sum(masked_truth[i + 1] * jax.nn.log_softmax(pred), axis=-1) *
is_not_done) / jnp.sum(unmasked_data)
loss *= gt_diffs[i][truth.location]
total_unmasked = jnp.sum(truth.data[i + 1] == _OutputClass.POSITIVE)
total_loss = jnp.sum(loss) / jnp.maximum(total_unmasked, EPS)

elif truth.type_ == _Type.POINTER:
if decode_diffs:
Expand Down

0 comments on commit 41fab53

Please sign in to comment.