diff --git a/flamby/datasets/fed_synthetic/loss.py b/flamby/datasets/fed_synthetic/loss.py index 63b141ccc..08094a7d7 100644 --- a/flamby/datasets/fed_synthetic/loss.py +++ b/flamby/datasets/fed_synthetic/loss.py @@ -8,4 +8,4 @@ def __init__(self): self.ce = torch.nn.CrossEntropyLoss() def forward(self, input: torch.Tensor, target: torch.Tensor): - return self.ce(input, target.squeeze().long()) + return self.ce(input, target.squeeze(axis=1).long())