Skip to content

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Sep 16, 2025

These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension.

This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))).

To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source.

Fixes #159152.

@llvmbot
Copy link
Member

llvmbot commented Sep 16, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension.

This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))).

To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source.

Fixes #159152.

I'll add a test shortly.


Full diff: https://github.com/llvm/llvm-project/pull/159205.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+49-37)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 66ebda7aa586b..863b6b5b36d3b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17325,18 +17325,9 @@ struct NodeExtensionHelper {
     case RISCVISD::VWSUBU_W_VL:
     case RISCVISD::VFWADD_W_VL:
     case RISCVISD::VFWSUB_W_VL:
-      if (OperandIdx == 1) {
-        SupportsZExt =
-            Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
-        SupportsSExt =
-            Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
-        SupportsFPExt =
-            Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
-        // There's no existing extension here, so we don't have to worry about
-        // making sure it gets removed.
-        EnforceOneUse = false;
+      // Operand 1 can't be changed.
+      if (OperandIdx == 1)
         break;
-      }
       [[fallthrough]];
     default:
       fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17374,20 +17365,20 @@ struct NodeExtensionHelper {
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::OR_VL:
-    case RISCVISD::VWADD_W_VL:
-    case RISCVISD::VWADDU_W_VL:
     case RISCVISD::FADD_VL:
     case RISCVISD::FMUL_VL:
-    case RISCVISD::VFWADD_W_VL:
     case RISCVISD::VFMADD_VL:
     case RISCVISD::VFNMSUB_VL:
     case RISCVISD::VFNMADD_VL:
     case RISCVISD::VFMSUB_VL:
       return true;
+    case RISCVISD::VWADD_W_VL:
+    case RISCVISD::VWADDU_W_VL:
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::VFWADD_W_VL:
     case RISCVISD::FSUB_VL:
     case RISCVISD::VFWSUB_W_VL:
     case ISD::SHL:
@@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
       Subtarget);
 }
 
+/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+                           const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                           const RISCVSubtarget &Subtarget) {
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
+                                          Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
+                           const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                           const RISCVSubtarget &Subtarget) {
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
+                                          Subtarget);
+}
+
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17534,7 +17549,7 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
   return std::nullopt;
 }
 
-/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
+/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
@@ -17542,11 +17557,14 @@ static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
-                                          Subtarget);
+  if (LHS.SupportsSExt)
+    return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
+                         Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+                         /*RHSExt=*/std::nullopt);
+  return std::nullopt;
 }
 
-/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
+/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
@@ -17554,11 +17572,14 @@ static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
-                                          Subtarget);
+  if (LHS.SupportsZExt)
+    return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
+                         Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+                         /*RHSExt=*/std::nullopt);
+  return std::nullopt;
 }
 
-/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
@@ -17566,20 +17587,11 @@ static std::optional<CombineResult>
 canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                      const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                      const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
-                                          Subtarget);
-}
-
-/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
-///
-/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
-/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                       const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-                       const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
-                                          Subtarget);
+  if (LHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+                         Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+                         /*RHSExt=*/std::nullopt);
+  return std::nullopt;
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case RISCVISD::VFNMSUB_VL:
     Strategies.push_back(canFoldToVWWithSameExtension);
     if (Root->getOpcode() == RISCVISD::VFMADD_VL)
-      Strategies.push_back(canFoldToVWWithBF16EXT);
+      Strategies.push_back(canFoldToVWWithSameExtBF16);
     break;
   case ISD::MUL:
   case RISCVISD::MUL_VL:
@@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case ISD::SHL:
   case RISCVISD::SHL_VL:
     // shl -> vwsll
-    Strategies.push_back(canFoldToVWWithZEXT);
+    Strategies.push_back(canFoldToVWWithSameExtZEXT);
     break;
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWSUB_W_VL:

… combineOp_VLToVWOp_VL.

These instructions have one already narrow operand. Previously,
we pretended like this operand was a supported extension.

This could cause problems when we called getOrCreateExtendedOp on
this narrow operand when creating the the VWADD_VL. If the narrow
operand happened to be an extend of the opposite type, we would
peek through it and then rebuild it with the wrong extension
type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become
(vwadd_vl (i16 (sext X)), (i16 (sext Y))).

To prevent this, we ignore the operand instead and pass std::nullopt
for SupportsExt to getOrCreateExtendedOp so it won't peek through any
extends on the narrow source.

Fixes llvm#159152.
@topperc topperc force-pushed the pr/combineOp_VLToVWOp_VL branch from ab40f44 to a156475 Compare September 16, 2025 22:45
; FOLDING-NEXT: ret
%a = sext <4 x i8> %x to <4 x i16>
%b = zext <4 x i16> %a to <4 x i32>
%c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant vector allows a RISCVISD::VWADDU_W_VL to be formed between LegalizeVectorOps and LegalizeDAG. LegalizeDAG will turn the build_vector into RISCVISD::VMV_V_X_VL. Then we will try to turn the VWADDU_W_VL into VWADD_VL. If we don't use a constant vector we'll go straight to VWADD_VL after LegalizeVectorOps.

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I think this needs to be backported to llvm 21?

@topperc
Copy link
Collaborator Author

topperc commented Sep 17, 2025

LGTM.

I think this needs to be backported to llvm 21?

I think llvm 21 needs a fix, but I think I want this fix to be tested for a bit.

Subtarget);
}

/// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixed up the comments here

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit 6119d1f into llvm:main Sep 19, 2025
9 checks passed
@topperc topperc deleted the pr/combineOp_VLToVWOp_VL branch September 19, 2025 16:20
swift-ci pushed a commit to swiftlang/llvm-project that referenced this pull request Sep 23, 2025
… combineOp_VLToVWOp_VL. (llvm#159205)

These instructions have one already narrow operand. Previously, we
pretended like this operand was a supported extension.

This could cause problems when we called getOrCreateExtendedOp on this
narrow operand when creating the the VWADD_VL. If the narrow operand
happened to be an extend of the opposite type, we would peek through it
and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32
(sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16
(sext Y))).

To prevent this, we ignore the operand instead and pass std::nullopt for
SupportsExt to getOrCreateExtendedOp so it won't peek through any
extends on the narrow source.

Fixes llvm#159152.

(cherry picked from commit 6119d1f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RISC-V] Miscompile at -O3 with -flto

4 participants