Skip to content

Commit

Permalink
Correctly apply weights to oneHotTensor in NLLLoss (#233)
Browse files Browse the repository at this point in the history
Co-authored-by: Siddharth Kotapati <sidk@Siddharths-MacBook-Pro.local>
  • Loading branch information
2 people authored and kulinseth committed Feb 5, 2023
1 parent 83877a2 commit c9083ea
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
19 changes: 15 additions & 4 deletions aten/src/ATen/native/mps/operations/LossOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -400,18 +400,29 @@ void nllnd_loss_backward_impl(
}

float onValue = -1.0f;
auto target_axis = target.defined() ? target.dim() : 1;

MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor
depth:numClasses
axis:target_axis
axis:1
dataType:inputTensor.dataType
onValue:onValue
offValue:0.0f
name:nil];

if(isWeightsArrayValid)
{
if(isWeightsArrayValid) {
int64_t nDim = input.sizes().size();
IntArrayRef sizes = input.sizes();
std::vector<NSNumber*> numbers(nDim);
for (const auto i: c10::irange(nDim)) {
NSInteger sz_i = (i == 1) ? sizes[i] : 1;
NSNumber* number = [NSNumber numberWithInteger:sz_i];
numbers[i] = number;
}

MPSGraphTensor *weightTensorReshaped = [mpsGraph reshapeTensor:weightTensor
withShape:[NSArray arrayWithObjects:numbers.data() count:numbers.size()]
name:nil];

oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
secondaryTensor:weightTensor
name:@"scaleByWeightTensor"];
Expand Down
6 changes: 4 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,14 +2490,16 @@ def _nll_loss_helper(self, input_size, reduction, expected):
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[1]
target_size = (input_size[0], ) + tuple(input_size[2:])
weights = torch.randn(num_channels)
weights_mps = weights.to("mps")
target = torch.randint(num_channels, target_size, device='cpu')

# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')

output_cpu = F.nll_loss(input, target, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))

output_cpu.sum().backward()
Expand Down

0 comments on commit c9083ea

Please sign in to comment.