Skip to content

Commit 8c0bb00

Browse files
abhudevabhipathak97
andauthored
Make intermediate type for cumsum ScalarType::Int (#221)
* Make intermediate type for cumsum ScalarType::Int * Disallow int64 as input for cumsum * Fix error message; add test Co-authored-by: abhipathak97 <abhipathak97@mps10.scv.apple.com>
1 parent dabe321 commit 8c0bb00

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,12 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
262262
return;
263263
}
264264
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
265+
TORCH_CHECK(input.scalar_type() != ScalarType::Long, "MPS does not support cumsum op with int64 input");
265266
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
266267
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
267268
// cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
268269
if (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int) {
269-
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, result.scalar_type());
270+
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
270271
}
271272
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
272273
axis: dim

test/test_mps.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,23 @@ def test_from_numpy_non_contiguous(self):
21652165
t_mps = torch.tensor(a, device="mps")
21662166
self.assertEqual(t_cpu, t_mps.to("cpu"))
21672167

2168+
def test_cumsum_all_dtypes(self):
2169+
def helper(dtype):
2170+
t = torch.tensor([1,1,1,1], device="mps", dtype=dtype)
2171+
t_cpu = torch.tensor([1,1,1,1], device="cpu")
2172+
2173+
a = t.cumsum(0, dtype=dtype)
2174+
a_cpu = t_cpu.cumsum(0, dtype=dtype)
2175+
2176+
self.assertEqual(a.cpu(), a_cpu)
2177+
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
2178+
2179+
try:
2180+
helper(torch.int64)
2181+
except Exception as e:
2182+
e_string = str(e)
2183+
self.assertEqual(e_string, "MPS does not support cumsum op with int64 input")
2184+
21682185

21692186
class TestLogical(TestCase):
21702187
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):

0 commit comments

Comments
 (0)