Skip to content

Commit

Permalink
Handle casting for div operation
Browse files Browse the repository at this point in the history
  • Loading branch information
abhudev committed Aug 8, 2022
1 parent e957317 commit 955e2bb
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,39 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha

// this type inference is only required at the time of graph creation
const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type());
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);

// Condition -
// 1. Division operation
// 2. Inputs are not float
bool div_condition = op_name.rfind("div", 0) == 0
&& (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half));

auto compute_type = ScalarType::Float;

if(div_condition) {

if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half)
compute_type = output_.scalar_type();

primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type);
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
else {
if (self.scalar_type() != common_dtype) {
primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype);
}
if (other.scalar_type() != common_dtype) {
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
}

newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
// Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor
// Output tensor should have been promoted but it remains an int32 tensor
if (output_.scalar_type() != common_dtype) {


if ((div_condition && compute_type != output_.scalar_type()) ||
output_.scalar_type() != common_dtype) {
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type());
}
}
Expand Down

0 comments on commit 955e2bb

Please sign in to comment.