@@ -356,19 +356,12 @@ void nllnd_loss_backward_impl(
356356 MPSShape* weight_shape = getMPSShape (weight);
357357 MPSShape* total_weight_shape = getMPSShape (total_weight);
358358
359- NSString * ns_shape_key = [[input_shape valueForKey: @" description" ] componentsJoinedByString: @" ," ];
360-
361359 string key = " nllnd_loss_backward_impl:" + to_string (numClasses) + " :" +
362360 to_string (ignore_index) + " :" +
363361 to_string (isWeightsArrayValid) + " :" +
364362 reductionToString (reduction) + " :" +
365- [ns_shape_key UTF8String ] + " :" +
366- getMPSTypeString (input.scalar_type ()) + " :" +
367- getMPSTypeString (target.scalar_type ()) + " :" +
368- getMPSTypeString (weight.scalar_type ()) + " :" +
369- getMPSTypeString (total_weight.scalar_type ());
363+ getTensorsStringKey ({input, target, weight, total_weight});
370364 CachedGraph* cachedGraph = static_cast <CachedGraph *>(cache_->LookUp (key));
371-
372365 if (!cachedGraph) {
373366 MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph (key, ^ MPSCachedGraph * () {
374367
@@ -408,12 +401,11 @@ void nllnd_loss_backward_impl(
408401 }
409402
410403 float onValue = -1 .0f ;
404+ auto target_axis = target.defined () ? target.dim () : 1 ;
411405
412- MPSGraphTensor *oneHotTensor;
413-
414- oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
406+ MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
415407 depth: numClasses
416- axis: 1
408+ axis: target_axis
417409 dataType: inputTensor.dataType
418410 onValue: onValue
419411 offValue: 0 .0f
0 commit comments