Skip to content

Commit

Permalink
Fold const WithElement to CNS_VEC (#86212)
Browse files Browse the repository at this point in the history
Co-authored-by: Egor Bogatov <egorbo@gmail.com>
  • Loading branch information
jasper-d and EgorBo authored Jun 13, 2023
1 parent c14f0c7 commit a052348
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 17 deletions.
165 changes: 148 additions & 17 deletions src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2648,7 +2648,8 @@ ValueNum ValueNumStore::VNForFunc(var_types typ, VNFunc func, ValueNum arg0VN, V
assert(arg0VN != NoVN);
assert(arg1VN != NoVN);
assert(arg2VN != NoVN);
assert(VNFuncArity(func) == 3);
// Some SIMD functions with variable number of arguments are defined with zero arity
assert((VNFuncArity(func) == 0) || (VNFuncArity(func) == 3));

#ifdef DEBUG
// Function arguments carry no exceptions.
Expand All @@ -2664,7 +2665,6 @@ ValueNum ValueNumStore::VNForFunc(var_types typ, VNFunc func, ValueNum arg0VN, V
}
assert(arg2VN == VNNormalValue(arg2VN));
#endif
assert(VNFuncArity(func) == 3);

ValueNum resultVN;

Expand Down Expand Up @@ -7813,7 +7813,109 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types type,
}
return VNForFunc(type, func, arg0VN, arg1VN);
}

ValueNum EvaluateSimdFloatWithElement(ValueNumStore* vns, var_types type, ValueNum arg0VN, int index, float value)
{
assert(vns->IsVNConstant(arg0VN));
assert(static_cast<unsigned>(index) < genTypeSize(type) / genTypeSize(TYP_FLOAT));

switch (type)
{
case TYP_SIMD8:
{
simd8_t cnsVec = vns->GetConstantSimd8(arg0VN);
cnsVec.f32[index] = value;
return vns->VNForSimd8Con(cnsVec);
}
case TYP_SIMD12:
{
simd12_t cnsVec = vns->GetConstantSimd12(arg0VN);
cnsVec.f32[index] = value;
return vns->VNForSimd12Con(cnsVec);
}
case TYP_SIMD16:
{
simd16_t cnsVec = vns->GetConstantSimd16(arg0VN);
cnsVec.f32[index] = value;
return vns->VNForSimd16Con(cnsVec);
}
#if defined TARGET_XARCH
case TYP_SIMD32:
{
simd32_t cnsVec = vns->GetConstantSimd32(arg0VN);
cnsVec.f32[index] = value;
return vns->VNForSimd32Con(cnsVec);
}
case TYP_SIMD64:
{
simd64_t cnsVec = vns->GetConstantSimd64(arg0VN);
cnsVec.f32[index] = value;
return vns->VNForSimd64Con(cnsVec);
}
#endif // TARGET_XARCH
default:
{
unreached();
}
}
}

ValueNum ValueNumStore::EvalHWIntrinsicFunTernary(var_types type,
var_types baseType,
NamedIntrinsic ni,
VNFunc func,
ValueNum arg0VN,
ValueNum arg1VN,
ValueNum arg2VN,
bool encodeResultType,
ValueNum resultTypeVN)
{
if (IsVNConstant(arg0VN) && IsVNConstant(arg1VN) && IsVNConstant(arg2VN))
{

switch (ni)
{
case NI_Vector128_WithElement:
#ifdef TARGET_ARM64
case NI_Vector64_WithElement:
#else
case NI_Vector256_WithElement:
case NI_Vector512_WithElement:
#endif
{
int index = GetConstantInt32(arg1VN);

assert(varTypeIsSIMD(type));

// No meaningful diffs for other base-types.
if ((baseType != TYP_FLOAT) || (TypeOfVN(arg0VN) != type) ||
(static_cast<unsigned>(index) >= (genTypeSize(type) / genTypeSize(baseType))))
{
break;
}

float value = GetConstantSingle(arg2VN);

return EvaluateSimdFloatWithElement(this, type, arg0VN, index, value);
}
default:
{
break;
}
}
}

if (encodeResultType)
{
return VNForFunc(type, func, arg0VN, arg1VN, arg2VN, resultTypeVN);
}
else
{
return VNForFunc(type, func, arg0VN, arg1VN, arg2VN);
}
}

#endif // FEATURE_HW_INTRINSICS

