@@ -47035,10 +47035,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
47035
47035
if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
47036
47036
return V;
47037
47037
47038
- // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
47039
- // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
47040
- // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
47041
- // depending on sign of (SarConst - [56,48,32,24,16])
47038
+ // fold (SRA (SHL X, ShlConst), SraConst)
47039
+ // into (SHL (sext_in_reg X), ShlConst - SraConst)
47040
+ // or (sext_in_reg X)
47041
+ // or (SRA (sext_in_reg X), SraConst - ShlConst)
47042
+ // depending on relation between SraConst and ShlConst.
47043
+ // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
47044
+ // us to do the sext_in_reg from corresponding bit.
47042
47045
47043
47046
// sexts in X86 are MOVs. The MOVs have the same code size
47044
47047
// as above SHIFTs (only SHIFT on 1 has lower code size).
@@ -47054,29 +47057,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
47054
47057
SDValue N00 = N0.getOperand(0);
47055
47058
SDValue N01 = N0.getOperand(1);
47056
47059
APInt ShlConst = N01->getAsAPIntVal();
47057
- APInt SarConst = N1->getAsAPIntVal();
47060
+ APInt SraConst = N1->getAsAPIntVal();
47058
47061
EVT CVT = N1.getValueType();
47059
47062
47060
- if (SarConst.isNegative())
47063
+ if (CVT != N01.getValueType())
47064
+ return SDValue();
47065
+ if (SraConst.isNegative())
47061
47066
return SDValue();
47062
47067
47063
47068
for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
47064
47069
unsigned ShiftSize = SVT.getSizeInBits();
47065
- // skipping types without corresponding sext/zext and
47066
- // ShlConst that is not one of [56,48,32,24,16]
47070
+ // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
47067
47071
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
47068
47072
continue;
47069
47073
SDLoc DL(N);
47070
47074
SDValue NN =
47071
47075
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
47072
- SarConst = SarConst - (Size - ShiftSize);
47073
- if (SarConst == 0)
47076
+ if (SraConst.eq(ShlConst))
47074
47077
return NN;
47075
- if (SarConst.isNegative( ))
47078
+ if (SraConst.ult(ShlConst ))
47076
47079
return DAG.getNode(ISD::SHL, DL, VT, NN,
47077
- DAG.getConstant(-SarConst , DL, CVT));
47080
+ DAG.getConstant(ShlConst - SraConst , DL, CVT));
47078
47081
return DAG.getNode(ISD::SRA, DL, VT, NN,
47079
- DAG.getConstant(SarConst , DL, CVT));
47082
+ DAG.getConstant(SraConst - ShlConst , DL, CVT));
47080
47083
}
47081
47084
return SDValue();
47082
47085
}
0 commit comments