@@ -355,19 +355,12 @@ void nllnd_loss_backward_impl(
355355 MPSShape* weight_shape = getMPSShape (weight);
356356 MPSShape* total_weight_shape = getMPSShape (total_weight);
357357
358- NSString * ns_shape_key = [[input_shape valueForKey: @" description" ] componentsJoinedByString: @" ," ];
359-
360358 string key = " nllnd_loss_backward_impl:" + to_string (numClasses) + " :" +
361359 to_string (ignore_index) + " :" +
362360 to_string (isWeightsArrayValid) + " :" +
363361 reductionToString (reduction) + " :" +
364- [ns_shape_key UTF8String ] + " :" +
365- getMPSTypeString (input.scalar_type ()) + " :" +
366- getMPSTypeString (target.scalar_type ()) + " :" +
367- getMPSTypeString (weight.scalar_type ()) + " :" +
368- getMPSTypeString (total_weight.scalar_type ());
362+ getTensorsStringKey ({input, target, weight, total_weight});
369363 CachedGraph* cachedGraph = static_cast <CachedGraph *>(cache_->LookUp (key));
370-
371364 if (!cachedGraph) {
372365 MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph (key, ^ MPSCachedGraph * () {
373366
@@ -407,12 +400,11 @@ void nllnd_loss_backward_impl(
407400 }
408401
409402 float onValue = -1 .0f ;
403+ auto target_axis = target.defined () ? target.dim () : 1 ;
410404
411- MPSGraphTensor *oneHotTensor;
412-
413- oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
405+ MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor: udpatedTargetTensor
414406 depth: numClasses
415- axis: 1
407+ axis: target_axis
416408 dataType: inputTensor.dataType
417409 onValue: onValue
418410 offValue: 0 .0f
0 commit comments