diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 66427c73e0c75..2cd65a7f6b541 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -15,8 +15,9 @@ const bool is_div, const string op_name) { - if (&output != &self) { - output.resize_(output.sizes()); + if (value_opt.toDouble() == 0.0) { + output.copy_(self); + return output; } MPSStream* mpsStream = getCurrentMPSStream(); @@ -44,7 +45,7 @@ newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1); newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2); - newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type())); + newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]); // the tensor to be optionally multiplied by value_scalar MPSGraphTensor *multiplicandTensor = nil;