-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL. #159205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesThese instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes #159152. I'll add a test shortly. Full diff: https://github.com/llvm/llvm-project/pull/159205.diff 1 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 66ebda7aa586b..863b6b5b36d3b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17325,18 +17325,9 @@ struct NodeExtensionHelper {
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFWSUB_W_VL:
- if (OperandIdx == 1) {
- SupportsZExt =
- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt =
- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
- SupportsFPExt =
- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
- // There's no existing extension here, so we don't have to worry about
- // making sure it gets removed.
- EnforceOneUse = false;
+ // Operand 1 can't be changed.
+ if (OperandIdx == 1)
break;
- }
[[fallthrough]];
default:
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17374,20 +17365,20 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::OR_VL:
- case RISCVISD::VWADD_W_VL:
- case RISCVISD::VWADDU_W_VL:
case RISCVISD::FADD_VL:
case RISCVISD::FMUL_VL:
- case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
return true;
+ case RISCVISD::VWADD_W_VL:
+ case RISCVISD::VWADDU_W_VL:
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
case RISCVISD::FSUB_VL:
case RISCVISD::VFWSUB_W_VL:
case ISD::SHL:
@@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
Subtarget);
}
+/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
+ Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
+ Subtarget);
+}
+
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17534,7 +17549,7 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
+/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17542,11 +17557,14 @@ static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
- Subtarget);
+ if (LHS.SupportsSExt)
+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
+/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17554,11 +17572,14 @@ static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
- Subtarget);
+ if (LHS.SupportsZExt)
+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
-/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
@@ -17566,20 +17587,11 @@ static std::optional<CombineResult>
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
- Subtarget);
-}
-
-/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
-///
-/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
-/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
- Subtarget);
+ if (LHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+ /*RHSExt=*/std::nullopt);
+ return std::nullopt;
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
- Strategies.push_back(canFoldToVWWithBF16EXT);
+ Strategies.push_back(canFoldToVWWithSameExtBF16);
break;
case ISD::MUL:
case RISCVISD::MUL_VL:
@@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::SHL:
case RISCVISD::SHL_VL:
// shl -> vwsll
- Strategies.push_back(canFoldToVWWithZEXT);
+ Strategies.push_back(canFoldToVWWithSameExtZEXT);
break;
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWSUB_W_VL:
|
… combineOp_VLToVWOp_VL. These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes llvm#159152.
ab40f44
to
a156475
Compare
; FOLDING-NEXT: ret | ||
%a = sext <4 x i8> %x to <4 x i16> | ||
%b = zext <4 x i16> %a to <4 x i32> | ||
%c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constant vector allows a RISCVISD::VWADDU_W_VL to be formed between LegalizeVectorOps and LegalizeDAG. LegalizeDAG will turn the build_vector into RISCVISD::VMV_V_X_VL. Then we will try to turn the VWADDU_W_VL into VWADD_VL. If we don't use a constant vector we'll go straight to VWADD_VL after LegalizeVectorOps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I think this needs to be backported to llvm 21?
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
Outdated
Show resolved
Hide resolved
I think llvm 21 needs a fix, but I think I want this fix to be tested for a bit. |
Subtarget); | ||
} | ||
|
||
/// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixed up the comments here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… combineOp_VLToVWOp_VL. (llvm#159205) These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes llvm#159152. (cherry picked from commit 6119d1f)
These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension.
This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))).
To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source.
Fixes #159152.