diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 2670701bacb54..bbbb81cf47432 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -254,7 +254,10 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una int64_t dim, c10::optional dtype, const Tensor& result) { - TORCH_CHECK(dim >=0 && dim < std::max(1LL, self.ndimension()), "Expected dim to be between 0 and ", self.ndimension(), " but got ", dim); + + auto nDims = self.dim(); + auto wrapped_dim = maybe_wrap_dim(dim, nDims); + TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")"); if (!is_macos_13_or_newer()) { TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade"); auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype); diff --git a/test/test_mps.py b/test/test_mps.py index 7d36da427667b..04141ace1d2eb 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2248,6 +2248,23 @@ def helper(shape, dtype): helper((2, 8, 4, 5), torch.int16) + def test_cumsum_minus_one_axis(self): + def helper(dtype): + # Test with axis -1 + cpu_x = None + if(dtype == torch.float32): + cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) + else: + cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) + x = cpu_x.detach().clone().to('mps') + + cpu_y = cpu_x.cumsum(-1) + y = x.cumsum(-1) + + self.assertEqual(y, cpu_y) + + [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] + class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)