Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions aten/src/ATen/native/mps/operations/LossOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -314,18 +314,18 @@ void mse_loss_out_impl(const Tensor& input, const Tensor& target,

// NLLLoss
void nllnd_loss_backward_impl(
Tensor& grad_input,
Tensor& grad_input_arg,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& input_arg,
const Tensor& target_arg,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight,
bool is2D)
{
// Empty output
if(grad_input.numel() == 0)
if(grad_input_arg.numel() == 0)
return;

MPSStream* stream = getCurrentMPSStream();
Expand All @@ -342,6 +342,10 @@ void nllnd_loss_backward_impl(

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;
auto grad_input = grad_input_arg.dim() == 1 ? grad_input_arg.view({1, grad_input_arg.size(0)}) : grad_input_arg;

@autoreleasepool {

auto numClasses = grad_input.sizes()[1];
Expand Down Expand Up @@ -472,24 +476,24 @@ void nllnd_loss_backward_impl(
void nllnd_loss_forward_impl
(Tensor& output,
Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& input_arg,
const Tensor& target_arg,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
bool is2D)
{
std::vector<long long> reshapedTarget(target.sizes().begin(), target.sizes().end());
std::vector<long long> reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end());
reshapedTarget.push_back(1);

Tensor batchSizeTensor = at::empty_like(input).resize_(IntArrayRef(1));
Tensor batchSizeTensor = at::empty_like(input_arg).resize_(IntArrayRef(1));
float batchVal = 1.0f;
for(size_t i = 0; i < reshapedTarget.size(); ++i)
batchVal *= reshapedTarget[i];
batchSizeTensor[0] = batchVal;

if(reduction == Reduction::None)
output.resize_(target.sizes());
output.resize_(target_arg.sizes());
if(reduction == Reduction::Sum)
output.resize_({});
if(reduction == Reduction::Mean)
Expand All @@ -516,6 +520,9 @@ void nllnd_loss_backward_impl(

MPSStream* stream = getCurrentMPSStream();

auto input = input_arg.dim() == 1 ? input_arg.view({1, input_arg.size(0)}) : input_arg;
auto target = target_arg.dim() == 0 ? target_arg.view({1}) : target_arg;

@autoreleasepool {

bool isWeightsArrayValid = (weight.numel() > 0);
Expand Down
25 changes: 25 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,26 @@ def _nll_loss_helper(self, input_size, reduction, expected):
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))

def _nll_loss_1d_helper(self, input_size, reduction):

# CPU
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[0]
target = torch.randint(num_channels, [], device='cpu')

# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')

output_cpu = F.nll_loss(input, target, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))

output_cpu.sum().backward()
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))

def test_as_strided(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values_1 = [[1.0, 1.0], [1.0, 1.0]]
Expand Down Expand Up @@ -1743,6 +1763,11 @@ def helper(n, c):

helper(3, 3)

def test_nll_loss_1d(self, device='cpu'):
self._nll_loss_1d_helper([10], "none")
self._nll_loss_1d_helper([10], "mean")
self._nll_loss_1d_helper([10], "sum")

def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
Expand Down