diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm index 2975fd9875949..e78cf15ae90b9 100644 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -24,6 +24,10 @@ MPSStream* stream = getCurrentMPSStream(); info.zero_(); + if (A.numel() == 0) { + return; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} diff --git a/test/test_mps.py b/test/test_mps.py index b31a9b6518906..c388c2b1ad6dc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9272,8 +9272,6 @@ class TestConsistency(TestCase): 'nonzero': [torch.bool, torch.uint8, torch.float16], 'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16], 'sgn': [torch.bool], - 'linalg.inv': [torch.float32], - 'linalg.inv_ex': [torch.float32], 'linalg.matrix_power': [torch.float32], 'nn.functional.interpolate': [torch.float32], 'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],