File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -487,9 +487,16 @@ Tensor where_mps(const Tensor& condition,
487487 MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor
488488 secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype]
489489 name: nil ];
490- // the cast is a workaround for the issue #103149520 (crash when bool and fp16 passed to binary ops)
491- MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: [mpsGraph castTensor: subZeroTensor toType: self_dtype name: @" castTensor" ]
492- secondaryTensor: [mpsGraph isInfiniteWithTensor: nanFreeTensor name: nil ]
490+ MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name: nil ];
491+ // workaround for Monterey; On Ventura the output of lessThan() is always Boolean
492+ if (subZeroTensor.dataType != MPSDataTypeBool) {
493+ subZeroTensor = castMPSTensor (mpsGraph, subZeroTensor, kBool );
494+ }
495+ if (isInfTensor.dataType != MPSDataTypeBool) {
496+ isInfTensor = castMPSTensor (mpsGraph, isInfTensor, kBool );
497+ }
498+ MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor
499+ secondaryTensor: isInfTensor
493500 name: nil ];
494501 MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor
495502 truePredicateTensor: newCachedGraph->negInfReplacementTensor
You can’t perform that action at this time.
0 commit comments