diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3cd9ecb9dd681..e48ca4a905ce9 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13530,7 +13530,7 @@ struct CombineResult; enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 }; /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: -/// add | add_vl -> vwadd(u) | vwadd(u)_w +/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w /// sub | sub_vl -> vwsub(u) | vwsub(u)_w /// mul | mul_vl -> vwmul(u) | vwmul_su /// fadd -> vfwadd | vfwadd_w @@ -13678,6 +13678,7 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case ISD::OR: return RISCVISD::VWADD_VL; case ISD::SUB: case RISCVISD::SUB_VL: @@ -13700,6 +13701,7 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case ISD::OR: return RISCVISD::VWADDU_VL; case ISD::SUB: case RISCVISD::SUB_VL: @@ -13745,6 +13747,7 @@ struct NodeExtensionHelper { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: + case ISD::OR: return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; case ISD::SUB: @@ -13865,6 +13868,10 @@ struct NodeExtensionHelper { case ISD::MUL: { return Root->getValueType(0).isScalableVector(); } + case ISD::OR: { + return Root->getValueType(0).isScalableVector() && + Root->getFlags().hasDisjoint(); + } // Vector Widening Integer Add/Sub/Mul Instructions case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: @@ -13945,7 +13952,8 @@ struct NodeExtensionHelper { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: - case ISD::MUL: { + case ISD::MUL: + case ISD::OR: { SDLoc DL(Root); MVT VT = Root->getSimpleValueType(0); return getDefaultScalableVLOps(VT, DL, DAG, Subtarget); @@ -13968,6 +13976,7 @@ struct NodeExtensionHelper { switch (N->getOpcode()) { case ISD::ADD: case ISD::MUL: + case ISD::OR: case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -14034,6 +14043,7 @@ struct CombineResult { case ISD::ADD: case ISD::SUB: case ISD::MUL: + case ISD::OR: Merge = DAG.getUNDEF(Root->getValueType(0)); break; } @@ -14184,6 +14194,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: + case ISD::OR: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: case RISCVISD::FADD_VL: @@ -14227,9 +14238,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { /// Combine a binary operation to its equivalent VW or VW_W form. /// The supported combines are: -/// add_vl -> vwadd(u) | vwadd(u)_w -/// sub_vl -> vwsub(u) | vwsub(u)_w -/// mul_vl -> vwmul(u) | vwmul_su +/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w +/// sub | sub_vl -> vwsub(u) | vwsub(u)_w +/// mul | mul_vl -> vwmul(u) | vwmul_su /// fadd_vl -> vfwadd | vfwadd_w /// fsub_vl -> vfwsub | vfwsub_w /// fmul_vl -> vfwmul @@ -15889,8 +15900,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } case ISD::AND: return performANDCombine(N, DCI, Subtarget); - case ISD::OR: + case ISD::OR: { + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; return performORCombine(N, DCI, Subtarget); + } case ISD::XOR: return performXORCombine(N, DAG, Subtarget); case ISD::MUL: diff --git a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll index ed12afdd95956..66e6883dd1d3e 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll @@ -1401,11 +1401,9 @@ define @vwaddu_vv_disjoint_or_add( %x.i8, %x.i8 to %x.shl = shl %x.i16, shufflevector( insertelement( poison, i16 8, i32 0), poison, zeroinitializer) @@ -1450,9 +1448,8 @@ define @vwadd_vv_disjoint_or( %x.i16, @vwaddu_wv_disjoint_or( %x.i32, %y.i16) { ; CHECK-LABEL: vwaddu_wv_disjoint_or: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma -; CHECK-NEXT: vzext.vf2 v10, v9 -; CHECK-NEXT: vor.vv v8, v8, v10 +; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma +; CHECK-NEXT: vwaddu.wv v8, v8, v9 ; CHECK-NEXT: ret %y.i32 = zext %y.i16 to %or = or disjoint %x.i32, %y.i32 @@ -1462,9 +1459,8 @@ define @vwaddu_wv_disjoint_or( %x.i32, @vwadd_wv_disjoint_or( %x.i32, %y.i16) { ; CHECK-LABEL: vwadd_wv_disjoint_or: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma -; CHECK-NEXT: vsext.vf2 v10, v9 -; CHECK-NEXT: vor.vv v8, v8, v10 +; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma +; CHECK-NEXT: vwadd.wv v8, v8, v9 ; CHECK-NEXT: ret %y.i32 = sext %y.i16 to %or = or disjoint %x.i32, %y.i32