From 23ed32ab84d9582cf4941018ec1a6f54778acbcf Mon Sep 17 00:00:00 2001 From: Siddharth Kotapati Date: Tue, 17 Jan 2023 13:15:14 -0800 Subject: [PATCH] Added zero check to inverse op, resolving crash seen in inverse & matrix_pow tests --- aten/src/ATen/native/mps/operations/Inverse.mm | 4 ++++ test/test_mps.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm index 2975fd9875949..e78cf15ae90b9 100644 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -24,6 +24,10 @@ MPSStream* stream = getCurrentMPSStream(); info.zero_(); + if (A.numel() == 0) { + return; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} diff --git a/test/test_mps.py b/test/test_mps.py index b31a9b6518906..c388c2b1ad6dc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9272,8 +9272,6 @@ class TestConsistency(TestCase): 'nonzero': [torch.bool, torch.uint8, torch.float16], 'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16], 'sgn': [torch.bool], - 'linalg.inv': [torch.float32], - 'linalg.inv_ex': [torch.float32], 'linalg.matrix_power': [torch.float32], 'nn.functional.interpolate': [torch.float32], 'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],