Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SimCLR NT Xent loss does not take into account batches from other DDP processes #179

Closed
OlivierDehaene opened this issue Sep 3, 2020 · 1 comment · Fixed by #329
Assignees
Labels
help wanted Extra attention is needed
Milestone

Comments

@OlivierDehaene
Copy link

OlivierDehaene commented Sep 3, 2020

🐛 Bug

In the NT Xent loss, out_1 and out_2 are not gathered over the whole DDP process group. This is a big issue as the loss is only classifying the correct pair over local_batch_size*2 possible pairs instead of over world_size*local_batch_size*2 possible pairs.

Code Sample

See "Gather hidden1/hidden2 across replicas and create local labels." comment in original implementation:

def add_contrastive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         tpu_context=None,
                         weights=1.0):
  """Compute loss for model.
  Args:
    hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.
  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
  # Get (normalized) hidden1 and hidden2.
  if hidden_norm:
    hidden = tf.math.l2_normalize(hidden, -1)
  hidden1, hidden2 = tf.split(hidden, 2, 0)
  batch_size = tf.shape(hidden1)[0]

  # Gather hidden1/hidden2 across replicas and create local labels.
  if tpu_context is not None:
    hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
    hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
    enlarged_batch_size = tf.shape(hidden1_large)[0]
    # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
    replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
    labels_idx = tf.range(batch_size) + replica_id * batch_size
    labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
    masks = tf.one_hot(labels_idx, enlarged_batch_size)
  else:
    hidden1_large = hidden1
    hidden2_large = hidden2
    labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
    masks = tf.one_hot(tf.range(batch_size), batch_size)

  logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
  logits_aa = logits_aa - masks * LARGE_NUM
  logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
  logits_bb = logits_bb - masks * LARGE_NUM
  logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
  logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature

  loss_a = tf.losses.softmax_cross_entropy(
      labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
  loss_b = tf.losses.softmax_cross_entropy(
      labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
  loss = loss_a + loss_b

  return loss, logits_ab, labels
@OlivierDehaene OlivierDehaene added the help wanted Extra attention is needed label Sep 3, 2020
@pbontrager
Copy link

I would like to second this. I was trying to see if this was happening in a hidden way, similarly to how sync_batch_norm is applied but I couldn't find it. I don't know if Pytorch Lightening has an idiomatic way to allow for this. Maybe it should expose its own allgather, broadcast, ... commands that connect to whichever distributed backend?

This was referenced Nov 2, 2020
@ananyahjha93 ananyahjha93 self-assigned this Nov 3, 2020
@Borda Borda added this to the v0.3 milestone Jan 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants