You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the NT Xent loss, out_1 and out_2 are not gathered over the whole DDP process group. This is a big issue as the loss is only classifying the correct pair over local_batch_size*2 possible pairs instead of over world_size*local_batch_size*2 possible pairs.
Code Sample
See "Gather hidden1/hidden2 across replicas and create local labels." comment in original implementation:
defadd_contrastive_loss(hidden,
hidden_norm=True,
temperature=1.0,
tpu_context=None,
weights=1.0):
"""Compute loss for model. Args: hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """# Get (normalized) hidden1 and hidden2.ifhidden_norm:
hidden=tf.math.l2_normalize(hidden, -1)
hidden1, hidden2=tf.split(hidden, 2, 0)
batch_size=tf.shape(hidden1)[0]
# Gather hidden1/hidden2 across replicas and create local labels.iftpu_contextisnotNone:
hidden1_large=tpu_cross_replica_concat(hidden1, tpu_context)
hidden2_large=tpu_cross_replica_concat(hidden2, tpu_context)
enlarged_batch_size=tf.shape(hidden1_large)[0]
# TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.replica_id=tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
labels_idx=tf.range(batch_size) +replica_id*batch_sizelabels=tf.one_hot(labels_idx, enlarged_batch_size*2)
masks=tf.one_hot(labels_idx, enlarged_batch_size)
else:
hidden1_large=hidden1hidden2_large=hidden2labels=tf.one_hot(tf.range(batch_size), batch_size*2)
masks=tf.one_hot(tf.range(batch_size), batch_size)
logits_aa=tf.matmul(hidden1, hidden1_large, transpose_b=True) /temperaturelogits_aa=logits_aa-masks*LARGE_NUMlogits_bb=tf.matmul(hidden2, hidden2_large, transpose_b=True) /temperaturelogits_bb=logits_bb-masks*LARGE_NUMlogits_ab=tf.matmul(hidden1, hidden2_large, transpose_b=True) /temperaturelogits_ba=tf.matmul(hidden2, hidden1_large, transpose_b=True) /temperatureloss_a=tf.losses.softmax_cross_entropy(
labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
loss_b=tf.losses.softmax_cross_entropy(
labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
loss=loss_a+loss_breturnloss, logits_ab, labels
The text was updated successfully, but these errors were encountered:
I would like to second this. I was trying to see if this was happening in a hidden way, similarly to how sync_batch_norm is applied but I couldn't find it. I don't know if Pytorch Lightening has an idiomatic way to allow for this. Maybe it should expose its own allgather, broadcast, ... commands that connect to whichever distributed backend?
🐛 Bug
In the NT Xent loss,
out_1
andout_2
are not gathered over the whole DDP process group. This is a big issue as the loss is only classifying the correct pair overlocal_batch_size*2
possible pairs instead of overworld_size*local_batch_size*2
possible pairs.Code Sample
See "Gather hidden1/hidden2 across replicas and create local labels." comment in original implementation:
The text was updated successfully, but these errors were encountered: