diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 26570b02ac0d5..7f08968bf07cc 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -277,7 +277,7 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target, newCachedGraph->gradInputTensor = bceLoss; } } else { - newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size()); + newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); } } return newCachedGraph; diff --git a/test/test_mps.py b/test/test_mps.py index c6754cf308073..2cf74c1803f2a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1695,7 +1695,7 @@ def helper(shape, reduction): helper([8, 4, 5, 7, 6], 'mean') # Binary Cross Enropy - def test_bce_loss(self): + def test_bce_loss_simple(self): def helper(shape, reduction): # create the criterion loss = torch.nn.BCELoss(reduction=reduction) @@ -1725,6 +1725,146 @@ def helper(shape, reduction): # verify if changes in shape would cause cached graph lookup problems helper([7, 5, 2, 4, 6], 'sum') helper([8, 4, 5, 7, 6], 'mean') + helper([1,1,32,32], 'mean') + + def test_bce_loss_always_nonnegative(self): + target = torch.ones(5, device='mps') + input = torch.ones(5, device='mps') + self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) + + target = torch.zeros(5, device='mps') + input = torch.zeros(5, device='mps') + self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) + + def test_bce_loss_size_mismatch(self): + bceloss = nn.BCELoss() + a = torch.rand(25, device='mps') + b = torch.rand(25, 1, device='mps') + with self.assertRaisesRegex(ValueError, r'Using a target size \('): + bceloss(a, b) + + def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self): + x_size = 1024 + y_size = 256 + target = torch.rand(x_size, y_size, device='mps') + + for reduction in ['none', 'mean', 'sum']: + output_sig = torch.rand(x_size, y_size, device='mps') - 0.5 + output_logits = output_sig.clone().detach() + + output_sig.requires_grad = True + output_logits.requires_grad = True + weight = torch.rand(y_size, device='mps') + + loss_sig = nn.BCELoss(weight, reduction=reduction)( + torch.sigmoid(output_sig), target + ) + loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)( + output_logits, target + ) + + self.assertEqual(loss_logits, loss_sig) + + if reduction == 'none': + grad = torch.rand(x_size, y_size, device='mps') + loss_sig.backward(grad) + loss_logits.backward(grad) + else: + loss_sig.backward() + loss_logits.backward() + + self.assertEqual(output_sig.grad, output_logits.grad) + + def test_bce_with_logits_has_correct_grad_at_zero(self): + output = torch.zeros(3, 1, requires_grad=True, device='mps') + target = torch.zeros(3, 1, device='mps') + nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward() + expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) + self.assertEqual(output.grad, expected_grad) + + def test_bce_with_logits_broadcasts_weights(self): + target = torch.rand(16, 4, device='mps') + output = torch.rand(16, 4, device='mps') - 0.5 + + weight = torch.rand(4, device='mps') + out1 = nn.BCEWithLogitsLoss(weight)(output, target) + + weight = weight.expand(16, 4).contiguous() + out2 = nn.BCEWithLogitsLoss(weight)(output, target) + + self.assertEqual(out1, out2) + + weight = torch.rand(16, 1, device='mps') + out1 = nn.BCEWithLogitsLoss(weight)(output, target) + + weight = weight.expand(16, 4).contiguous() + out2 = nn.BCEWithLogitsLoss(weight)(output, target) + + self.assertEqual(out1, out2) + + def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self): + target = torch.rand(64, 4, device='mps') + output = torch.rand(64, 4, device='mps') - 0.5 + pos_weight = torch.ones(64, 4, device='mps') + + self.assertEqual(nn.BCEWithLogitsLoss()(output, target), + nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)) + + def test_bce_with_logits_broadcasts_pos_weights(self): + target = torch.rand(64, 4, device='mps') + output = torch.rand(64, 4, device='mps') - 0.5 + pos_weight = torch.rand(4, device='mps') + out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) + + pos_weight1 = pos_weight.expand(1, 4) + out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target) + + pos_weight2 = pos_weight.expand(64, 4) + out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target) + + self.assertEqual(out1, out2) + self.assertEqual(out1, out3) + + def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self): + output = torch.zeros(3, 1, requires_grad=True, device='mps') + target = torch.zeros(3, 1, device='mps') + pos_weight = torch.ones(3, 1, device='mps') + nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward() + expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) + grad = output.grad + self.assertEqual(grad, expected_grad) + + def test_bce_with_logits_stability(self): + output = torch.tensor([0., -120.], device='mps') + target = torch.tensor([0., 1.], device='mps') + pos_weight = torch.tensor([1., 1.], device='mps') + + out1 = nn.BCEWithLogitsLoss()(output, target) + self.assertTrue(torch.isfinite(out1).all().item()) + + out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) + self.assertTrue(torch.isfinite(out2).all().item()) + + def test_bce_loss_broadcasts_weights(self): + sigmoid = nn.Sigmoid() + target = torch.rand(16, 4, device='mps') + output = torch.rand(16, 4, device='mps') - 0.5 + + weight = torch.rand(4, device='mps') + out1 = nn.BCELoss(weight)(sigmoid(output), target) + + weight = weight.expand(16, 4).contiguous() + out2 = nn.BCELoss(weight)(sigmoid(output), target) + + self.assertEqual(out1, out2) + + weight = torch.rand(16, 1, device='mps') + out1 = nn.BCELoss(weight)(sigmoid(output), target) + + weight = weight.expand(16, 4).contiguous() + out2 = nn.BCELoss(weight)(sigmoid(output), target) + + self.assertEqual(out1, out2) def test_log_softmax(self): values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]