@@ -744,8 +744,8 @@ Tensor std_var_common_impl_mps(
744744 }
745745 }
746746
747- bool use_correction = correction.has_value ();
748- const auto correction_value = use_correction ? correction.value () : false ;
747+ bool use_correction = !( correction.has_value () && correction. value () == 0 );
748+ const auto correction_value = correction.value_or ( 1 ) ;
749749 int64_t correction_n = 1 ;
750750
751751 native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance ();
@@ -884,7 +884,7 @@ Tensor std_var_common_impl_mps(
884884 return output_t ;
885885 }
886886
887- double bessel_correction = ((double ) correction_n) / ((double ) (correction_n-1 ));
887+ double bessel_correction = ((double ) correction_n) / ((double ) (correction_n-correction_value ));
888888
889889 auto stream = at::mps::getCurrentMPSStream ();
890890
@@ -894,7 +894,7 @@ Tensor std_var_common_impl_mps(
894894 string bessel_corrected = (use_correction && correction_value) ? " unbiased " : " biased " ;
895895 string use_dim_info = (use_dim) ? " use_dim=1:" + to_string (dim_value.size ()) : " use_dim=0" ;
896896 string keepdim_info = (keepdim) ? " keepdim=1" : " keepdim=0" ;
897- string key = op_key + use_dim_info + " :" + keepdim_info + " :" + string ([ns_key UTF8String ]) + " :" + native_mps::getTensorsStringKey (input_t ) + " :" + bessel_corrected;
897+ string key = op_key + use_dim_info + " :" + keepdim_info + " :" + string ([ns_key UTF8String ]) + " :" + native_mps::getTensorsStringKey (input_t ) + " :" + bessel_corrected + " : " + std::to_string (correction_value) ;
898898
899899 auto cachedGraph = cache_->LookUpAs <CachedGraph>(key);
900900 // Initialize once if configuration not found in cache
@@ -916,7 +916,7 @@ Tensor std_var_common_impl_mps(
916916 if (use_correction && correction_value)
917917 {
918918 MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar: bessel_correction
919- dataType: MPSDataTypeFloat32 ];
919+ dataType: native_mps::getMPSDataType ( input_t . scalar_type ()) ];
920920 MPSGraphTensor *correctedTensor = [mpsGraph multiplicationWithPrimaryTensor: outputVarTensor
921921 secondaryTensor: besselTensor
922922 name: nil ];
0 commit comments