Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 49 additions & 37 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17325,18 +17325,9 @@ struct NodeExtensionHelper {
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFWSUB_W_VL:
if (OperandIdx == 1) {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
SupportsSExt =
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
SupportsFPExt =
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
// There's no existing extension here, so we don't have to worry about
// making sure it gets removed.
EnforceOneUse = false;
// Operand 1 can't be changed.
if (OperandIdx == 1)
break;
}
[[fallthrough]];
default:
fillUpExtensionSupport(Root, DAG, Subtarget);
Expand Down Expand Up @@ -17374,20 +17365,20 @@ struct NodeExtensionHelper {
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::OR_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case RISCVISD::FADD_VL:
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
return true;
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::FSUB_VL:
case RISCVISD::VFWSUB_W_VL:
case ISD::SHL:
Expand Down Expand Up @@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
Subtarget);
}

/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
Expand Down Expand Up @@ -17534,52 +17549,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
return std::nullopt;
}

/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
Subtarget);
if (LHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
/*RHSExt=*/std::nullopt);
return std::nullopt;
}

/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
Subtarget);
if (LHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
/*RHSExt=*/std::nullopt);
return std::nullopt;
}

/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
Subtarget);
}

/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
Subtarget);
if (LHS.SupportsFPExt)
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
/*RHSExt=*/std::nullopt);
return std::nullopt;
}

/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
Expand Down Expand Up @@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
Strategies.push_back(canFoldToVWWithBF16EXT);
Strategies.push_back(canFoldToVWWithSameExtBF16);
break;
case ISD::MUL:
case RISCVISD::MUL_VL:
Expand All @@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::SHL:
case RISCVISD::SHL_VL:
// shl -> vwsll
Strategies.push_back(canFoldToVWWithZEXT);
Strategies.push_back(canFoldToVWWithSameExtZEXT);
break;
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWSUB_W_VL:
Expand Down
23 changes: 23 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) {
%i = or <2 x i16> %h, %g
ret <2 x i16> %i
}

; Make sure we have a vsext.vl and a vwaddu.vx.
define <4 x i32> @pr159152(<4 x i8> %x) {
; NO_FOLDING-LABEL: pr159152:
; NO_FOLDING: # %bb.0:
; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; NO_FOLDING-NEXT: vsext.vf2 v9, v8
; NO_FOLDING-NEXT: li a0, 9
; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0
; NO_FOLDING-NEXT: ret
;
; FOLDING-LABEL: pr159152:
; FOLDING: # %bb.0:
; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; FOLDING-NEXT: vsext.vf2 v9, v8
; FOLDING-NEXT: li a0, 9
; FOLDING-NEXT: vwaddu.vx v8, v9, a0
; FOLDING-NEXT: ret
%a = sext <4 x i8> %x to <4 x i16>
%b = zext <4 x i16> %a to <4 x i32>
%c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant vector allows a RISCVISD::VWADDU_W_VL to be formed between LegalizeVectorOps and LegalizeDAG. LegalizeDAG will turn the build_vector into RISCVISD::VMV_V_X_VL. Then we will try to turn the VWADDU_W_VL into VWADD_VL. If we don't use a constant vector we'll go straight to VWADD_VL after LegalizeVectorOps.

ret <4 x i32> %c
}