Skip to content

Commit d24fe97

Browse files
Fix arange_mps_out for empty tensor (#245)
* Fix arange_mps_out for empty tensor * Address PR comments
1 parent 1d50f47 commit d24fe97

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

aten/src/ATen/native/mps/operations/RangeFactories.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@
8888
}
8989
result.resize_({size});
9090
}
91+
92+
if (result.numel() == 0) {
93+
return;
94+
}
95+
9196
bool is_contiguous = result.is_contiguous();
9297
Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result;
9398
using namespace mps;

test/test_mps.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5457,6 +5457,14 @@ def test_arange(self):
54575457
self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
54585458
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
54595459

5460+
def test_arange_empty(self):
5461+
out_mps = torch.tensor([], device="mps")
5462+
out_cpu = torch.tensor([], device="cpu")
5463+
5464+
y_mps = torch.arange(0, 0, 1, out=out_mps)
5465+
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
5466+
self.assertEqual(y_mps, y_cpu)
5467+
54605468
# Test softmax
54615469
def test_softmax(self):
54625470
def helper(shape, dim, channels_last=False):

0 commit comments

Comments
 (0)