Skip to content

Commit 81c2928

Browse files
DenisVieriu97kulinseth
authored andcommitted
Fix arange_mps_out for empty tensor (#245)
* Fix arange_mps_out for empty tensor * Address PR comments
1 parent 71ebe62 commit 81c2928

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

aten/src/ATen/native/mps/operations/RangeFactories.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@
8787
}
8888
result.resize_({size});
8989
}
90+
91+
if (result.numel() == 0) {
92+
return;
93+
}
94+
9095
bool is_contiguous = result.is_contiguous();
9196
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;
9297
using namespace mps;

test/test_mps.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5564,6 +5564,14 @@ def test_arange(self):
55645564
self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
55655565
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
55665566

5567+
def test_arange_empty(self):
5568+
out_mps = torch.tensor([], device="mps")
5569+
out_cpu = torch.tensor([], device="cpu")
5570+
5571+
y_mps = torch.arange(0, 0, 1, out=out_mps)
5572+
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
5573+
self.assertEqual(y_mps, y_cpu)
5574+
55675575
# Test softmax
55685576
def test_softmax(self):
55695577
def helper(shape, dim, channels_last=False):

0 commit comments

Comments
 (0)