Skip to content

Commit 5de429e

Browse files
skotapatiSiddharth Kotapati
authored andcommitted
Added zero check to inverse op, resolving crash seen in inverse & matrix_pow tests (#236)
Co-authored-by: Siddharth Kotapati <sidk@Siddharths-MacBook-Pro.local>
1 parent 67d87ea commit 5de429e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

aten/src/ATen/native/mps/operations/Inverse.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
MPSStream* stream = getCurrentMPSStream();
2525
info.zero_();
2626

27+
if (A.numel() == 0) {
28+
return;
29+
}
30+
2731
struct CachedGraph : public MPSCachedGraph
2832
{
2933
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}

test/test_mps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9402,8 +9402,6 @@ class TestConsistency(TestCase):
94029402
'nonzero': [torch.bool, torch.uint8, torch.float16],
94039403
'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16],
94049404
'sgn': [torch.bool],
9405-
'linalg.inv': [torch.float32],
9406-
'linalg.inv_ex': [torch.float32],
94079405
'linalg.matrix_power': [torch.float32],
94089406
'nn.functional.interpolate': [torch.float32],
94099407
'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],

0 commit comments

Comments
 (0)