From 5e40c08752737308db568777aeb11c3c5940d970 Mon Sep 17 00:00:00 2001 From: Abhishek Pathak Date: Mon, 1 Aug 2022 18:07:50 -0700 Subject: [PATCH 1/2] Add empty input checks for more ops --- aten/src/ATen/native/mps/operations/Linear.mm | 3 +++ aten/src/ATen/native/mps/operations/PointwiseOps.mm | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index a6710ea5fc2a5..7df8e9b8e837e 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -46,6 +46,9 @@ Tensor _mps_linear( TORCH_CHECK(output.is_mps()); + if(output.numel() == 0) + return output; + MPSStream *stream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 66427c73e0c75..2e66a9b154623 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -18,6 +18,10 @@ if (&output != &self) { output.resize_(output.sizes()); } + + if(output.numel() == 0) + return output; + MPSStream* mpsStream = getCurrentMPSStream(); struct CachedGraph : public MPSCachedGraph From 17602fbabbd03e30e5c617e86a60069b36ccc0bc Mon Sep 17 00:00:00 2001 From: Abhishek Pathak Date: Tue, 9 Aug 2022 18:23:38 -0700 Subject: [PATCH 2/2] Add empty check for sigmoid --- aten/src/ATen/native/mps/operations/Activation.mm | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index b741276b45e01..287c04374e112 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -417,6 +417,9 @@ Tensor relu_mps(const Tensor& self) { using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(output.is_mps()); + if(output.numel() == 0) + return; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream();