Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ Tensor std_var_common_impl_mps(
NSMutableArray<NSNumber *> *axes = nil;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
int64_t* output_shape = nil;
std::vector<int64_t> output_shape;

if ((!keepdim && !use_dim) || (!keepdim && use_dim && dim_value.size() <= 0))
{
Expand Down Expand Up @@ -630,7 +630,6 @@ Tensor std_var_common_impl_mps(
axes);

num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims;
output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));

unsigned int curr_i = 0;
for (int i = 0; i < num_input_dims; i++)
Expand All @@ -645,13 +644,17 @@ Tensor std_var_common_impl_mps(
}
}
if (found) continue;
output_shape[curr_i] = input_shape[i];
output_shape.push_back(input_shape[i]);
curr_i += 1;
// End loop when output shape is filled
if(curr_i == num_output_dims)
break;
}

for(int i = 0; i < num_reduce_dims; i++)
{
correction_n *= input_shape[dim_value[i]];
auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
correction_n *= input_shape[wrap_dim];
}
// (3, 4, 5) --> (3, 5)
}
Expand All @@ -668,10 +671,9 @@ Tensor std_var_common_impl_mps(
input_shape,
axes);
num_output_dims = num_input_dims;
output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));
for (int i = 0; i < num_input_dims; i++)
{
output_shape[i] = (int64_t) 1;
output_shape.push_back((int64_t) 1);
correction_n *= input_shape[i];
}
// scalar --> vector case [[1.0034567]]
Expand All @@ -691,21 +693,24 @@ Tensor std_var_common_impl_mps(
axes);

num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0;
output_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t));

for(int i = 0; i < num_reduce_dims; i++)
{
correction_n *= input_shape[dim_value[i]];
auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size());
correction_n *= input_shape[wrap_dim];
}

for (int i = 0; i < num_input_dims; i++)
{
output_shape[i] = [apparent_output_shape[i] longValue];
output_shape.push_back([apparent_output_shape[i] longValue]);
}
}

int64_t output_shape_array[output_shape.size()];
std::copy(output_shape.begin(), output_shape.end(), output_shape_array);

Tensor output_t = at::native::empty_mps(
IntArrayRef(output_shape, num_output_dims),
IntArrayRef(output_shape_array, num_output_dims),
input_t.scalar_type(),
c10::nullopt,
kMPS,
Expand Down Expand Up @@ -790,7 +795,7 @@ Tensor std_var_common_impl_mps(
};
native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
free(output_shape);

return output_t;
}

Expand Down
108 changes: 26 additions & 82 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,104 +2948,48 @@ def helper(shape):
helper((9, 5, 6, 7))

# Test var
def test_var(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

all_var = torch.var(x, unbiased=False)
all_var_cpu = torch.var(cpu_x, unbiased=False)

self.assertEqual(all_var, all_var_cpu)

nil_dim_var = torch.var(x, dim=[], unbiased=False)
nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=False)

self.assertEqual(nil_dim_var, nil_dim_var_cpu)

nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=False)
nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=False)

self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)

zero_dim_var = torch.var(x, dim=[0], unbiased=False)
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=False)

self.assertEqual(zero_dim_var, zero_dim_var_cpu)

zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=False)
zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=False)

self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)

zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=False)
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=False)

self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
def test_var_simple(self):
def helper():

zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=False)
zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
shape = [2,3,4,5]

self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)

two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=False)
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=False)

self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)

two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=False)
two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)

self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)

all_var = torch.var(x, unbiased=True)
all_var_cpu = torch.var(cpu_x, unbiased=True)

self.assertEqual(all_var, all_var_cpu)

nil_dim_var = torch.var(x, dim=[], unbiased=True)
nil_dim_var_cpu = torch.var(cpu_x, dim=[], unbiased=True)

self.assertEqual(nil_dim_var, nil_dim_var_cpu)

nil_dim_var_keepdim = torch.var(x, dim=[], keepdim=True, unbiased=True)
nil_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[], keepdim=True, unbiased=True)
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

self.assertEqual(nil_dim_var_keepdim, nil_dim_var_cpu_keepdim)
for unbiased in [False, True]:
for keepdim in [False, True]:

zero_dim_var = torch.var(x, dim=[0], unbiased=True)
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], unbiased=True)
zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)

self.assertEqual(zero_dim_var, zero_dim_var_cpu)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)

zero_dim_var_keepdim = torch.var(x, dim=[0], keepdim=True, unbiased=True)
zero_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0], keepdim=True, unbiased=True)
all_var = torch.var(x, unbiased=unbiased)
all_var_cpu = torch.var(cpu_x, unbiased=unbiased)

self.assertEqual(zero_dim_var_keepdim, zero_dim_var_cpu_keepdim)
self.assertEqual(all_var, all_var_cpu)

zero_one_dim_var = torch.var(x, dim=[0, 1], unbiased=True)
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, 1], unbiased=True)
nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)

self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
self.assertEqual(nil_dim_var, nil_dim_var_cpu)

zero_one_dim_var_keepdim = torch.var(x, dim=[0, 1], keepdim=True, unbiased=True)
zero_one_dim_var_cpu_keepdim = torch.var(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)

self.assertEqual(zero_one_dim_var_keepdim, zero_one_dim_var_cpu_keepdim)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)

two_three_dim_var = torch.var(x, dim=[2, 3], unbiased=True)
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], unbiased=True)
zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)

self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)

two_three_keepdim_var = torch.var(x, dim=[2, 3], keepdim=True, unbiased=True)
two_three_dim_keepvar_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)

self.assertEqual(two_three_keepdim_var, two_three_dim_keepvar_cpu)
self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)

helper((4, 5, 6, 7))
# verify if a change in shape of input would cause problems with graph caching
helper((9, 5, 6, 7))
helper()

# Test forward amax
def test_amax(self):
Expand Down