Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/LossOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target,
newCachedGraph->gradInputTensor = bceLoss;
}
} else {
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size());
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size());
}
}
return newCachedGraph;
Expand Down
142 changes: 141 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,7 @@ def helper(shape, reduction):
helper([8, 4, 5, 7, 6], 'mean')

# Binary Cross Enropy
def test_bce_loss(self):
def test_bce_loss_simple(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.BCELoss(reduction=reduction)
Expand Down Expand Up @@ -1725,6 +1725,146 @@ def helper(shape, reduction):
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
helper([1,1,32,32], 'mean')

def test_bce_loss_always_nonnegative(self):
target = torch.ones(5, device='mps')
input = torch.ones(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)

target = torch.zeros(5, device='mps')
input = torch.zeros(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)

def test_bce_loss_size_mismatch(self):
bceloss = nn.BCELoss()
a = torch.rand(25, device='mps')
b = torch.rand(25, 1, device='mps')
with self.assertRaisesRegex(ValueError, r'Using a target size \('):
bceloss(a, b)

def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
x_size = 1024
y_size = 256
target = torch.rand(x_size, y_size, device='mps')

for reduction in ['none', 'mean', 'sum']:
output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
output_logits = output_sig.clone().detach()

output_sig.requires_grad = True
output_logits.requires_grad = True
weight = torch.rand(y_size, device='mps')

loss_sig = nn.BCELoss(weight, reduction=reduction)(
torch.sigmoid(output_sig), target
)
loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
output_logits, target
)

self.assertEqual(loss_logits, loss_sig)

if reduction == 'none':
grad = torch.rand(x_size, y_size, device='mps')
loss_sig.backward(grad)
loss_logits.backward(grad)
else:
loss_sig.backward()
loss_logits.backward()

self.assertEqual(output_sig.grad, output_logits.grad)

def test_bce_with_logits_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
self.assertEqual(output.grad, expected_grad)

def test_bce_with_logits_broadcasts_weights(self):
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5

weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)

weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)

self.assertEqual(out1, out2)

weight = torch.rand(16, 1, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)

weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)

self.assertEqual(out1, out2)

def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.ones(64, 4, device='mps')

self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))

def test_bce_with_logits_broadcasts_pos_weights(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)

pos_weight1 = pos_weight.expand(1, 4)
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)

pos_weight2 = pos_weight.expand(64, 4)
out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)

self.assertEqual(out1, out2)
self.assertEqual(out1, out3)

def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
pos_weight = torch.ones(3, 1, device='mps')
nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
grad = output.grad
self.assertEqual(grad, expected_grad)

def test_bce_with_logits_stability(self):
output = torch.tensor([0., -120.], device='mps')
target = torch.tensor([0., 1.], device='mps')
pos_weight = torch.tensor([1., 1.], device='mps')

out1 = nn.BCEWithLogitsLoss()(output, target)
self.assertTrue(torch.isfinite(out1).all().item())

out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
self.assertTrue(torch.isfinite(out2).all().item())

def test_bce_loss_broadcasts_weights(self):
sigmoid = nn.Sigmoid()
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5

weight = torch.rand(4, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)

weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)

self.assertEqual(out1, out2)

weight = torch.rand(16, 1, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)

weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)

self.assertEqual(out1, out2)

def test_log_softmax(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
Expand Down