Skip to content

Commit ff458f0

Browse files
authored
Dev/kulin/nll (#189)
* Fix the NLLLoss2D crash. * Cleanup.
1 parent 42f00a5 commit ff458f0

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

aten/src/ATen/native/mps/operations/LossOps.mm

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -356,19 +356,12 @@ void nllnd_loss_backward_impl(
356356
MPSShape* weight_shape = getMPSShape(weight);
357357
MPSShape* total_weight_shape = getMPSShape(total_weight);
358358

359-
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
360-
361359
string key = "nllnd_loss_backward_impl:" + to_string(numClasses) + ":" +
362360
to_string(ignore_index) + ":" +
363361
to_string(isWeightsArrayValid) + ":" +
364362
reductionToString(reduction) + ":" +
365-
[ns_shape_key UTF8String] + ":" +
366-
getMPSTypeString(input.scalar_type()) + ":" +
367-
getMPSTypeString(target.scalar_type()) + ":" +
368-
getMPSTypeString(weight.scalar_type()) + ":" +
369-
getMPSTypeString(total_weight.scalar_type());
363+
getTensorsStringKey({input, target, weight, total_weight});
370364
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
371-
372365
if(!cachedGraph) {
373366
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
374367

@@ -408,12 +401,11 @@ void nllnd_loss_backward_impl(
408401
}
409402

410403
float onValue = -1.0f;
404+
auto target_axis = target.defined() ? target.dim() : 1;
411405

412-
MPSGraphTensor *oneHotTensor;
413-
414-
oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
406+
MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
415407
depth:numClasses
416-
axis:1
408+
axis:target_axis
417409
dataType:inputTensor.dataType
418410
onValue:onValue
419411
offValue:0.0f

test/test_mps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,6 +2308,20 @@ def test_smooth_l1_loss_reduction_mean_sum_backward(self):
23082308

23092309

23102310
class TestNLLLoss(TestCase):
2311+
def test_nll2d_loss_backward(self, device='mps'):
2312+
a = torch.randn(3, 5, requires_grad=True, device=device)
2313+
b = torch.tensor([1, 0, 4], device=device)
2314+
loss = nn.NLLLoss()
2315+
out = loss(a, b)
2316+
self.assertIsNone(out.grad_fn._saved_weight)
2317+
loss = nn.NLLLoss(weight=torch.ones((5,), device=device))
2318+
out = loss(a, b)
2319+
self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,)))
2320+
2321+
out.sum().backward()
2322+
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
2323+
out.grad_fn._saved_weight
2324+
23112325
def test_nll_loss_mismatched_batch(self, device='mps'):
23122326
x = torch.randn((10, 3), requires_grad=True, device=device)
23132327
# t should have size (10,)

0 commit comments

Comments
 (0)