Skip to content

Commit 56d1286

Browse files
abhudevkulinseth
authored andcommitted
Handle 1D inputs (#49)
* Add test for NLL 1d * Fix forward NLL for 1D case * Handle NLL backward for 1d * Fix handling of 0D target; remove newlines
1 parent 9ee3120 commit 56d1286

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

aten/src/ATen/native/mps/operations/LossOps.mm

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -314,18 +314,18 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target,
314314

315315
// NLLLoss
316316
void nllnd_loss_backward_impl(
317-
Tensor& grad_input,
317+
Tensor& grad_input_arg,
318318
const Tensor& grad_output,
319-
const Tensor& input,
320-
const Tensor& target,
319+
const Tensor& input_arg,
320+
const Tensor& target_arg,
321321
const Tensor& weight,
322322
int64_t reduction,
323323
int64_t ignore_index,
324324
const Tensor& total_weight,
325325
bool is2D)
326326
{
327327
// Empty output
328-
if(grad_input.numel() == 0)
328+
if(grad_input_arg.numel() == 0)
329329
return;
330330

331331
MPSStream* stream = getCurrentMPSStream();
@@ -342,6 +342,10 @@ void nllnd_loss_backward_impl(
342342

343343
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
344344

345+
auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
346+
auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
347+
auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;
348+
345349
@autoreleasepool {
346350

347351
auto numClasses = grad_input.sizes()[1];
@@ -472,24 +476,24 @@ void nllnd_loss_backward_impl(
472476
void nllnd_loss_forward_impl
473477
(Tensor& output,
474478
Tensor& total_weight,
475-
const Tensor& input,
476-
const Tensor& target,
479+
const Tensor& input_arg,
480+
const Tensor& target_arg,
477481
const Tensor& weight,
478482
int64_t reduction,
479483
int64_t ignore_index,
480484
bool is2D)
481485
{
482-
std::vector<long long> reshapedTarget(target.sizes().begin(), target.sizes().end());
486+
std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
483487
reshapedTarget.push_back(1);
484488

485-
Tensor batchSizeTensor = at::empty_like(input).resize_(IntArrayRef(1));
489+
Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1));
486490
float batchVal = 1.0f;
487491
for(size_t i = 0; i < reshapedTarget.size(); ++i)
488492
batchVal *= reshapedTarget[i];
489493
batchSizeTensor[0] = batchVal;
490494

491495
if(reduction == Reduction::None)
492-
output.resize_(target.sizes());
496+
output.resize_(target_arg.sizes());
493497
if(reduction == Reduction::Sum)
494498
output.resize_({});
495499
if(reduction == Reduction::Mean)
@@ -516,6 +520,9 @@ void nllnd_loss_backward_impl(
516520

517521
MPSStream* stream = getCurrentMPSStream();
518522

523+
auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
524+
auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
525+
519526
@autoreleasepool {
520527

521528
bool isWeightsArrayValid = (weight.numel() > 0);

test/test_mps.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,6 +1701,26 @@ def _nll_loss_helper(self, input_size, reduction, expected):
17011701
output_mps.sum().backward()
17021702
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
17031703

1704+
def _nll_loss_1d_helper(self, input_size, reduction):
1705+
1706+
# CPU
1707+
input = torch.rand(input_size, requires_grad=True, device='cpu')
1708+
num_channels = input_size[0]
1709+
target = torch.randint(num_channels, [], device='cpu')
1710+
1711+
# MPS
1712+
input_mps = input.detach().clone().to('mps').requires_grad_()
1713+
target_mps = target.detach().clone().to('mps')
1714+
1715+
output_cpu = F.nll_loss(input, target, reduction=reduction)
1716+
output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
1717+
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
1718+
self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))
1719+
1720+
output_cpu.sum().backward()
1721+
output_mps.sum().backward()
1722+
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
1723+
17041724
def test_as_strided(self):
17051725
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
17061726
values_1 = [[1.0, 1.0], [1.0, 1.0]]
@@ -1743,6 +1763,11 @@ def helper(n, c):
17431763

17441764
helper(3, 3)
17451765

1766+
def test_nll_loss_1d(self, device='cpu'):
1767+
self._nll_loss_1d_helper([10], "none")
1768+
self._nll_loss_1d_helper([10], "mean")
1769+
self._nll_loss_1d_helper([10], "sum")
1770+
17461771
def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
17471772
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
17481773
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))

0 commit comments

Comments
 (0)