Skip to content

Commit ded98e6

Browse files
committed
Revert "Added zero check to inverse op, resolving crash seen in inverse & matrix_pow tests (#236)"
This reverts commit 5d9be8c.
1 parent e4eadc3 commit ded98e6

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

aten/src/ATen/native/mps/operations/Inverse.mm

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
MPSStream* stream = getCurrentMPSStream();
2525
info.zero_();
2626

27-
if (A.numel() == 0) {
28-
return;
29-
}
30-
3127
struct CachedGraph : public MPSCachedGraph
3228
{
3329
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}

test/test_mps.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9426,6 +9426,8 @@ class TestConsistency(TestCase):
94269426
'nonzero': [torch.bool, torch.uint8, torch.float16],
94279427
'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16],
94289428
'sgn': [torch.bool],
9429+
'linalg.inv': [torch.float32],
9430+
'linalg.inv_ex': [torch.float32],
94299431
'linalg.matrix_power': [torch.float32],
94309432
'nn.functional.interpolate': [torch.float32],
94319433
'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],

0 commit comments

Comments
 (0)