Skip to content

Commit

Permalink
Ensure we check the right types and handle some edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
tannergooding committed Jul 16, 2024
1 parent f45c339 commit 06323f6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 45 deletions.
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsicxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ NamedIntrinsic HWIntrinsicInfo::lookupEvexMaskId(NamedIntrinsic intrinsic)
case NI_SSE2_CompareLessThan:
case NI_SSE42_CompareLessThan:
case NI_AVX_CompareLessThan:
case NI_AVX2_CompareLessThan:
{
return NI_EVEX_CompareLessThanMask;
}
Expand Down
40 changes: 15 additions & 25 deletions src/coreclr/jit/morph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9939,15 +9939,19 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)

GenTreeHWIntrinsic* op1Intrin = op1->AsHWIntrinsic();

if (!isCndSel)
if (isCndSel)
{
// CndSel knows how to handle mismatched mask sizes, but not all consumers can

if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType))
if (op1->OperIsConvertMaskToVector())
{
// We're already in the "correct" shape for lowering to take advantage
break;
}
}
else if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType))
{
// CndSel knows how to handle mismatched mask sizes, but not all consumers can
break;
}

if (!canUseEvexEncoding())
{
Expand All @@ -9956,8 +9960,7 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)

// We have something expecting a mask and have a case where we could be producing a mask directly

NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId();

NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId();
NamedIntrinsic evexIntrinId = HWIntrinsicInfo::lookupEvexMaskId(op1IntrinId);

if (evexIntrinId != NI_Illegal)
Expand All @@ -9967,41 +9970,29 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
op1Intrin->ChangeHWIntrinsicId(evexIntrinId);
op1Intrin->gtType = TYP_MASK;

#ifdef DEBUG
// We want to remorph the nodes introduced below, so clear the flag

auto resetMorphedFlag = [](GenTree** slot, fgWalkData* data) -> fgWalkResult {
(*slot)->gtDebugFlags &= ~GTF_DEBUG_NODE_MORPHED;
return WALK_CONTINUE;
};

fgWalkTreePost(&op1, resetMorphedFlag);
#endif // DEBUG

switch (op1IntrinId)
switch (evexIntrinId)
{
case NI_EVEX_AndMask:
case NI_EVEX_AndNotMask:
case NI_EVEX_OrMask:
case NI_EVEX_XorMask:
{
// There's a few special nodes which are allowed to combine masks
// and so we handle these by inserting a CvtVectorToMask over each
// operand and remorphing, which will get us the optimized sequence
// and so we handle these explicitly by introducing the vector to
// mask conversions required and folding, which will get us to an
// optimized sequence where relevant

cvtNode = op1Intrin->Op(1);
cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize);
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());
cvtNode = gtFoldExpr(cvtNode);

op1Intrin->Op(1) = cvtNode;

cvtNode = op1Intrin->Op(2);
cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize);
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());
cvtNode = gtFoldExpr(cvtNode);

op1Intrin->Op(2) = cvtNode;

op1 = fgMorphHWIntrinsic(op1Intrin);
break;
}

Expand All @@ -10015,7 +10006,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
{
// This will allow lowering to emit a vblendm and potentially do embedded masking
cvtNode = gtNewSimdCvtMaskToVectorNode(retType, op1, simdBaseJitType, simdSize);
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());

node->Op(1) = cvtNode;
return node;
Expand Down
44 changes: 24 additions & 20 deletions src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8268,7 +8268,9 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,
if (varTypeIsFloating(baseType))
{
// Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return VNZeroForType(type);
}
Expand All @@ -8278,91 +8280,91 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,

case GT_GT:
{
ValueNum zeroVN = VNZeroForType(type);
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (varTypeIsUnsigned(baseType))
{
// Handle `(0 > x) == false` for unsigned types.
if ((cnsVN == arg0VN) && (cnsVN == zeroVN))
if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType)))
{
return zeroVN;
return VNZeroForType(type);
}
}
else if (varTypeIsFloating(baseType))
{
// Handle `(x > NaN) == false` and `(NaN > x) == false` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return zeroVN;
return VNZeroForType(type);
}
}
break;
}

case GT_GE:
{
ValueNum zeroVN = VNZeroForType(type);
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (varTypeIsUnsigned(baseType))
{
// Handle `x >= 0 == true` for unsigned types.
if ((cnsVN == arg1VN) && (cnsVN == zeroVN))
if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType)))
{
return VNAllBitsForType(type);
}
}
else if (varTypeIsFloating(baseType))
{
// Handle `(x >= NaN) == false` and `(NaN >= x) == false` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return zeroVN;
return VNZeroForType(type);
}
}
break;
}

case GT_LT:
{
ValueNum zeroVN = VNZeroForType(type);
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (varTypeIsUnsigned(baseType))
{
// Handle `x < 0 == false` for unsigned types.
if ((cnsVN == arg1VN) && (cnsVN == zeroVN))
if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType)))
{
return zeroVN;
return VNZeroForType(type);
}
}
else if (varTypeIsFloating(baseType))
{
// Handle `(x < NaN) == false` and `(NaN < x) == false` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return zeroVN;
return VNZeroForType(type);
}
}
break;
}

case GT_LE:
{
ValueNum zeroVN = VNZeroForType(type);
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (varTypeIsUnsigned(baseType))
{
// Handle `0 <= x == true` for unsigned types.
if ((cnsVN == arg0VN) && (cnsVN == zeroVN))
if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType)))
{
return VNAllBitsForType(type);
}
}
else if (varTypeIsFloating(baseType))
{
// Handle `(x <= NaN) == false` and `(NaN <= x) == false` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return zeroVN;
return VNZeroForType(type);
}
}
break;
Expand Down Expand Up @@ -8417,10 +8419,12 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,

case GT_NE:
{
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

if (varTypeIsFloating(baseType))
{
// Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types
if (VNIsVectorNaN(type, baseType, cnsVN))
if (VNIsVectorNaN(simdType, baseType, cnsVN))
{
return VNAllBitsForType(type);
}
Expand Down

0 comments on commit 06323f6

Please sign in to comment.