Skip to content

Commit d84cf3d

Browse files
Ronian526Ronian
authored andcommitted
Fix std and var for float16 and float32 (#186)
* Fix std and var for float16 and float32 - fix type mismatch - add correction parameter to bessel correction calculation - use unbiased (correction=1) std / var by default * remove space Co-authored-by: Ronian <ronian@Ronians-MBP.attlocal.net>
1 parent 8e32de4 commit d84cf3d

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ Tensor std_var_common_impl_mps(
744744
}
745745
}
746746

747-
bool use_correction = correction.has_value();
748-
const auto correction_value = use_correction ? correction.value() : false;
747+
bool use_correction = !(correction.has_value() && correction.value() == 0);
748+
const auto correction_value = correction.value_or(1);
749749
int64_t correction_n = 1;
750750

751751
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
@@ -884,7 +884,7 @@ Tensor std_var_common_impl_mps(
884884
return output_t;
885885
}
886886

887-
double bessel_correction = ((double) correction_n) / ((double) (correction_n-1));
887+
double bessel_correction = ((double) correction_n) / ((double) (correction_n-correction_value));
888888

889889
auto stream = at::mps::getCurrentMPSStream();
890890

@@ -894,7 +894,7 @@ Tensor std_var_common_impl_mps(
894894
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
895895
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
896896
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
897-
string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected;
897+
string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
898898

899899
auto cachedGraph = cache_->LookUpAs<CachedGraph>(key);
900900
// Initialize once if configuration not found in cache
@@ -916,7 +916,7 @@ Tensor std_var_common_impl_mps(
916916
if (use_correction && correction_value)
917917
{
918918
MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar:bessel_correction
919-
dataType:MPSDataTypeFloat32];
919+
dataType: native_mps::getMPSDataType(input_t.scalar_type())];
920920
MPSGraphTensor *correctedTensor = [mpsGraph multiplicationWithPrimaryTensor: outputVarTensor
921921
secondaryTensor: besselTensor
922922
name: nil];

test/test_mps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8300,6 +8300,8 @@ class TestConsistency(TestCase):
83008300
'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
83018301
'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
83028302
'nn.functional.nll_loss': ['f32'],
8303+
'std': ['f16','f32'],
8304+
'var': ['f16','f32'],
83038305
}
83048306

83058307

@@ -8485,9 +8487,7 @@ class TestConsistency(TestCase):
84858487
'masked.sum': [torch.bool],
84868488

84878489
# Functions that hard crash
8488-
'std': [torch.float16],
84898490
'stft': [torch.float32],
8490-
'var': [torch.float16],
84918491
# + forward when requires_grad=True or running backward
84928492
'__rpow__': [torch.int64],
84938493
'masked.std': [torch.int32],

0 commit comments

Comments
 (0)