-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Provide a more efficient lowering for experimental.cttz.elts. #88552
Conversation
For experimental.cttz.elts, we can use a vfirst instruction, but we need to correct the result if input vector can be 0. cttz.elts returns the vector length while vfirst returns -1.
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesFor experimental.cttz.elts, we can use a vfirst instruction, but we need to correct the result if input vector can be 0. cttz.elts returns the vector length while vfirst returns -1. Full diff: https://github.com/llvm/llvm-project/pull/88552.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..fa4e9bf002e967 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1484,6 +1484,11 @@ bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT,
return VF > MaxVF || !isPowerOf2_32(VF);
}
+bool RISCVTargetLowering::shouldExpandCttzElements(EVT VT) const {
+ return !Subtarget.hasVInstructions() ||
+ VT.getVectorElementType() != MVT::i1 || !isTypeLegal(VT);
+}
+
bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
const CallInst &I,
MachineFunction &MF,
@@ -8718,6 +8723,33 @@ static SDValue lowerGetVectorLength(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), Res);
}
+static SDValue lowerCttzElts(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ SDValue Op0 = N->getOperand(1);
+ MVT OpVT = Op0.getSimpleValueType();
+ MVT ContainerVT = OpVT;
+ if (OpVT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(DAG, OpVT, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ }
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDLoc DL(N);
+ auto [Mask, VL] = getDefaultVLOps(OpVT, ContainerVT, DL, DAG, Subtarget);
+ SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Op0, Mask, VL);
+ if (isOneConstant(N->getOperand(2)))
+ return Res;
+
+ // Convert -1 to VL.
+ SDValue Setcc =
+ DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+ // We need to use vscale rather than X0 for scalable vectors.
+ if (!OpVT.isFixedLengthVector())
+ VL = DAG.getVScale(
+ DL, XLenVT,
+ APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements()));
+ return DAG.getSelect(DL, XLenVT, Setcc, VL, Res);
+}
+
static inline void promoteVCIXScalar(const SDValue &Op,
SmallVectorImpl<SDValue> &Operands,
SelectionDAG &DAG) {
@@ -8913,6 +8945,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
case Intrinsic::experimental_get_vector_length:
return lowerGetVectorLength(Op.getNode(), DAG, Subtarget);
+ case Intrinsic::experimental_cttz_elts:
+ return lowerCttzElts(Op.getNode(), DAG, Subtarget);
case Intrinsic::riscv_vmv_x_s: {
SDValue Res = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Op.getOperand(1));
return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
@@ -12336,6 +12370,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
return;
}
+ case Intrinsic::experimental_cttz_elts: {
+ SDValue Res = lowerCttzElts(N, DAG, Subtarget);
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
+ return;
+ }
case Intrinsic::riscv_orc_b:
case Intrinsic::riscv_brev8:
case Intrinsic::riscv_sha256sig0:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ace5b3fd2b95b4..e2633733c31b19 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -986,6 +986,8 @@ class RISCVTargetLowering : public TargetLowering {
bool shouldExpandGetVectorLength(EVT TripCountVT, unsigned VF,
bool IsScalable) const override;
+ bool shouldExpandCttzElements(EVT VT) const override;
+
/// RVV code generation for fixed length vectors does not lower all
/// BUILD_VECTORs. This makes BUILD_VECTOR legalisation a source of stores to
/// merge. However, merging them creates a BUILD_VECTOR that is just as
diff --git a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
index 65d0768c60885d..8fe6ef0ab52a06 100644
--- a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
+++ b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
@@ -128,43 +128,88 @@ define i64 @ctz_nxv8i1_no_range(<vscale x 8 x i16> %a) {
define i32 @ctz_nxv16i1(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
; RV32-LABEL: ctz_nxv16i1:
; RV32: # %bb.0:
-; RV32-NEXT: vmv1r.v v0, v8
+; RV32-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; RV32-NEXT: vfirst.m a0, v8
+; RV32-NEXT: bgez a0, .LBB2_2
+; RV32-NEXT: # %bb.1:
; RV32-NEXT: csrr a0, vlenb
; RV32-NEXT: slli a0, a0, 1
-; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
-; RV32-NEXT: vmv.v.x v8, a0
-; RV32-NEXT: vid.v v16
-; RV32-NEXT: li a1, -1
-; RV32-NEXT: vmadd.vx v16, a1, v8
-; RV32-NEXT: vmv.v.i v8, 0
-; RV32-NEXT: vmerge.vvm v8, v8, v16, v0
-; RV32-NEXT: vredmaxu.vs v8, v8, v8
-; RV32-NEXT: vmv.x.s a1, v8
-; RV32-NEXT: sub a0, a0, a1
+; RV32-NEXT: .LBB2_2:
; RV32-NEXT: ret
;
; RV64-LABEL: ctz_nxv16i1:
; RV64: # %bb.0:
-; RV64-NEXT: vmv1r.v v0, v8
+; RV64-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; RV64-NEXT: vfirst.m a0, v8
+; RV64-NEXT: bgez a0, .LBB2_2
+; RV64-NEXT: # %bb.1:
; RV64-NEXT: csrr a0, vlenb
; RV64-NEXT: slli a0, a0, 1
-; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
-; RV64-NEXT: vmv.v.x v8, a0
-; RV64-NEXT: vid.v v16
-; RV64-NEXT: li a1, -1
-; RV64-NEXT: vmadd.vx v16, a1, v8
-; RV64-NEXT: vmv.v.i v8, 0
-; RV64-NEXT: vmerge.vvm v8, v8, v16, v0
-; RV64-NEXT: vredmaxu.vs v8, v8, v8
-; RV64-NEXT: vmv.x.s a1, v8
-; RV64-NEXT: subw a0, a0, a1
+; RV64-NEXT: .LBB2_2:
; RV64-NEXT: ret
%res = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %a, i1 0)
ret i32 %res
}
+define i32 @ctz_nxv16i1_poison(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
+; RV32-LABEL: ctz_nxv16i1_poison:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; RV32-NEXT: vfirst.m a0, v8
+; RV32-NEXT: ret
+;
+; RV64-LABEL: ctz_nxv16i1_poison:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; RV64-NEXT: vfirst.m a0, v8
+; RV64-NEXT: ret
+ %res = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %a, i1 1)
+ ret i32 %res
+}
+
+define i32 @ctz_v16i1(<16 x i1> %pg, <16 x i1> %a) {
+; RV32-LABEL: ctz_v16i1:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; RV32-NEXT: vfirst.m a0, v8
+; RV32-NEXT: bgez a0, .LBB4_2
+; RV32-NEXT: # %bb.1:
+; RV32-NEXT: li a0, 16
+; RV32-NEXT: .LBB4_2:
+; RV32-NEXT: ret
+;
+; RV64-LABEL: ctz_v16i1:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; RV64-NEXT: vfirst.m a0, v8
+; RV64-NEXT: bgez a0, .LBB4_2
+; RV64-NEXT: # %bb.1:
+; RV64-NEXT: li a0, 16
+; RV64-NEXT: .LBB4_2:
+; RV64-NEXT: ret
+ %res = call i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1> %a, i1 0)
+ ret i32 %res
+}
+
+define i32 @ctz_v16i1_poison(<16 x i1> %pg, <16 x i1> %a) {
+; RV32-LABEL: ctz_v16i1_poison:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; RV32-NEXT: vfirst.m a0, v8
+; RV32-NEXT: ret
+;
+; RV64-LABEL: ctz_v16i1_poison:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; RV64-NEXT: vfirst.m a0, v8
+; RV64-NEXT: ret
+ %res = call i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1> %a, i1 1)
+ ret i32 %res
+}
+
declare i64 @llvm.experimental.cttz.elts.i64.nxv8i16(<vscale x 8 x i16>, i1)
declare i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1>, i1)
declare i32 @llvm.experimental.cttz.elts.i32.nxv4i32(<vscale x 4 x i32>, i1)
+declare i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1>, i1)
attributes #0 = { vscale_range(2,1024) }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if (!OpVT.isFixedLengthVector()) | ||
VL = DAG.getVScale( | ||
DL, XLenVT, | ||
APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can also do
if (!OpVT.isFixedLengthVector()) | |
VL = DAG.getVScale( | |
DL, XLenVT, | |
APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements())); | |
SDValue VL = DAG.getElementCount(DL, XLenVT, OpVT.getVectorElementCount()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/minor comment and a possible follow up.
@@ -1484,6 +1484,11 @@ bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT, | |||
return VF > MaxVF || !isPowerOf2_32(VF); | |||
} | |||
|
|||
bool RISCVTargetLowering::shouldExpandCttzElements(EVT VT) const { | |||
return !Subtarget.hasVInstructions() || | |||
VT.getVectorElementType() != MVT::i1 || !isTypeLegal(VT); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could handle the legal non-i1 case via s simple vector comparison against zero - producing a mask result that then flows through this code. From the tests, this looks like a pretty huge improvement in codegen quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. I was just trying to cover the new usage from #88385
@@ -12336,6 +12370,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, | |||
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); | |||
return; | |||
} | |||
case Intrinsic::experimental_cttz_elts: { | |||
SDValue Res = lowerCttzElts(N, DAG, Subtarget); | |||
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The i32 seems slightly weird here - I'd expect XLenVT unless there's something I didn't spot on a first glance about when this code is reached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is called by type legalization. It needs to return the original type so that it can plug in to the original users that haven't been legalized yet. Looking at it again, it should probably be N->getValuetype(0). I think we just lack test coverage for return types other than i32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -12336,6 +12366,12 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, | |||
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res)); | |||
return; | |||
} | |||
case Intrinsic::experimental_cttz_elts: { | |||
EVT VT = N->getValueType(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single use variable.
For experimental.cttz.elts, we can use a vfirst instruction, but we need to correct the result if input vector can be 0. cttz.elts returns the vector length while vfirst returns -1.