Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result) {
TORCH_CHECK(dim >=0 && dim < std::max(1LL, self.ndimension()), "Expected dim to be between 0 and ", self.ndimension(), " but got ", dim);

auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
Expand Down
17 changes: 17 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,23 @@ def helper(shape, dtype):

helper((2, 8, 4, 5), torch.int16)

def test_cumsum_minus_one_axis(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test the TestConsistency test for cumsum.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also enable that for all types.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestConsistency passes, and expanded this test with datatypes (which all pass).

def helper(dtype):
# Test with axis -1
cpu_x = None
if(dtype == torch.float32):
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')

cpu_y = cpu_x.cumsum(-1)
y = x.cumsum(-1)

self.assertEqual(y, cpu_y)

[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]

class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down