Skip to content

Gradient in distributed training #363

@KinglittleQ

Description

@KinglittleQ

Here is the current implementation of the distributed wrapper.

def all_gather_embeddings_labels(embeddings, labels):
if c_f.is_list_or_tuple(embeddings):
assert c_f.is_list_or_tuple(labels)
all_embeddings, all_labels = [], []
for i in range(len(embeddings)):
E, L = all_gather(embeddings[i], labels[i])
all_embeddings.append(E)
all_labels.append(L)
embeddings = torch.cat(all_embeddings, dim=0)
labels = torch.cat(all_labels, dim=0)
else:
embeddings, labels = all_gather(embeddings, labels)
return embeddings, labels
class DistributedLossWrapper(torch.nn.Module):
def __init__(self, loss, **kwargs):
super().__init__()
has_parameters = len([p for p in loss.parameters()]) > 0
self.loss = DDP(loss, **kwargs) if has_parameters else loss
def forward(self, embeddings, labels, *args, **kwargs):
embeddings, labels = all_gather_embeddings_labels(embeddings, labels)
return self.loss(embeddings, labels, *args, **kwargs)

But according to this blog, the loss function should be multiplied by world_size.

What do you think about it? Maybe I can create a PR to fix it.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions