Skip to content

Commit bcc9f63

Browse files
authored
Fix the crash in nan_to_num() with Float16 data type (#226)
1 parent 7f4cf84 commit bcc9f63

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

aten/src/ATen/native/mps/operations/TensorCompare.mm

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,16 @@ Tensor where_mps(const Tensor& condition,
487487
MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor
488488
secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype]
489489
name: nil];
490-
// the cast is a workaround for the issue #103149520 (crash when bool and fp16 passed to binary ops)
491-
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: [mpsGraph castTensor: subZeroTensor toType: self_dtype name: @"castTensor"]
492-
secondaryTensor: [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil]
490+
MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil];
491+
// workaround for Monterey; On Ventura the output of lessThan() is always Boolean
492+
if (subZeroTensor.dataType != MPSDataTypeBool) {
493+
subZeroTensor = castMPSTensor(mpsGraph, subZeroTensor, kBool);
494+
}
495+
if (isInfTensor.dataType != MPSDataTypeBool) {
496+
isInfTensor = castMPSTensor(mpsGraph, isInfTensor, kBool);
497+
}
498+
MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor
499+
secondaryTensor: isInfTensor
493500
name: nil];
494501
MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor
495502
truePredicateTensor: newCachedGraph->negInfReplacementTensor

0 commit comments

Comments
 (0)