diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 51b52f99c5273..cc1ef6ab3fa12 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -220,6 +220,15 @@ void set_axes_and_shapes(const Tensor& input_t, axes:axes name:nil]; } + else if(reduction_type == "amax") { + castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor + axes:axes + name:nil]; + } else if(reduction_type == "amin") { + castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor + axes:axes + name:nil]; + } MPSGraphTensor* outputTensor = nil; @@ -294,6 +303,24 @@ inline ScalarType get_dtype_from_self( return src_type; } +TORCH_IMPL_FUNC(amax_out_mps) + (const Tensor& input_t, + IntArrayRef dim, + bool keepdim, + const Tensor& output_t) { + + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps"); +} + +TORCH_IMPL_FUNC(amin_out_mps) + (const Tensor& input_t, + IntArrayRef dim, + bool keepdim, + const Tensor& output_t) { + + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps"); +} + Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { auto num_dims = self.dim(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 561658085f945..36c95a3326e70 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3134,6 +3134,7 @@ structured: True dispatch: CPU, CUDA: amax_out + MPS: amax_out_mps # Return: (Tensor output, Tensor indices) - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -3290,6 +3291,7 @@ structured: True dispatch: CPU, CUDA: amin_out + MPS: amin_out_mps # TODO: Add this function to MPS dispatch key so that we avoid declaring it in # native_functions.yaml diff --git a/test/test_mps.py b/test/test_mps.py index 2cf74c1803f2a..cb176336e86b2 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2755,6 +2755,50 @@ def helper(shape): helper((4, 5, 6, 7)) + # Test forward amax + def test_amax(self): + def helper(shape, dim, keepdim): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + + result = torch.amax(x, dim=dim, keepdim=keepdim) + result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim) + + cpu_grad = torch.randn(result_cpu.shape) + grad = cpu_grad.to('mps') + + result_cpu.backward(gradient=cpu_grad) + result.backward(gradient=grad) + + self.assertEqual(result, result_cpu) + self.assertEqual(x.grad, cpu_x.grad) + + for dim in ([], [0], [0, 1], [2, 3]): + for keepdim in [False, True]: + helper((2, 8, 4, 5), dim, keepdim) + + # Test forward amin + def test_amin(self): + def helper(shape, dim, keepdim): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) + x = cpu_x.detach().clone().to('mps').requires_grad_() + + result = torch.amin(x, dim=dim, keepdim=keepdim) + result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim) + + cpu_grad = torch.randn(result_cpu.shape) + grad = cpu_grad.to('mps') + + result_cpu.backward(gradient=cpu_grad) + result.backward(gradient=grad) + + self.assertEqual(result, result_cpu) + self.assertEqual(x.grad, cpu_x.grad) + + for dim in ([], [0], [0, 1], [2, 3]): + for keepdim in [False, True]: + helper((2, 8, 4, 5), dim, keepdim) + # Test minimum and maximum def test_minimum_maximum(self): def helper(n, c, h, w):