Skip to content

Commit 06323f6

Browse files
committed
Ensure we check the right types and handle some edge cases
1 parent f45c339 commit 06323f6

File tree

3 files changed

+40
-45
lines changed

3 files changed

+40
-45
lines changed

src/coreclr/jit/hwintrinsicxarch.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ NamedIntrinsic HWIntrinsicInfo::lookupEvexMaskId(NamedIntrinsic intrinsic)
882882
case NI_SSE2_CompareLessThan:
883883
case NI_SSE42_CompareLessThan:
884884
case NI_AVX_CompareLessThan:
885+
case NI_AVX2_CompareLessThan:
885886
{
886887
return NI_EVEX_CompareLessThanMask;
887888
}

src/coreclr/jit/morph.cpp

+15-25
Original file line numberDiff line numberDiff line change
@@ -9939,15 +9939,19 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
99399939

99409940
GenTreeHWIntrinsic* op1Intrin = op1->AsHWIntrinsic();
99419941

9942-
if (!isCndSel)
9942+
if (isCndSel)
99439943
{
9944-
// CndSel knows how to handle mismatched mask sizes, but not all consumers can
9945-
9946-
if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType))
9944+
if (op1->OperIsConvertMaskToVector())
99479945
{
9946+
// We're already in the "correct" shape for lowering to take advantage
99489947
break;
99499948
}
99509949
}
9950+
else if (genTypeSize(op1Intrin->GetSimdBaseType()) != genTypeSize(simdBaseType))
9951+
{
9952+
// CndSel knows how to handle mismatched mask sizes, but not all consumers can
9953+
break;
9954+
}
99519955

99529956
if (!canUseEvexEncoding())
99539957
{
@@ -9956,8 +9960,7 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
99569960

99579961
// We have something expecting a mask and have a case where we could be producing a mask directly
99589962

9959-
NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId();
9960-
9963+
NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId();
99619964
NamedIntrinsic evexIntrinId = HWIntrinsicInfo::lookupEvexMaskId(op1IntrinId);
99629965

99639966
if (evexIntrinId != NI_Illegal)
@@ -9967,41 +9970,29 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
99679970
op1Intrin->ChangeHWIntrinsicId(evexIntrinId);
99689971
op1Intrin->gtType = TYP_MASK;
99699972

9970-
#ifdef DEBUG
9971-
// We want to remorph the nodes introduced below, so clear the flag
9972-
9973-
auto resetMorphedFlag = [](GenTree** slot, fgWalkData* data) -> fgWalkResult {
9974-
(*slot)->gtDebugFlags &= ~GTF_DEBUG_NODE_MORPHED;
9975-
return WALK_CONTINUE;
9976-
};
9977-
9978-
fgWalkTreePost(&op1, resetMorphedFlag);
9979-
#endif // DEBUG
9980-
9981-
switch (op1IntrinId)
9973+
switch (evexIntrinId)
99829974
{
99839975
case NI_EVEX_AndMask:
99849976
case NI_EVEX_AndNotMask:
99859977
case NI_EVEX_OrMask:
99869978
case NI_EVEX_XorMask:
99879979
{
99889980
// There's a few special nodes which are allowed to combine masks
9989-
// and so we handle these by inserting a CvtVectorToMask over each
9990-
// operand and remorphing, which will get us the optimized sequence
9981+
// and so we handle these explicitly by introducing the vector to
9982+
// mask conversions required and folding, which will get us to an
9983+
// optimized sequence where relevant
99919984

99929985
cvtNode = op1Intrin->Op(1);
99939986
cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize);
9994-
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());
9987+
cvtNode = gtFoldExpr(cvtNode);
99959988

99969989
op1Intrin->Op(1) = cvtNode;
99979990

99989991
cvtNode = op1Intrin->Op(2);
99999992
cvtNode = gtNewSimdCvtVectorToMaskNode(TYP_MASK, cvtNode, simdBaseJitType, simdSize);
10000-
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());
9993+
cvtNode = gtFoldExpr(cvtNode);
100019994

100029995
op1Intrin->Op(2) = cvtNode;
10003-
10004-
op1 = fgMorphHWIntrinsic(op1Intrin);
100059996
break;
100069997
}
100079998

