diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index b619307ef8aa1..768a6199f0274 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -72,16 +72,39 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha // this type inference is only required at the time of graph creation const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type()); - if (self.scalar_type() != common_dtype) { - primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); + + // Condition - + // 1. Division operation + // 2. Inputs are not float + bool div_condition = op_name.rfind("div", 0) == 0 + && (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half)); + + auto compute_type = ScalarType::Float; + + if(div_condition) { + + if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half) + compute_type = output_.scalar_type(); + + primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type); + secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type); } - if (other.scalar_type() != common_dtype) { - secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); + else { + if (self.scalar_type() != common_dtype) { + primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); + } + if (other.scalar_type() != common_dtype) { + secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); + } } + newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor); // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor // Output tensor should have been promoted but it remains an int32 tensor - if (output_.scalar_type() != common_dtype) { + + + if ((div_condition && compute_type != output_.scalar_type()) || + output_.scalar_type() != common_dtype) { newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type()); } }