diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 6473f6988f436..779ff0fffada8 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -590,7 +590,7 @@ Tensor std_var_common_impl_mps( NSMutableArray *axes = nil; NSMutableArray *apparent_output_shape = nil; NSMutableArray *apparent_input_shape = nil; - int64_t* output_shape = nil; + std::vector output_shape; if ((!keepdim && !use_dim) || (!keepdim && use_dim && dim_value.size() <= 0)) { @@ -630,7 +630,6 @@ Tensor std_var_common_impl_mps( axes); num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims; - output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t)); unsigned int curr_i = 0; for (int i = 0; i < num_input_dims; i++) @@ -645,13 +644,17 @@ Tensor std_var_common_impl_mps( } } if (found) continue; - output_shape[curr_i] = input_shape[i]; + output_shape.push_back(input_shape[i]); curr_i += 1; + // End loop when output shape is filled + if(curr_i == num_output_dims) + break; } for(int i = 0; i < num_reduce_dims; i++) { - correction_n *= input_shape[dim_value[i]]; + auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); + correction_n *= input_shape[wrap_dim]; } // (3, 4, 5) --> (3, 5) } @@ -668,10 +671,9 @@ Tensor std_var_common_impl_mps( input_shape, axes); num_output_dims = num_input_dims; - output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t)); for (int i = 0; i < num_input_dims; i++) { - output_shape[i] = (int64_t) 1; + output_shape.push_back((int64_t) 1); correction_n *= input_shape[i]; } // scalar --> vector case [[1.0034567]] @@ -691,21 +693,24 @@ Tensor std_var_common_impl_mps( axes); num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; - output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t)); for(int i = 0; i < num_reduce_dims; i++) { - correction_n *= input_shape[dim_value[i]]; + auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); + correction_n *= input_shape[wrap_dim]; } for (int i = 0; i < num_input_dims; i++) { - output_shape[i] = [apparent_output_shape[i] longValue]; + output_shape.push_back([apparent_output_shape[i] longValue]); } } + int64_t output_shape_array[output_shape.size()]; + std::copy(output_shape.begin(), output_shape.end(), output_shape_array); + Tensor output_t = at::native::empty_mps( - IntArrayRef(output_shape, num_output_dims), + IntArrayRef(output_shape_array, num_output_dims), input_t.scalar_type(), c10::nullopt, kMPS, @@ -790,7 +795,7 @@ Tensor std_var_common_impl_mps( }; native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - free(output_shape); + return output_t; } diff --git a/test/test_mps.py b/test/test_mps.py index 6fc27d3ac71b8..23e6512bcebd9 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2948,104 +2948,48 @@ def helper(shape): helper((9, 5, 6, 7)) # Test var - def test_var(self): - def helper(shape): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) - x = cpu_x.detach().clone().to('mps') - - all_var = torch.var(x, unbiased=False) - all_var_cpu = torch.var(cpu_x, unbiased=False) - - self.assertEqual(all_var, all_var_cpu) - - nil_dim_var = torch.var(x, dim=[], unbiased=False) - nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False) - - self.assertEqual(nil_dim_var, nil_dim_var_cpu) - - nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False) - nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False) - - self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim) - - zero_dim_var = torch.var(x, dim=[0], unbiased=False) - zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False) - - self.assertEqual(zero_dim_var, zero_dim_var_cpu) - - zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False) - zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False) - - self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim) - - zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False) - zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False) - - self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu) + def test_var_simple(self): + def helper(): - zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False) - zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False) + shape = [2,3,4,5] - self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim) - - two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False) - two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False) - - self.assertEqual(two_three_dim_var, two_three_dim_var_cpu) - - two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False) - two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False) - - self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu) - - all_var = torch.var(x, unbiased=True) - all_var_cpu = torch.var(cpu_x, unbiased=True) - - self.assertEqual(all_var, all_var_cpu) - - nil_dim_var = torch.var(x, dim=[], unbiased=True) - nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True) - - self.assertEqual(nil_dim_var, nil_dim_var_cpu) - - nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True) - nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True) + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + x = cpu_x.detach().clone().to('mps') - self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim) + for unbiased in [False, True]: + for keepdim in [False, True]: - zero_dim_var = torch.var(x, dim=[0], unbiased=True) - zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True) + zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased) + zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased) - self.assertEqual(zero_dim_var, zero_dim_var_cpu) + self.assertEqual(zero_dim_var, zero_dim_var_cpu) - zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True) - zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True) + all_var = torch.var(x, unbiased=unbiased) + all_var_cpu = torch.var(cpu_x, unbiased=unbiased) - self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim) + self.assertEqual(all_var, all_var_cpu) - zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True) - zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True) + nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased) + nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased) - self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu) + self.assertEqual(nil_dim_var, nil_dim_var_cpu) - zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True) - zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True) + zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased) + zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased) - self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim) + self.assertEqual(zero_dim_var, zero_dim_var_cpu) - two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True) - two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True) + zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) + zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) - self.assertEqual(two_three_dim_var, two_three_dim_var_cpu) + self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu) - two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True) - two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True) + two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) + two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) - self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu) + self.assertEqual(two_three_dim_var, two_three_dim_var_cpu) - helper((4, 5, 6, 7)) - # verify if a change in shape of input would cause problems with graph caching - helper((9, 5, 6, 7)) + helper() # Test forward amax def test_amax(self):