Skip to content

Commit 3f548d0

Browse files
abhudevkulinseth
authored andcommitted
Repeat crash fix (#69)
* Handle empty repeats * Add repeat to allow list * Replace numberWithInt with numberWithInteger
1 parent dadfe1c commit 3f548d0

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

aten/src/ATen/native/mps/operations/Repeat.mm

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,21 @@ void set_apparent_shapes(NSMutableArray<NSNumber*> * input_shape,
4444
int64_t num_repeat_dims) {
4545

4646

47-
// Set repeats_shape
47+
bool repeat_empty = false;
48+
if(num_repeat_dims == 0) {
49+
num_repeat_dims = num_input_dims;
50+
repeat_empty = true;
51+
}
4852

53+
// Set repeats_shape
4954
repeats_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_repeat_dims];
5055

51-
for(int i = 0; i < num_repeat_dims; i++)
52-
repeats_shape[i] = [NSNumber numberWithInt:repeats[i]];
56+
for(int i = 0; i < num_repeat_dims; i++) {
57+
if(repeat_empty)
58+
repeats_shape[i] = [NSNumber numberWithInteger:1];
59+
else
60+
repeats_shape[i] = [NSNumber numberWithInteger:repeats[i]];
61+
}
5362

5463
// If no extension of the shape is needed
5564
if(num_repeat_dims == num_input_dims) {
@@ -115,7 +124,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
115124
c10::nullopt);
116125

117126
// Empty output
118-
if(zero_tensor)
127+
if(zero_tensor || output.numel() == 0)
119128
return output;
120129

121130
auto stream = at::mps::getCurrentMPSStream();

test/test_mps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6272,6 +6272,12 @@ class TestConsistency(TestCase):
62726272
'torch.int32',
62736273
'torch.int64',
62746274
'torch.uint8'],
6275+
'repeat': ['torch.float16',
6276+
'torch.float32',
6277+
'torch.int16',
6278+
'torch.int32',
6279+
'torch.int64',
6280+
'torch.uint8'],
62756281
'repeat_interleave': ['torch.bool',
62766282
'torch.float16',
62776283
'torch.float32',
@@ -6441,7 +6447,6 @@ class TestConsistency(TestCase):
64416447
# These were moved from ALLOWLIST to BLOCK as they are not working
64426448
# locally
64436449
'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
6444-
'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
64456450
'__radd__': ['torch.bool', 'torch.uint8'],
64466451
'__rmul__': ['torch.uint8'],
64476452
'add': ['torch.bool', 'torch.uint8'],

0 commit comments

Comments
 (0)