@@ -10015,7 +10006,6 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
1001510006
{
1001610007
// This will allow lowering to emit a vblendm and potentially do embedded masking
1001710008
cvtNode = gtNewSimdCvtMaskToVectorNode(retType, op1, simdBaseJitType, simdSize);
10018-
cvtNode = fgMorphHWIntrinsic(cvtNode->AsHWIntrinsic());
1001910009

1002010010
node->Op(1) = cvtNode;
1002110011
return node;

src/coreclr/jit/valuenum.cpp

+24-20
Original file line numberDiff line numberDiff line change
@@ -8268,7 +8268,9 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,
82688268
if (varTypeIsFloating(baseType))
82698269
{
82708270
// Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types
8271-
if (VNIsVectorNaN(type, baseType, cnsVN))
8271+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
8272+
8273+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
82728274
{
82738275
return VNZeroForType(type);
82748276
}
@@ -8278,91 +8280,91 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,
82788280

82798281
case GT_GT:
82808282
{
8281-
ValueNum zeroVN = VNZeroForType(type);
8283+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
82828284

82838285
if (varTypeIsUnsigned(baseType))
82848286
{
82858287
// Handle `(0 > x) == false` for unsigned types.
8286-
if ((cnsVN == arg0VN) && (cnsVN == zeroVN))
8288+
if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType)))
82878289
{
8288-
return zeroVN;
8290+
return VNZeroForType(type);
82898291
}
82908292
}
82918293
else if (varTypeIsFloating(baseType))
82928294
{
82938295
// Handle `(x > NaN) == false` and `(NaN > x) == false` for floating-point types
8294-
if (VNIsVectorNaN(type, baseType, cnsVN))
8296+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
82958297
{
8296-
return zeroVN;
8298+
return VNZeroForType(type);
82978299
}
82988300
}
82998301
break;
83008302
}
83018303

83028304
case GT_GE:
83038305
{
8304-
ValueNum zeroVN = VNZeroForType(type);
8306+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
83058307

83068308
if (varTypeIsUnsigned(baseType))
83078309
{
83088310
// Handle `x >= 0 == true` for unsigned types.
8309-
if ((cnsVN == arg1VN) && (cnsVN == zeroVN))
8311+
if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType)))
83108312
{
83118313
return VNAllBitsForType(type);
83128314
}
83138315
}
83148316
else if (varTypeIsFloating(baseType))
83158317
{
83168318
// Handle `(x >= NaN) == false` and `(NaN >= x) == false` for floating-point types
8317-
if (VNIsVectorNaN(type, baseType, cnsVN))
8319+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
83188320
{
8319-
return zeroVN;
8321+
return VNZeroForType(type);
83208322
}
83218323
}
83228324
break;
83238325
}
83248326

83258327
case GT_LT:
83268328
{
8327-
ValueNum zeroVN = VNZeroForType(type);
8329+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
83288330

83298331
if (varTypeIsUnsigned(baseType))
83308332
{
83318333
// Handle `x < 0 == false` for unsigned types.
8332-
if ((cnsVN == arg1VN) && (cnsVN == zeroVN))
8334+
if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType)))
83338335
{
8334-
return zeroVN;
8336+
return VNZeroForType(type);
83358337
}
83368338
}
83378339
else if (varTypeIsFloating(baseType))
83388340
{
83398341
// Handle `(x < NaN) == false` and `(NaN < x) == false` for floating-point types
8340-
if (VNIsVectorNaN(type, baseType, cnsVN))
8342+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
83418343
{
8342-
return zeroVN;
8344+
return VNZeroForType(type);
83438345
}
83448346
}
83458347
break;
83468348
}
83478349

83488350
case GT_LE:
83498351
{
8350-
ValueNum zeroVN = VNZeroForType(type);
8352+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
83518353

83528354
if (varTypeIsUnsigned(baseType))
83538355
{
83548356
// Handle `0 <= x == true` for unsigned types.
8355-
if ((cnsVN == arg0VN) && (cnsVN == zeroVN))
8357+
if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType)))
83568358
{
83578359
return VNAllBitsForType(type);
83588360
}
83598361
}
83608362
else if (varTypeIsFloating(baseType))
83618363
{
83628364
// Handle `(x <= NaN) == false` and `(NaN <= x) == false` for floating-point types
8363-
if (VNIsVectorNaN(type, baseType, cnsVN))
8365+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
83648366
{
8365-
return zeroVN;
8367+
return VNZeroForType(type);
83668368
}
83678369
}
83688370
break;
@@ -8417,10 +8419,12 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree,
84178419

84188420
case GT_NE:
84198421
{
8422+
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
8423+
84208424
if (varTypeIsFloating(baseType))
84218425
{
84228426
// Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types
8423-
if (VNIsVectorNaN(type, baseType, cnsVN))
8427+
if (VNIsVectorNaN(simdType, baseType, cnsVN))
84248428
{
84258429
return VNAllBitsForType(type);
84268430
}

0 commit comments

Comments
 (0)