diff --git a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm similarity index 59% rename from aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm rename to aten/src/ATen/native/mps/operations/AdaptivePooling.mm index c82818318e9e6..1d58de2902cfd 100644 --- a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm +++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm @@ -26,6 +26,8 @@ kernel_sizeW = isizeW - (osizeW-1) * strideW; } +// Adaptive average pooling + Tensor& adaptive_avg_pool2d_out_mps (const Tensor& input, IntArrayRef output_size, @@ -150,5 +152,93 @@ } +// Adaptive max pooling + +TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps) + (const Tensor& input, + IntArrayRef output_size, + const Tensor& output, + const Tensor& indices) { + + for (int64_t i = 1; i < input.ndimension(); i++) { + TORCH_CHECK(input.size(i) > 0, + "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", input.sizes(), " with dimension ", i, " being " + "empty"); + } + + int64_t isizeH = input.size(-2); + int64_t isizeW = input.size(-1); + + int64_t osizeH = output_size[0]; + int64_t osizeW = output_size[1]; + + if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) + TORCH_CHECK(input.ndimension() == 4, + "adaptive_avg_pool2d(): Expected 4D tensor, but got ", + input.sizes()) + + switch (input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: + case at::MemoryFormat::ChannelsLast: + break; + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only ChannelsLast, Contiguous") + } + + int64_t strideH; + int64_t strideW; + int64_t kernel_sizeH; + int64_t kernel_sizeW; + + set_kernel_params(isizeH, isizeW, + osizeH, osizeW, + strideH, strideW, + kernel_sizeH, kernel_sizeW); + + auto outputs = at::max_pool2d_with_indices(input, + IntArrayRef({kernel_sizeH, kernel_sizeW}), + IntArrayRef({strideH, strideW}), + IntArrayRef({0, 0}), + IntArrayRef({1, 1}), + false); + + output.copy_(std::get<0>(outputs)); + indices.copy_(std::get<1>(outputs)); +} + +TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps) + (const Tensor& gradOutput, + const Tensor& input, + const Tensor& indices, + const Tensor& gradInput) { + + int64_t isizeH = input.size(-2); + int64_t isizeW = input.size(-1); + int64_t osizeH = gradOutput.size(-2); + int64_t osizeW = gradOutput.size(-1); + + int64_t strideH, strideW, kernel_sizeH, kernel_sizeW; + + set_kernel_params(isizeH, isizeW, + osizeH, osizeW, + strideH, strideW, + kernel_sizeH, kernel_sizeW); + + auto returnGradInput = at::max_pool2d_with_indices_backward(gradOutput, + input, + IntArrayRef({kernel_sizeH, kernel_sizeW}), + IntArrayRef({strideH, strideW}), + IntArrayRef({0, 0}), + IntArrayRef({1, 1}), + false, + indices); + + gradInput.copy_(returnGradInput); + +} + } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c81d6badfd926..b97c8c9dfa3a4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9791,6 +9791,7 @@ dispatch: CPU: adaptive_max_pool2d_out_cpu CUDA: adaptive_max_pool2d_out_cuda + MPS: adaptive_max_pool2d_out_mps # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) @@ -9803,6 +9804,7 @@ dispatch: CPU: adaptive_max_pool2d_backward_out_cpu CUDA: adaptive_max_pool2d_backward_out_cuda + MPS: adaptive_max_pool2d_backward_out_mps - func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor python_module: nn diff --git a/test/test_mps.py b/test/test_mps.py index 72b4be7231552..f9778b4759695 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -3089,6 +3089,50 @@ def helper(input_shape, out_shape, channels_last): helper((2, 16, 16), (4, 4), False) + # Test max avg pool2d - when the input size is a multiple of output size + # Not testing for channels last right now + def test_adaptive_max_pool2d_simple(self): + def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): + cpu_x = None + if(dtype in [torch.float16, torch.float32]): + cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) + else: + cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True) + if(channels_last): + cpu_x = cpu_x.to(memory_format=torch.channels_last) + cpu_x.retain_grad() + x = cpu_x.detach().clone().to('mps').requires_grad_() + + max_result, max_indices = None, None + max_result_cpu, max_indices_cpu = None, None + + if(return_indices): + max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) + max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) + else: + max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) + max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) + + cpu_grad = torch.randn(max_result_cpu.shape) + grad = cpu_grad.to('mps') + + max_result.backward(gradient=grad) + max_result_cpu.backward(gradient=cpu_grad) + + self.assertEqual(max_result, max_result_cpu) + if(return_indices): + self.assertEqual(max_indices, max_indices_cpu) + self.assertEqual(x.grad, cpu_x.grad) + + for dtype in [torch.float32]: + for return_indices in [False, True]: + helper((2, 2, 4, 4), (2, 2), return_indices, dtype) + helper((2, 2, 9, 9), (3, 3), return_indices, dtype) + helper((2, 2, 9, 9), (9, 9), return_indices, dtype) + helper((2, 2, 16, 16), (2, 2), return_indices, dtype) + helper((2, 2, 16, 16), (2, 16), return_indices, dtype) + helper((2, 16, 16), (4, 4), return_indices, dtype) + def test_gelu_simple(self): def helper(shape): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)