Skip to content

Commit

Permalink
Reapply "[DAGCombiner] Add support for scalarising extracts of a vect…
Browse files Browse the repository at this point in the history
…or setcc (llvm#117566)" (llvm#118823)

[Reverts d57892a]

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.

---------

Co-authored-by: Paul Walker <paul.walker@arm.com>
  • Loading branch information
david-arm and paulwalker-arm authored Dec 9, 2024
1 parent 6a52a51 commit 8630a7b
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 41 deletions.
43 changes: 27 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22755,16 +22755,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,

/// Transform a vector binary operation into a scalar binary operation by moving
/// the math/logic after an extract element of a vector.
static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
const SDLoc &DL, bool LegalOperations) {
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
const SDLoc &DL, bool LegalTypes) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDValue Vec = ExtElt->getOperand(0);
SDValue Index = ExtElt->getOperand(1);
auto *IndexC = dyn_cast<ConstantSDNode>(Index);
if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
unsigned Opc = Vec.getOpcode();
if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
Vec->getNumValues() != 1)
return SDValue();

EVT ResVT = ExtElt->getValueType(0);
if (Opc == ISD::SETCC &&
(ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
return SDValue();

// Targets may want to avoid this to prevent an expensive register transfer.
if (!TLI.shouldScalarizeBinop(Vec))
return SDValue();
Expand All @@ -22775,19 +22781,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
SDValue Op0 = Vec.getOperand(0);
SDValue Op1 = Vec.getOperand(1);
APInt SplatVal;
if (isAnyConstantBuildVector(Op0, true) ||
ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
isAnyConstantBuildVector(Op1, true) ||
ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
// extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
// extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
EVT VT = ExtElt->getValueType(0);
SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
}
if (!isAnyConstantBuildVector(Op0, true) &&
!ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
!isAnyConstantBuildVector(Op1, true) &&
!ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
return SDValue();

return SDValue();
// extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
// extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
if (Opc == ISD::SETCC) {
EVT OpVT = Op0.getValueType().getVectorElementType();
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
return DAG.getSetCC(DL, ResVT, Op0, Op1,
cast<CondCodeSDNode>(Vec->getOperand(2))->get());
}
Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
}

// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
Expand Down Expand Up @@ -23020,7 +23031,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
}
}

if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
return BO;

if (VecVT.isScalableVector())
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::SELECT_CC: SplitRes_SELECT_CC(N, Lo, Hi); break;
case ISD::UNDEF: SplitRes_UNDEF(N, Lo, Hi); break;
case ISD::FREEZE: SplitRes_FREEZE(N, Lo, Hi); break;
case ISD::SETCC: ExpandIntRes_SETCC(N, Lo, Hi); break;

case ISD::BITCAST: ExpandRes_BITCAST(N, Lo, Hi); break;
case ISD::BUILD_PAIR: ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
Expand Down Expand Up @@ -3316,6 +3317,20 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
}
}

void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDLoc DL(N);

SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
EVT NewVT = getSetCCResultType(LHS.getValueType());

// Taking the same approach as ScalarizeVecRes_SETCC
SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));

