Skip to content

Commit 5d9be8c

Browse files
skotapatiSiddharth Kotapati
andauthored
Added zero check to inverse op, resolving crash seen in inverse & matrix_pow tests (#236)
Co-authored-by: Siddharth Kotapati <sidk@Siddharths-MacBook-Pro.local>
1 parent f079740 commit 5d9be8c

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

aten/src/ATen/native/mps/operations/Inverse.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
MPSStream* stream = getCurrentMPSStream();
2525
info.zero_();
2626

27+
if (A.numel() == 0) {
28+
return;
29+
}
30+
2731
struct CachedGraph : public MPSCachedGraph
2832
{
2933
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}

test/test_mps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9272,8 +9272,6 @@ class TestConsistency(TestCase):
92729272
'nonzero': [torch.bool, torch.uint8, torch.float16],
92739273
'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16],
92749274
'sgn': [torch.bool],
9275-
'linalg.inv': [torch.float32],
9276-
'linalg.inv_ex': [torch.float32],
92779275
'linalg.matrix_power': [torch.float32],
92789276
'nn.functional.interpolate': [torch.float32],
92799277
'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],

0 commit comments

Comments
 (0)