Skip to content

Commit 8a56d6f

Browse files
razarmehrkulinseth
authored andcommitted
Return input in addcmul/div if value is zero (#84)
Also remove the unnecessary resize (structured op)
1 parent 3b8e823 commit 8a56d6f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

aten/src/ATen/native/mps/operations/PointwiseOps.mm

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
const bool is_div,
1616
const string op_name)
1717
{
18-
if (&output != &self) {
19-
output.resize_(output.sizes());
18+
if (value_opt.toDouble() == 0.0) {
19+
output.copy_(self);
20+
return output;
2021
}
2122

2223
if(output.numel() == 0) {
@@ -49,7 +50,7 @@
4950
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
5051
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
5152
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
52-
newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()));
53+
newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]);
5354

5455
// the tensor to be optionally multiplied by value_scalar
5556
MPSGraphTensor *multiplicandTensor = nil;

0 commit comments

Comments
 (0)