-
Notifications
You must be signed in to change notification settings - Fork 667
Closed
Labels
Description
Here is the current implementation of the distributed wrapper.
| def all_gather_embeddings_labels(embeddings, labels): | |
| if c_f.is_list_or_tuple(embeddings): | |
| assert c_f.is_list_or_tuple(labels) | |
| all_embeddings, all_labels = [], [] | |
| for i in range(len(embeddings)): | |
| E, L = all_gather(embeddings[i], labels[i]) | |
| all_embeddings.append(E) | |
| all_labels.append(L) | |
| embeddings = torch.cat(all_embeddings, dim=0) | |
| labels = torch.cat(all_labels, dim=0) | |
| else: | |
| embeddings, labels = all_gather(embeddings, labels) | |
| return embeddings, labels | |
| class DistributedLossWrapper(torch.nn.Module): | |
| def __init__(self, loss, **kwargs): | |
| super().__init__() | |
| has_parameters = len([p for p in loss.parameters()]) > 0 | |
| self.loss = DDP(loss, **kwargs) if has_parameters else loss | |
| def forward(self, embeddings, labels, *args, **kwargs): | |
| embeddings, labels = all_gather_embeddings_labels(embeddings, labels) | |
| return self.loss(embeddings, labels, *args, **kwargs) |
But according to this blog, the loss function should be multiplied by world_size.
What do you think about it? Maybe I can create a PR to fix it.