Res = DAG.getBoolExtOrTrunc(Res, DL, N->getValueType(0), NewVT);
SplitInteger(Res, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
SDValue &Lo, SDValue &Hi) {
SDLoc DL(N);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_MINMAX (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_CMP (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_SETCC (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
unsigned getMinimumJumpTableEntries() const override;

bool softPromoteHalfType() const override { return true; }

bool shouldScalarizeBinop(SDValue VecOp) const override {
return VecOp.getOpcode() == ISD::SETCC;
}
};

namespace AArch64 {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2107,7 +2107,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {

// Assume target opcodes can't be scalarized.
// TODO - do we have any exceptions?
if (Opc >= ISD::BUILTIN_OP_END)
if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
return false;

// If the vector op is not supported, try to convert to scalar.
Expand Down
46 changes: 24 additions & 22 deletions llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)

define fastcc i8 @allocno_reload_assign() {
define fastcc i8 @allocno_reload_assign(ptr %p) {
; CHECK-LABEL: allocno_reload_assign:
; CHECK: // %bb.0:
; CHECK-NEXT: fmov d0, xzr
Expand All @@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
; CHECK-NEXT: uzp1 p0.s, p0.s, p0.s
; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h
; CHECK-NEXT: uzp1 p0.b, p0.b, p0.b
; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1
; CHECK-NEXT: uzp1 p8.b, p0.b, p0.b
; CHECK-NEXT: mov z0.b, p8/z, #1 // =0x1
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: mov z0.b, #0 // =0x0
; CHECK-NEXT: uunpklo z1.h, z0.b
Expand All @@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: punpkhi p3.h, p1.b
; CHECK-NEXT: punpkhi p4.h, p1.b
; CHECK-NEXT: uunpklo z0.d, z2.s
; CHECK-NEXT: uunpkhi z1.d, z2.s
; CHECK-NEXT: punpklo p5.h, p0.b
; CHECK-NEXT: punpklo p6.h, p0.b
; CHECK-NEXT: uunpklo z2.d, z3.s
; CHECK-NEXT: uunpkhi z3.d, z3.s
; CHECK-NEXT: punpkhi p7.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: uunpklo z4.d, z5.s
; CHECK-NEXT: uunpkhi z5.d, z5.s
; CHECK-NEXT: uunpklo z6.d, z7.s
; CHECK-NEXT: uunpkhi z7.d, z7.s
; CHECK-NEXT: punpklo p0.h, p2.b
; CHECK-NEXT: punpkhi p1.h, p2.b
; CHECK-NEXT: punpklo p2.h, p3.b
; CHECK-NEXT: punpkhi p3.h, p3.b
; CHECK-NEXT: punpklo p4.h, p5.b
; CHECK-NEXT: punpkhi p5.h, p5.b
; CHECK-NEXT: punpklo p6.h, p7.b
; CHECK-NEXT: punpkhi p7.h, p7.b
; CHECK-NEXT: punpklo p1.h, p2.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: punpklo p3.h, p4.b
; CHECK-NEXT: punpkhi p4.h, p4.b
; CHECK-NEXT: punpklo p5.h, p6.b
; CHECK-NEXT: punpkhi p6.h, p6.b
; CHECK-NEXT: punpklo p7.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: .LBB0_1: // =>This Inner Loop Header: Depth=1
; CHECK-NEXT: st1b { z0.d }, p0, [z16.d]
; CHECK-NEXT: st1b { z1.d }, p1, [z16.d]
; CHECK-NEXT: st1b { z2.d }, p2, [z16.d]
; CHECK-NEXT: st1b { z3.d }, p3, [z16.d]
; CHECK-NEXT: st1b { z4.d }, p4, [z16.d]
; CHECK-NEXT: st1b { z5.d }, p5, [z16.d]
; CHECK-NEXT: st1b { z6.d }, p6, [z16.d]
; CHECK-NEXT: st1b { z7.d }, p7, [z16.d]
; CHECK-NEXT: st1b { z0.d }, p1, [z16.d]
; CHECK-NEXT: st1b { z1.d }, p2, [z16.d]
; CHECK-NEXT: st1b { z2.d }, p3, [z16.d]
; CHECK-NEXT: st1b { z3.d }, p4, [z16.d]
; CHECK-NEXT: st1b { z4.d }, p5, [z16.d]
; CHECK-NEXT: st1b { z5.d }, p6, [z16.d]
; CHECK-NEXT: st1b { z6.d }, p7, [z16.d]
; CHECK-NEXT: st1b { z7.d }, p0, [z16.d]
; CHECK-NEXT: str p8, [x0]
; CHECK-NEXT: b .LBB0_1
br label %1

Expand All @@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
%constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
%constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
store <vscale x 16 x i1> %constexpr, ptr %p, align 16
br label %1
}

Expand Down
Loading

0 comments on commit 8630a7b

Please sign in to comment.