Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower GetElement on arm64 to the correct access sequence #104288

Merged
merged 8 commits into from
Jul 5, 2024
105 changes: 11 additions & 94 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,10 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_Vector128_GetElement:
{
assert(intrin.numOperands == 2);
assert(!intrin.op1->isContained());

assert(intrin.op2->OperIsConst());
assert(intrin.op2->isContained());

var_types simdType = Compiler::getSIMDTypeForSize(node->GetSimdSize());

Expand All @@ -1658,109 +1662,22 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
simdType = TYP_SIMD16;
}

if (!intrin.op2->OperIsConst())
{
assert(!intrin.op2->isContained());

emitAttr baseTypeSize = emitTypeSize(intrin.baseType);
unsigned baseTypeScale = genLog2(EA_SIZE_IN_BYTES(baseTypeSize));

regNumber baseReg;
regNumber indexReg = op2Reg;

// Optimize the case of op1 is in memory and trying to access i'th element.
if (!intrin.op1->isUsedFromReg())
{
assert(intrin.op1->isContained());

if (intrin.op1->OperIsLocal())
{
unsigned varNum = intrin.op1->AsLclVarCommon()->GetLclNum();
baseReg = internalRegisters.Extract(node);

// Load the address of varNum
GetEmitter()->emitIns_R_S(INS_lea, EA_PTRSIZE, baseReg, varNum, 0);
}
else
{
// Require GT_IND addr to be not contained.
assert(intrin.op1->OperIs(GT_IND));

GenTree* addr = intrin.op1->AsIndir()->Addr();
assert(!addr->isContained());
baseReg = addr->GetRegNum();
}
}
else
{
unsigned simdInitTempVarNum = compiler->lvaSIMDInitTempVarNum;
noway_assert(simdInitTempVarNum != BAD_VAR_NUM);

baseReg = internalRegisters.Extract(node);

// Load the address of simdInitTempVarNum
GetEmitter()->emitIns_R_S(INS_lea, EA_PTRSIZE, baseReg, simdInitTempVarNum, 0);

// Store the vector to simdInitTempVarNum
GetEmitter()->emitIns_R_R(INS_str, emitTypeSize(simdType), op1Reg, baseReg);
}

assert(genIsValidIntReg(indexReg));
assert(genIsValidIntReg(baseReg));
assert(baseReg != indexReg);
ssize_t ival = intrin.op2->AsIntCon()->IconValue();

// Load item at baseReg[index]
GetEmitter()->emitIns_R_R_R_Ext(ins_Load(intrin.baseType), baseTypeSize, targetReg, baseReg,
indexReg, INS_OPTS_LSL, baseTypeScale);
}
else if (!GetEmitter()->isValidVectorIndex(emitTypeSize(simdType), emitTypeSize(intrin.baseType),
intrin.op2->AsIntCon()->IconValue()))
if (!GetEmitter()->isValidVectorIndex(emitTypeSize(simdType), emitTypeSize(intrin.baseType), ival))
{
// We only need to generate code for the get if the index is valid
// If the index is invalid, previously generated for the range check will throw
break;
}
else if (!intrin.op1->isUsedFromReg())
{
assert(intrin.op1->isContained());
assert(intrin.op2->IsCnsIntOrI());

int offset = (int)intrin.op2->AsIntCon()->IconValue() * genTypeSize(intrin.baseType);
instruction ins = ins_Load(intrin.baseType);

assert(!intrin.op1->isUsedFromReg());

if (intrin.op1->OperIsLocal())
{
unsigned varNum = intrin.op1->AsLclVarCommon()->GetLclNum();
GetEmitter()->emitIns_R_S(ins, emitActualTypeSize(intrin.baseType), targetReg, varNum, offset);
}
else
{
assert(intrin.op1->OperIs(GT_IND));

GenTree* addr = intrin.op1->AsIndir()->Addr();
assert(!addr->isContained());
regNumber baseReg = addr->GetRegNum();

// ldr targetReg, [baseReg, #offset]
GetEmitter()->emitIns_R_R_I(ins, emitActualTypeSize(intrin.baseType), targetReg, baseReg,
offset);
}
}
else
if ((varTypeIsFloating(intrin.baseType) && (targetReg == op1Reg) && (ival == 0)))
{
assert(intrin.op2->IsCnsIntOrI());
ssize_t indexValue = intrin.op2->AsIntCon()->IconValue();

// no-op if vector is float/double, targetReg == op1Reg and fetching for 0th index.
if ((varTypeIsFloating(intrin.baseType) && (targetReg == op1Reg) && (indexValue == 0)))
{
break;
}

GetEmitter()->emitIns_R_R_I(ins, emitTypeSize(intrin.baseType), targetReg, op1Reg, indexValue,
INS_OPTS_NONE);
break;
}

GetEmitter()->emitIns_R_R_I(ins, emitTypeSize(intrin.baseType), targetReg, op1Reg, ival, INS_OPTS_NONE);
break;
}

Expand Down
143 changes: 126 additions & 17 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,128 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node)
return LowerHWIntrinsicDot(node);
}

case NI_Vector64_GetElement:
case NI_Vector128_GetElement:
{
GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);

bool isContainableMemory = IsContainableMemoryOp(op1) && IsSafeToContainMem(node, op1);

if (isContainableMemory || !op2->OperIsConst())
{
unsigned simdSize = node->GetSimdSize();
var_types simdBaseType = node->GetSimdBaseType();
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);

// We're either already loading from memory or we need to since
// we don't know what actual index is going to be retrieved.

unsigned lclNum = BAD_VAR_NUM;
unsigned lclOffs = 0;

