Skip to content

[AMDGPU] Convert 64-bit sra to 32-bit if shift amt >= 32 #144421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 87 additions & 26 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4151,32 +4151,96 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,

SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (N->getValueType(0) != MVT::i64)
SDValue RHS = N->getOperand(1);
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
EVT VT = N->getValueType(0);
SDValue LHS = N->getOperand(0);
SelectionDAG &DAG = DCI.DAG;
SDLoc SL(N);

if (VT.getScalarType() != MVT::i64)
return SDValue();

const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!RHS)
// For C >= 32
// i64 (sra x, C) -> (build_pair (sra hi_32(x), C - 32), sra hi_32(x), 31))

// On some subtargets, 64-bit shift is a quarter rate instruction. In the
// common case, splitting this into a move and a 32-bit shift is faster and
// the same code size.
KnownBits Known = DAG.computeKnownBits(RHS);

EVT ElementType = VT.getScalarType();
EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
: TargetScalarType;

if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
return SDValue();

SelectionDAG &DAG = DCI.DAG;
SDLoc SL(N);
unsigned RHSVal = RHS->getZExtValue();
SDValue ShiftFullAmt =
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
SDValue ShiftAmt;
if (CRHS) {
unsigned RHSVal = CRHS->getZExtValue();
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
TargetType);
} else if (Known.getMinValue().getZExtValue() ==
(ElementType.getSizeInBits() - 1)) {
ShiftAmt = ShiftFullAmt;
} else {
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
const SDValue ShiftMask =
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
// This AND instruction will clamp out of bounds shift values.
// It will also be removed during later instruction selection.
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
}

// For C >= 32
// (sra i64:x, C) -> build_pair (sra hi_32(x), C - 32), (sra hi_32(x), 31)
if (RHSVal >= 32) {
SDValue Hi = getHiHalf64(N->getOperand(0), DAG);
Hi = DAG.getFreeze(Hi);
SDValue HiShift = DAG.getNode(ISD::SRA, SL, MVT::i32, Hi,
DAG.getConstant(31, SL, MVT::i32));
SDValue LoShift = DAG.getNode(ISD::SRA, SL, MVT::i32, Hi,
DAG.getConstant(RHSVal - 32, SL, MVT::i32));
EVT ConcatType;
SDValue Hi;
SDLoc LHSSL(LHS);
// Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
if (VT.isVector()) {
unsigned NElts = TargetType.getVectorNumElements();
ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
SmallVector<SDValue, 8> HiOps(NElts);
SmallVector<SDValue, 16> HiAndLoOps;

SDValue BuildVec = DAG.getBuildVector(MVT::v2i32, SL, {LoShift, HiShift});
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, BuildVec);
DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
for (unsigned I = 0; I != NElts; ++I) {
HiOps[I] = HiAndLoOps[2 * I + 1];
}
Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
} else {
const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
}
Hi = DAG.getFreeze(Hi);

return SDValue();
SDValue HiShift = DAG.getNode(ISD::SRA, SL, TargetType, Hi, ShiftFullAmt);
SDValue NewShift = DAG.getNode(ISD::SRA, SL, TargetType, Hi, ShiftAmt);

SDValue Vec;
if (VT.isVector()) {
unsigned NElts = TargetType.getVectorNumElements();
SmallVector<SDValue, 8> HiOps;
SmallVector<SDValue, 8> LoOps;
SmallVector<SDValue, 16> HiAndLoOps(NElts * 2);

DAG.ExtractVectorElements(HiShift, HiOps, 0, NElts);
DAG.ExtractVectorElements(NewShift, LoOps, 0, NElts);
for (unsigned I = 0; I != NElts; ++I) {
HiAndLoOps[2 * I + 1] = HiOps[I];
HiAndLoOps[2 * I] = LoOps[I];
}
Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
} else {
Vec = DAG.getBuildVector(ConcatType, SL, {NewShift, HiShift});
}
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
}

SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
Expand Down Expand Up @@ -4213,7 +4277,7 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
return SDValue();

// for C >= 32
// i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
// i64 (srl x, C) -> (build_pair (srl hi_32(x), C - 32), 0)

// On some subtargets, 64-bit shift is a quarter rate instruction. In the
// common case, splitting this into a move and a 32-bit shift is faster and
Expand Down Expand Up @@ -5265,25 +5329,22 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case ISD::SHL:
case ISD::SRA:
case ISD::SRL: {
// Range metadata can be invalidated when loads are converted to legal types
// (e.g. v2i64 -> v4i32).
// Try to convert vector shl/srl before type legalization so that range
// Try to convert vector shl/sra/srl before type legalization so that range
// metadata can be utilized.
if (!(N->getValueType(0).isVector() &&
DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
break;
if (N->getOpcode() == ISD::SHL)
return performShlCombine(N, DCI);
if (N->getOpcode() == ISD::SRA)
return performSraCombine(N, DCI);
return performSrlCombine(N, DCI);
}
case ISD::SRA: {
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
break;

return performSraCombine(N, DCI);
}
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
case ISD::MUL:
Expand Down
Loading
Loading