ValueNum ValueNumStore::EvalMathFuncUnary(var_types typ, NamedIntrinsic gtMathFN, ValueNum arg0VN)
{
Expand Down Expand Up @@ -11475,9 +11577,11 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
ValueNumPair excSetPair = ValueNumStore::VNPForEmptyExcSet();
ValueNumPair normalPair = ValueNumPair();

if ((tree->GetOperandCount() > 2) || ((JitConfig.JitDisableSimdVN() & 2) == 2))
const size_t opCount = tree->GetOperandCount();

if ((opCount > 3) || (JitConfig.JitDisableSimdVN() & 2) == 2)
{
// TODO-CQ: allow intrinsics with > 2 operands to be properly VN'ed.
// TODO-CQ: allow intrinsics with > 3 operands to be properly VN'ed.
normalPair = vnStore->VNPairForExpr(compCurBB, tree->TypeGet());

for (GenTree* operand : tree->Operands())
Expand Down Expand Up @@ -11525,7 +11629,7 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
const bool isVariableNumArgs = HWIntrinsicInfo::lookupNumArgs(intrinsicId) == -1;

// There are some HWINTRINSICS operations that have zero args, i.e. NI_Vector128_Zero
if (tree->GetOperandCount() == 0)
if (opCount == 0)
{
// Currently we don't have intrinsics with variable number of args with a parameter-less option.
assert(!isVariableNumArgs);
Expand All @@ -11542,13 +11646,13 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
assert(vnStore->VNFuncArity(func) == 0);
}
}
else // HWINTRINSIC unary or binary operator.
else // HWINTRINSIC unary or binary or ternary operator.
{
ValueNumPair op1vnp;
ValueNumPair op1Xvnp;
getOperandVNs(tree->Op(1), &op1vnp, &op1Xvnp);

if (tree->GetOperandCount() == 1)
if (opCount == 1)
{
ValueNum normalLVN = vnStore->EvalHWIntrinsicFunUnary(tree->TypeGet(), tree->GetSimdBaseType(),
intrinsicId, func, op1vnp.GetLiberal(),
Expand All @@ -11567,17 +11671,44 @@ void Compiler::fgValueNumberHWIntrinsic(GenTreeHWIntrinsic* tree)
ValueNumPair op2Xvnp;
getOperandVNs(tree->Op(2), &op2vnp, &op2Xvnp);

ValueNum normalLVN =
vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetLiberal(), op2vnp.GetLiberal(), encodeResultType,
resultTypeVNPair.GetLiberal());
ValueNum normalCVN =
vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetConservative(), op2vnp.GetConservative(),
encodeResultType, resultTypeVNPair.GetConservative());
if (opCount == 2)
{
ValueNum normalLVN =
vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetLiberal(), op2vnp.GetLiberal(), encodeResultType,
resultTypeVNPair.GetLiberal());
ValueNum normalCVN =
vnStore->EvalHWIntrinsicFunBinary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetConservative(), op2vnp.GetConservative(),
encodeResultType, resultTypeVNPair.GetConservative());

normalPair = ValueNumPair(normalLVN, normalCVN);
excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
normalPair = ValueNumPair(normalLVN, normalCVN);
excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
}
else
{
assert(opCount == 3);

ValueNumPair op3vnp;
ValueNumPair op3Xvnp;
getOperandVNs(tree->Op(3), &op3vnp, &op3Xvnp);

ValueNum normalLVN =
vnStore->EvalHWIntrinsicFunTernary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetLiberal(), op2vnp.GetLiberal(),
op3vnp.GetLiberal(), encodeResultType,
resultTypeVNPair.GetLiberal());
ValueNum normalCVN =
vnStore->EvalHWIntrinsicFunTernary(tree->TypeGet(), tree->GetSimdBaseType(), intrinsicId, func,
op1vnp.GetConservative(), op2vnp.GetConservative(),
op3vnp.GetConservative(), encodeResultType,
resultTypeVNPair.GetConservative());

normalPair = ValueNumPair(normalLVN, normalCVN);

excSetPair = vnStore->VNPExcSetUnion(op1Xvnp, op2Xvnp);
excSetPair = vnStore->VNPExcSetUnion(excSetPair, op3Xvnp);
}
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/coreclr/jit/valuenum.h
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,16 @@ class ValueNumStore
bool encodeResultType,
ValueNum resultTypeVN);

ValueNum EvalHWIntrinsicFunTernary(var_types type,
var_types baseType,
NamedIntrinsic ni,
VNFunc func,
ValueNum arg0VN,
ValueNum arg1VN,
ValueNum arg2VN,
bool encodeResultType,
ValueNum resultTypeVN);

// Returns "true" iff "vn" represents a function application.
bool IsVNFunc(ValueNum vn);

Expand Down

0 comments on commit a052348

Please sign in to comment.