Skip to content

Commit 3ac4827

Browse files
razarmehrkulinseth
authored andcommitted
Return input in addcmul/div if value is zero (#84)
Also remove the unnecessary resize (structured op)
1 parent 7366608 commit 3ac4827

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

aten/src/ATen/native/mps/operations/PointwiseOps.mm

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
const bool is_div,
1515
const string op_name)
1616
{
17-
if (&output != &self) {
18-
output.resize_(output.sizes());
17+
if (value_opt.toDouble() == 0.0) {
18+
output.copy_(self);
19+
return output;
1920
}
2021

2122
if(output.numel() == 0) {
@@ -48,7 +49,7 @@
4849
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
4950
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
5051
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
51-
newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()));
52+
newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]);
5253

5354
// the tensor to be optionally multiplied by value_scalar
5455
MPSGraphTensor *multiplicandTensor = nil;

0 commit comments

Comments
 (0)