Skip to content

Commit d345020

Browse files
committed
Dev/kulin/nll (#189)
* Fix the NLLLoss2D crash. * Cleanup.
1 parent 158bb02 commit d345020

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

aten/src/ATen/native/mps/operations/LossOps.mm

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -355,19 +355,12 @@ void nllnd_loss_backward_impl(
355355
MPSShape* weight_shape = getMPSShape(weight);
356356
MPSShape* total_weight_shape = getMPSShape(total_weight);
357357

358-
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
359-
360358
string key = "nllnd_loss_backward_impl:" + to_string(numClasses) + ":" +
361359
to_string(ignore_index) + ":" +
362360
to_string(isWeightsArrayValid) + ":" +
363361
reductionToString(reduction) + ":" +
364-
[ns_shape_key UTF8String] + ":" +
365-
getMPSTypeString(input.scalar_type()) + ":" +
366-
getMPSTypeString(target.scalar_type()) + ":" +
367-
getMPSTypeString(weight.scalar_type()) + ":" +
368-
getMPSTypeString(total_weight.scalar_type());
362+
getTensorsStringKey({input, target, weight, total_weight});
369363
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
370-
371364
if(!cachedGraph) {
372365
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
373366

@@ -407,12 +400,11 @@ void nllnd_loss_backward_impl(
407400
}
408401

409402
float onValue = -1.0f;
403+
auto target_axis = target.defined() ? target.dim() : 1;
410404

411-
MPSGraphTensor *oneHotTensor;
412-
413-
oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
405+
MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
414406
depth:numClasses
415-
axis:1
407+
axis:target_axis
416408
dataType:inputTensor.dataType
417409
onValue:onValue
418410
offValue:0.0f

test/test_mps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2398,6 +2398,20 @@ def test_smooth_l1_loss_reduction_mean_sum_backward(self):
23982398

23992399

24002400
class TestNLLLoss(TestCase):
2401+
def test_nll2d_loss_backward(self, device='mps'):
2402+
a = torch.randn(3, 5, requires_grad=True, device=device)
2403+
b = torch.tensor([1, 0, 4], device=device)
2404+
loss = nn.NLLLoss()
2405+
out = loss(a, b)
2406+
self.assertIsNone(out.grad_fn._saved_weight)
2407+
loss = nn.NLLLoss(weight=torch.ones((5,), device=device))
2408+
out = loss(a, b)
2409+
self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,)))
2410+
2411+
out.sum().backward()
2412+
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
2413+
out.grad_fn._saved_weight
2414+
24012415
def test_nll_loss_mismatched_batch(self, device='mps'):
24022416
x = torch.randn((10, 3), requires_grad=True, device=device)
24032417
# t should have size (10,)

0 commit comments

Comments
 (0)