if (!isContainableMemory)
{
// We aren't already in memory, so we need to spill there

comp->getSIMDInitTempVarNum(simdType);
lclNum = comp->lvaSIMDInitTempVarNum;

GenTree* storeLclVar = comp->gtNewStoreLclVarNode(lclNum, op1);
BlockRange().InsertBefore(node, storeLclVar);
LowerNode(storeLclVar);
}
else if (op1->IsLocal())
{
// We're an existing local that is loaded from memory
GenTreeLclVarCommon* lclVar = op1->AsLclVarCommon();

lclNum = lclVar->GetLclNum();
lclOffs = lclVar->GetLclOffs();

BlockRange().Remove(op1);
}

if (lclNum != BAD_VAR_NUM)
{
// We need to get the address of the local
op1 = comp->gtNewLclAddrNode(lclNum, lclOffs, TYP_BYREF);
BlockRange().InsertBefore(node, op1);
LowerNode(op1);
}
else
{
assert(op1->isIndir());

// We need to get the underlying address
GenTree* addr = op1->AsIndir()->Addr();
BlockRange().Remove(op1);
op1 = addr;
}

GenTree* offset = op2;
unsigned baseTypeSize = genTypeSize(simdBaseType);

if (offset->OperIsConst())
{
// We have a constant index, so scale it up directly
GenTreeIntConCommon* index = offset->AsIntCon();
index->SetIconValue(index->IconValue() * baseTypeSize);
}
else
{
// We have a non-constant index, so scale it up via mul but
// don't lower the GT_MUL node since the indir will try to
// create an addressing mode and will do folding itself. We
// do, however, skip the multiply for scale == 1

if (baseTypeSize != 1)
{
GenTreeIntConCommon* scale = comp->gtNewIconNode(baseTypeSize);
BlockRange().InsertBefore(node, scale);

offset = comp->gtNewOperNode(GT_MUL, offset->TypeGet(), offset, scale);
BlockRange().InsertBefore(node, offset);
}
}

// Add the offset, don't lower the GT_ADD node since the indir will
// try to create an addressing mode and will do folding itself. We
// do, however, skip the add for offset == 0
GenTree* addr = op1;

if (!offset->IsIntegralConst(0))
{
addr = comp->gtNewOperNode(GT_ADD, addr->TypeGet(), addr, offset);
BlockRange().InsertBefore(node, addr);
}
else
{
BlockRange().Remove(offset);
}

// Finally we can indirect the memory address to get the actual value
GenTreeIndir* indir = comp->gtNewIndir(simdBaseType, addr);
BlockRange().InsertBefore(node, indir);

LIR::Use use;
if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(indir);
}
else
{
indir->SetUnusedValue();
}

BlockRange().Remove(node);
return LowerNode(indir);
}

assert(op2->OperIsConst());
break;
}

case NI_Vector64_op_Equality:
case NI_Vector128_op_Equality:
{
Expand Down Expand Up @@ -3310,24 +3432,11 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_Vector64_GetElement:
case NI_Vector128_GetElement:
{
assert(varTypeIsIntegral(intrin.op2));
assert(!IsContainableMemoryOp(intrin.op1) || !IsSafeToContainMem(node, intrin.op1));
assert(intrin.op2->OperIsConst());

if (intrin.op2->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op2);
}

// TODO: Codegen isn't currently handling this correctly
//
// if (IsContainableMemoryOp(intrin.op1) && IsSafeToContainMem(node, intrin.op1))
// {
// MakeSrcContained(node, intrin.op1);
//
// if (intrin.op1->OperIs(GT_IND))
// {
// intrin.op1->AsIndir()->Addr()->ClearContained();
// }
// }
// Loading a constant index from register
MakeSrcContained(node, intrin.op2);
break;
}

Expand Down
23 changes: 1 addition & 22 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,6 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
srcCount += BuildDelayFreeUses(intrin.op3, embOp2Node->Op(1));
}
}

else if (intrin.op2 != nullptr)
{
// RMW intrinsic operands doesn't have to be delayFree when they can be assigned the same register as op1Reg
Expand All @@ -1928,28 +1927,8 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
bool forceOp2DelayFree = false;
SingleTypeRegSet lowVectorCandidates = RBM_NONE;
size_t lowVectorOperandNum = 0;
if ((intrin.id == NI_Vector64_GetElement) || (intrin.id == NI_Vector128_GetElement))
{
if (!intrin.op2->IsCnsIntOrI() && (!intrin.op1->isContained() || intrin.op1->OperIsLocal()))
{
// If the index is not a constant and the object is not contained or is a local
// we will need a general purpose register to calculate the address
// internal register must not clobber input index
// TODO-Cleanup: An internal register will never clobber a source; this code actually
// ensures that the index (op2) doesn't interfere with the target.
buildInternalIntRegisterDefForNode(intrinsicTree);
forceOp2DelayFree = true;
}

if (!intrin.op2->IsCnsIntOrI() && !intrin.op1->isContained())
{
// If the index is not a constant or op1 is in register,
// we will use the SIMD temp location to store the vector.
var_types requiredSimdTempType = (intrin.id == NI_Vector64_GetElement) ? TYP_SIMD8 : TYP_SIMD16;
compiler->getSIMDInitTempVarNum(requiredSimdTempType);
}
}
else if (HWIntrinsicInfo::IsLowVectorOperation(intrin.id))
if (HWIntrinsicInfo::IsLowVectorOperation(intrin.id))
{
getLowVectorOperandAndCandidates(intrin, &lowVectorOperandNum, &lowVectorCandidates);
}
Expand Down
Loading