Skip to content

Commit b65d0e8

Browse files
abhudevkulinseth
authored andcommitted
Repeat crash fix (#69)
* Handle empty repeats * Add repeat to allow list * Replace numberWithInt with numberWithInteger
1 parent 00a015f commit b65d0e8

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

aten/src/ATen/native/mps/operations/Repeat.mm

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,21 @@ void set_apparent_shapes(NSMutableArray<NSNumber*> * input_shape,
4444
int64_t num_repeat_dims) {
4545

4646

47-
// Set repeats_shape
47+
bool repeat_empty = false;
48+
if(num_repeat_dims == 0) {
49+
num_repeat_dims = num_input_dims;
50+
repeat_empty = true;
51+
}
4852

53+
// Set repeats_shape
4954
repeats_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_repeat_dims];
5055

51-
for(int i = 0; i < num_repeat_dims; i++)
52-
repeats_shape[i] = [NSNumber numberWithInt:repeats[i]];
56+
for(int i = 0; i < num_repeat_dims; i++) {
57+
if(repeat_empty)
58+
repeats_shape[i] = [NSNumber numberWithInteger:1];
59+
else
60+
repeats_shape[i] = [NSNumber numberWithInteger:repeats[i]];
61+
}
5362

5463
// If no extension of the shape is needed
5564
if(num_repeat_dims == num_input_dims) {
@@ -115,7 +124,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
115124
c10::nullopt);
116125

117126
// Empty output
118-
if(zero_tensor)
127+
if(zero_tensor || output.numel() == 0)
119128
return output;
120129

121130
auto stream = at::mps::getCurrentMPSStream();

test/test_mps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6530,6 +6530,12 @@ class TestConsistency(TestCase):
65306530
'torch.int32',
65316531
'torch.int64',
65326532
'torch.uint8'],
6533+
'repeat': ['torch.float16',
6534+
'torch.float32',
6535+
'torch.int16',
6536+
'torch.int32',
6537+
'torch.int64',
6538+
'torch.uint8'],
65336539
'repeat_interleave': ['torch.bool',
65346540
'torch.float16',
65356541
'torch.float32',
@@ -6699,7 +6705,6 @@ class TestConsistency(TestCase):
66996705
# These were moved from ALLOWLIST to BLOCK as they are not working
67006706
# locally
67016707
'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
6702-
'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
67036708
'__radd__': ['torch.bool', 'torch.uint8'],
67046709
'__rmul__': ['torch.uint8'],
67056710
'add': ['torch.bool', 'torch.uint8'],

0 commit comments

Comments
 (0)