@@ -1462,8 +1462,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14621462
14631463 const uint64_t C1 = N1C->getZExtValue ();
14641464
1465- // Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a mask
1466- // with c3 leading zeros and c2 is larger than c3.
14671465 if (N0.getOpcode () == ISD::SRA && isa<ConstantSDNode>(N0.getOperand (1 )) &&
14681466 N0.hasOneUse ()) {
14691467 unsigned C2 = N0.getConstantOperandVal (1 );
@@ -1477,6 +1475,8 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14771475 X.getOpcode () == ISD::SHL &&
14781476 isa<ConstantSDNode>(X.getOperand (1 )) &&
14791477 X.getConstantOperandVal (1 ) == 32 ;
1478+ // Turn (and (sra x, c2), c1) -> (srli (srai x, c2-c3), c3) if c1 is a
1479+ // mask with c3 leading zeros and c2 is larger than c3.
14801480 if (isMask_64 (C1) && !Skip) {
14811481 unsigned Leading = XLen - llvm::bit_width (C1);
14821482 if (C2 > Leading) {
@@ -1490,6 +1490,27 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
14901490 return ;
14911491 }
14921492 }
1493+
1494+ // Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
1495+ // leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
1496+ // use (slli (srli (srai y, c2 - c3), c3 + c4), c4).
1497+ if (isShiftedMask_64 (C1) && !Skip) {
1498+ unsigned Leading = XLen - llvm::bit_width (C1);
1499+ unsigned Trailing = llvm::countr_zero (C1);
1500+ if (C2 > Leading && Leading > 0 && Trailing > 0 ) {
1501+ SDNode *SRAI = CurDAG->getMachineNode (
1502+ RISCV::SRAI, DL, VT, N0.getOperand (0 ),
1503+ CurDAG->getTargetConstant (C2 - Leading, DL, VT));
1504+ SDNode *SRLI = CurDAG->getMachineNode (
1505+ RISCV::SRLI, DL, VT, SDValue (SRAI, 0 ),
1506+ CurDAG->getTargetConstant (Leading + Trailing, DL, VT));
1507+ SDNode *SLLI = CurDAG->getMachineNode (
1508+ RISCV::SLLI, DL, VT, SDValue (SRLI, 0 ),
1509+ CurDAG->getTargetConstant (Trailing, DL, VT));
1510+ ReplaceNode (Node, SLLI);
1511+ return ;
1512+ }
1513+ }
14931514 }
14941515
14951516 // If C1 masks off the upper bits only (but can't be formed as an
@@ -3032,6 +3053,33 @@ bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt,
30323053 return true ;
30333054 }
30343055 }
3056+ } else if (N0.getOpcode () == ISD::SRA && N0.hasOneUse () &&
3057+ isa<ConstantSDNode>(N.getOperand (1 ))) {
3058+ uint64_t Mask = N.getConstantOperandVal (1 );
3059+ unsigned C2 = N0.getConstantOperandVal (1 );
3060+
3061+ // Look for (and (sra y, c2), c1) where c1 is a shifted mask with c3
3062+ // leading zeros and c4 trailing zeros. If c2 is greater than c3, we can
3063+ // use (srli (srai y, c2 - c3), c3 + c4) followed by a SHXADD with c4 as
3064+ // the X amount.
3065+ if (isShiftedMask_64 (Mask)) {
3066+ unsigned XLen = Subtarget->getXLen ();
3067+ unsigned Leading = XLen - llvm::bit_width (Mask);
3068+ unsigned Trailing = llvm::countr_zero (Mask);
3069+ if (C2 > Leading && Leading > 0 && Trailing == ShAmt) {
3070+ SDLoc DL (N);
3071+ EVT VT = N.getValueType ();
3072+ Val = SDValue (CurDAG->getMachineNode (
3073+ RISCV::SRAI, DL, VT, N0.getOperand (0 ),
3074+ CurDAG->getTargetConstant (C2 - Leading, DL, VT)),
3075+ 0 );
3076+ Val = SDValue (CurDAG->getMachineNode (
3077+ RISCV::SRLI, DL, VT, Val,
3078+ CurDAG->getTargetConstant (Leading + ShAmt, DL, VT)),
3079+ 0 );
3080+ return true ;
3081+ }
3082+ }
30353083 }
30363084 } else if (bool LeftShift = N.getOpcode () == ISD::SHL;
30373085 (LeftShift || N.getOpcode () == ISD::SRL) &&
0 commit comments