@@ -47035,10 +47035,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
4703547035  if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
4703647036    return V;
4703747037
47038-   // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
47039-   // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
47040-   // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
47041-   // depending on sign of (SarConst - [56,48,32,24,16])
47038+   // fold (SRA (SHL X, ShlConst), SraConst)
47039+   // into (SHL (sext_in_reg X), ShlConst - SraConst)
47040+   //   or (sext_in_reg X)
47041+   //   or (SRA (sext_in_reg X), SraConst - ShlConst)
47042+   // depending on relation between SraConst and ShlConst.
47043+   // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
47044+   // us to do the sext_in_reg from corresponding bit.
4704247045
4704347046  // sexts in X86 are MOVs. The MOVs have the same code size
4704447047  // as above SHIFTs (only SHIFT on 1 has lower code size).
@@ -47054,29 +47057,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
4705447057  SDValue N00 = N0.getOperand(0);
4705547058  SDValue N01 = N0.getOperand(1);
4705647059  APInt ShlConst = N01->getAsAPIntVal();
47057-   APInt SarConst  = N1->getAsAPIntVal();
47060+   APInt SraConst  = N1->getAsAPIntVal();
4705847061  EVT CVT = N1.getValueType();
4705947062
47060-   if (SarConst.isNegative())
47063+   if (CVT != N01.getValueType())
47064+     return SDValue();
47065+   if (SraConst.isNegative())
4706147066    return SDValue();
4706247067
4706347068  for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
4706447069    unsigned ShiftSize = SVT.getSizeInBits();
47065-     // skipping types without corresponding sext/zext and
47066-     // ShlConst that is not one of [56,48,32,24,16]
47070+     // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
4706747071    if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
4706847072      continue;
4706947073    SDLoc DL(N);
4707047074    SDValue NN =
4707147075        DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
47072-     SarConst = SarConst - (Size - ShiftSize);
47073-     if (SarConst == 0)
47076+     if (SraConst.eq(ShlConst))
4707447077      return NN;
47075-     if (SarConst.isNegative( ))
47078+     if (SraConst.ult(ShlConst ))
4707647079      return DAG.getNode(ISD::SHL, DL, VT, NN,
47077-                          DAG.getConstant(-SarConst , DL, CVT));
47080+                          DAG.getConstant(ShlConst - SraConst , DL, CVT));
4707847081    return DAG.getNode(ISD::SRA, DL, VT, NN,
47079-                        DAG.getConstant(SarConst , DL, CVT));
47082+                        DAG.getConstant(SraConst - ShlConst , DL, CVT));
4708047083  }
4708147084  return SDValue();
4708247085}
0 commit comments