@@ -8585,68 +8585,43 @@ static SDValue performPostLD1Combine(SDNode *N,
8585
8585
return SDValue ();
8586
8586
}
8587
8587
8588
- // / Target-specific DAG combine for the across vector reduction.
8589
- // / This function specifically handles the final clean-up step of a vector
8590
- // / reduction produced by the LoopVectorizer. It is the log2-shuffle pattern,
8591
- // / consisting of log2(NumVectorElements) steps and, in each step, 2^(s)
8592
- // / elements are reduced, where s is an induction variable from 0
8593
- // / to log2(NumVectorElements).
8594
- // / For example,
8595
- // / %1 = vector_shuffle %0, <2,3,u,u>
8596
- // / %2 = add %0, %1
8597
- // / %3 = vector_shuffle %2, <1,u,u,u>
8598
- // / %4 = add %2, %3
8599
- // / %5 = extract_vector_elt %4, 0
8600
- // / becomes :
8601
- // / %0 = uaddv %0
8602
- // / %1 = extract_vector_elt %0, 0
8603
- // /
8604
- // / FIXME: Currently this function is implemented and tested specifically
8605
- // / for the add reduction. We could also support other types of across lane
8606
- // / reduction available in AArch64, including SMAXV, SMINV, UMAXV, UMINV,
8607
- // / SADDLV, UADDLV, FMAXNMV, FMAXV, FMINNMV, FMINV.
8608
- static SDValue
8609
- performAcrossLaneReductionCombine (SDNode *N, SelectionDAG &DAG,
8610
- const AArch64Subtarget *Subtarget) {
8611
- if (!Subtarget->hasNEON ())
8588
+ // / This function handles the log2-shuffle pattern produced by the
8589
+ // / LoopVectorizer for the across vector reduction. It consists of
8590
+ // / log2(NumVectorElements) steps and, in each step, 2^(s) elements
8591
+ // / are reduced, where s is an induction variable from 0 to
8592
+ // / log2(NumVectorElements).
8593
+ static SDValue tryMatchAcrossLaneShuffleForReduction (SDNode *N, SDValue OpV,
8594
+ unsigned Op,
8595
+ SelectionDAG &DAG) {
8596
+ EVT VTy = OpV->getOperand (0 ).getValueType ();
8597
+ if (!VTy.isVector ())
8612
8598
return SDValue ();
8613
- SDValue N0 = N->getOperand (0 );
8614
- SDValue N1 = N->getOperand (1 );
8615
8599
8616
- // Check if the input vector is fed by the operator we want to handle.
8617
- // We specifically check only ADD for now.
8618
- if (N0->getOpcode () != ISD::ADD)
8619
- return SDValue ();
8620
-
8621
- // The vector extract idx must constant zero because we only expect the final
8622
- // result of the reduction is placed in lane 0.
8623
- if (!isa<ConstantSDNode>(N1) || cast<ConstantSDNode>(N1)->getZExtValue ())
8624
- return SDValue ();
8625
-
8626
- EVT EltTy = N0.getValueType ().getVectorElementType ();
8627
- if (EltTy != MVT::i32 && EltTy != MVT::i16 && EltTy != MVT::i8)
8628
- return SDValue ();
8629
-
8630
- int NumVecElts = N0.getValueType ().getVectorNumElements ();
8600
+ int NumVecElts = VTy.getVectorNumElements ();
8631
8601
if (NumVecElts != 4 && NumVecElts != 8 && NumVecElts != 16 )
8632
8602
return SDValue ();
8633
8603
8634
8604
int NumExpectedSteps = APInt (8 , NumVecElts).logBase2 ();
8635
- SDValue PreOp = N0 ;
8605
+ SDValue PreOp = OpV ;
8636
8606
// Iterate over each step of the across vector reduction.
8637
8607
for (int CurStep = 0 ; CurStep != NumExpectedSteps; ++CurStep) {
8638
- // We specifically check ADD for now.
8639
- if (PreOp.getOpcode () != ISD::ADD)
8640
- return SDValue ();
8641
8608
SDValue CurOp = PreOp.getOperand (0 );
8642
8609
SDValue Shuffle = PreOp.getOperand (1 );
8643
8610
if (Shuffle.getOpcode () != ISD::VECTOR_SHUFFLE) {
8644
- // Try to swap the 1st and 2nd operand as add is commutative.
8611
+ // Try to swap the 1st and 2nd operand as add and min/max instructions
8612
+ // are commutative.
8645
8613
CurOp = PreOp.getOperand (1 );
8646
8614
Shuffle = PreOp.getOperand (0 );
8647
8615
if (Shuffle.getOpcode () != ISD::VECTOR_SHUFFLE)
8648
8616
return SDValue ();
8649
8617
}
8618
+
8619
+ // Check if the input vector is fed by the operator we want to handle,
8620
+ // except the last step; the very first input vector is not necessarily
8621
+ // the same operator we are handling.
8622
+ if (CurOp.getOpcode () != Op && (CurStep != (NumExpectedSteps - 1 )))
8623
+ return SDValue ();
8624
+
8650
8625
// Check if it forms one step of the across vector reduction.
8651
8626
// E.g.,
8652
8627
// %cur = add %1, %0
@@ -8674,11 +8649,169 @@ performAcrossLaneReductionCombine(SDNode *N, SelectionDAG &DAG,
8674
8649
8675
8650
PreOp = CurOp;
8676
8651
}
8652
+ unsigned Opcode;
8653
+ switch (Op) {
8654
+ default :
8655
+ llvm_unreachable (" Unexpected operator for across vector reduction" );
8656
+ case ISD::ADD:
8657
+ Opcode = AArch64ISD::UADDV;
8658
+ break ;
8659
+ case ISD::SMAX:
8660
+ Opcode = AArch64ISD::SMAXV;
8661
+ break ;
8662
+ case ISD::UMAX:
8663
+ Opcode = AArch64ISD::UMAXV;
8664
+ break ;
8665
+ case ISD::SMIN:
8666
+ Opcode = AArch64ISD::SMINV;
8667
+ break ;
8668
+ case ISD::UMIN:
8669
+ Opcode = AArch64ISD::UMINV;
8670
+ break ;
8671
+ }
8677
8672
SDLoc DL (N);
8678
- return DAG.getNode (
8679
- ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType (0 ),
8680
- DAG.getNode (AArch64ISD::UADDV, DL, PreOp.getSimpleValueType (), PreOp),
8681
- DAG.getConstant (0 , DL, MVT::i64));
8673
+ return DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType (0 ),
8674
+ DAG.getNode (Opcode, DL, PreOp.getSimpleValueType (), PreOp),
8675
+ DAG.getConstant (0 , DL, MVT::i64));
8676
+ }
8677
+
8678
+ // / Target-specific DAG combine for the across vector min/max reductions.
8679
+ // / This function specifically handles the final clean-up step of the vector
8680
+ // / min/max reductions produced by the LoopVectorizer. It is the log2-shuffle
8681
+ // / pattern, which narrows down and finds the final min/max value from all
8682
+ // / elements of the vector.
8683
+ // / For example, for a <16 x i8> vector :
8684
+ // / svn0 = vector_shuffle %0, undef<8,9,10,11,12,13,14,15,u,u,u,u,u,u,u,u>
8685
+ // / %smax0 = smax %arr, svn0
8686
+ // / %svn1 = vector_shuffle %smax0, undef<4,5,6,7,u,u,u,u,u,u,u,u,u,u,u,u>
8687
+ // / %smax1 = smax %smax0, %svn1
8688
+ // / %svn2 = vector_shuffle %smax1, undef<2,3,u,u,u,u,u,u,u,u,u,u,u,u,u,u>
8689
+ // / %smax2 = smax %smax1, svn2
8690
+ // / %svn3 = vector_shuffle %smax2, undef<1,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u>
8691
+ // / %sc = setcc %smax2, %svn3, gt
8692
+ // / %n0 = extract_vector_elt %sc, #0
8693
+ // / %n1 = extract_vector_elt %smax2, #0
8694
+ // / %n2 = extract_vector_elt $smax2, #1
8695
+ // / %result = select %n0, %n1, n2
8696
+ // / becomes :
8697
+ // / %1 = smaxv %0
8698
+ // / %result = extract_vector_elt %1, 0
8699
+ // / FIXME: Currently this function matches only SMAXV, UMAXV, SMINV, and UMINV.
8700
+ // / We could also support other types of across lane reduction available
8701
+ // / in AArch64, including FMAXNMV, FMAXV, FMINNMV, and FMINV.
8702
+ static SDValue
8703
+ performAcrossLaneMinMaxReductionCombine (SDNode *N, SelectionDAG &DAG,
8704
+ const AArch64Subtarget *Subtarget) {
8705
+ if (!Subtarget->hasNEON ())
8706
+ return SDValue ();
8707
+
8708
+ SDValue N0 = N->getOperand (0 );
8709
+ SDValue IfTrue = N->getOperand (1 );
8710
+ SDValue IfFalse = N->getOperand (2 );
8711
+
8712
+ // Check if the SELECT merges up the final result of the min/max
8713
+ // from a vector.
8714
+ if (N0.getOpcode () != ISD::EXTRACT_VECTOR_ELT ||
8715
+ IfTrue.getOpcode () != ISD::EXTRACT_VECTOR_ELT ||
8716
+ IfFalse.getOpcode () != ISD::EXTRACT_VECTOR_ELT)
8717
+ return SDValue ();
8718
+
8719
+ // Expect N0 is fed by SETCC.
8720
+ SDValue SetCC = N0.getOperand (0 );
8721
+ EVT SetCCVT = SetCC.getValueType ();
8722
+ if (SetCC.getOpcode () != ISD::SETCC || !SetCCVT.isVector () ||
8723
+ SetCCVT.getVectorElementType () != MVT::i1)
8724
+ return SDValue ();
8725
+
8726
+ SDValue VectorOp = SetCC.getOperand (0 );
8727
+ unsigned Op = VectorOp->getOpcode ();
8728
+ // Check if the input vector is fed by the operator we want to handle.
8729
+ if (Op != ISD::SMAX && Op != ISD::UMAX && Op != ISD::SMIN && Op != ISD::UMIN)
8730
+ return SDValue ();
8731
+
8732
+ EVT VTy = VectorOp.getValueType ();
8733
+ if (!VTy.isVector ())
8734
+ return SDValue ();
8735
+
8736
+ EVT EltTy = VTy.getVectorElementType ();
8737
+ if (EltTy != MVT::i32 && EltTy != MVT::i16 && EltTy != MVT::i8)
8738
+ return SDValue ();
8739
+
8740
+ // Check if extracting from the same vector.
8741
+ // For example,
8742
+ // %sc = setcc %vector, %svn1, gt
8743
+ // %n0 = extract_vector_elt %sc, #0
8744
+ // %n1 = extract_vector_elt %vector, #0
8745
+ // %n2 = extract_vector_elt $vector, #1
8746
+ if (!(VectorOp == IfTrue->getOperand (0 ) &&
8747
+ VectorOp == IfFalse->getOperand (0 )))
8748
+ return SDValue ();
8749
+
8750
+ // Check if the condition code is matched with the operator type.
8751
+ ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand (2 ))->get ();
8752
+ if ((Op == ISD::SMAX && CC != ISD::SETGT && CC != ISD::SETGE) ||
8753
+ (Op == ISD::UMAX && CC != ISD::SETUGT && CC != ISD::SETUGE) ||
8754
+ (Op == ISD::SMIN && CC != ISD::SETLT && CC != ISD::SETLE) ||
8755
+ (Op == ISD::UMIN && CC != ISD::SETULT && CC != ISD::SETULE))
8756
+ return SDValue ();
8757
+
8758
+ // Expect to check only lane 0 from the vector SETCC.
8759
+ if (!isa<ConstantSDNode>(N0.getOperand (1 )) ||
8760
+ cast<ConstantSDNode>(N0.getOperand (1 ))->getZExtValue () != 0 )
8761
+ return SDValue ();
8762
+
8763
+ // Expect to extract the true value from lane 0.
8764
+ if (!isa<ConstantSDNode>(IfTrue.getOperand (1 )) ||
8765
+ cast<ConstantSDNode>(IfTrue.getOperand (1 ))->getZExtValue () != 0 )
8766
+ return SDValue ();
8767
+
8768
+ // Expect to extract the false value from lane 1.
8769
+ if (!isa<ConstantSDNode>(IfFalse.getOperand (1 )) ||
8770
+ cast<ConstantSDNode>(IfFalse.getOperand (1 ))->getZExtValue () != 1 )
8771
+ return SDValue ();
8772
+
8773
+ return tryMatchAcrossLaneShuffleForReduction (N, SetCC, Op, DAG);
8774
+ }
8775
+
8776
+ // / Target-specific DAG combine for the across vector add reduction.
8777
+ // / This function specifically handles the final clean-up step of the vector
8778
+ // / add reduction produced by the LoopVectorizer. It is the log2-shuffle
8779
+ // / pattern, which adds all elements of a vector together.
8780
+ // / For example, for a <4 x i32> vector :
8781
+ // / %1 = vector_shuffle %0, <2,3,u,u>
8782
+ // / %2 = add %0, %1
8783
+ // / %3 = vector_shuffle %2, <1,u,u,u>
8784
+ // / %4 = add %2, %3
8785
+ // / %result = extract_vector_elt %4, 0
8786
+ // / becomes :
8787
+ // / %0 = uaddv %0
8788
+ // / %result = extract_vector_elt %0, 0
8789
+ static SDValue
8790
+ performAcrossLaneAddReductionCombine (SDNode *N, SelectionDAG &DAG,
8791
+ const AArch64Subtarget *Subtarget) {
8792
+ if (!Subtarget->hasNEON ())
8793
+ return SDValue ();
8794
+ SDValue N0 = N->getOperand (0 );
8795
+ SDValue N1 = N->getOperand (1 );
8796
+
8797
+ // Check if the input vector is fed by the ADD.
8798
+ if (N0->getOpcode () != ISD::ADD)
8799
+ return SDValue ();
8800
+
8801
+ // The vector extract idx must constant zero because we only expect the final
8802
+ // result of the reduction is placed in lane 0.
8803
+ if (!isa<ConstantSDNode>(N1) || cast<ConstantSDNode>(N1)->getZExtValue () != 0 )
8804
+ return SDValue ();
8805
+
8806
+ EVT VTy = N0.getValueType ();
8807
+ if (!VTy.isVector ())
8808
+ return SDValue ();
8809
+
8810
+ EVT EltTy = VTy.getVectorElementType ();
8811
+ if (EltTy != MVT::i32 && EltTy != MVT::i16 && EltTy != MVT::i8)
8812
+ return SDValue ();
8813
+
8814
+ return tryMatchAcrossLaneShuffleForReduction (N, N0, ISD::ADD, DAG);
8682
8815
}
8683
8816
8684
8817
// / Target-specific DAG combine function for NEON load/store intrinsics
@@ -9259,8 +9392,12 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
9259
9392
return performBitcastCombine (N, DCI, DAG);
9260
9393
case ISD::CONCAT_VECTORS:
9261
9394
return performConcatVectorsCombine (N, DCI, DAG);
9262
- case ISD::SELECT:
9263
- return performSelectCombine (N, DCI);
9395
+ case ISD::SELECT: {
9396
+ SDValue RV = performSelectCombine (N, DCI);
9397
+ if (!RV.getNode ())
9398
+ RV = performAcrossLaneMinMaxReductionCombine (N, DAG, Subtarget);
9399
+ return RV;
9400
+ }
9264
9401
case ISD::VSELECT:
9265
9402
return performVSelectCombine (N, DCI.DAG );
9266
9403
case ISD::STORE:
@@ -9276,7 +9413,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
9276
9413
case ISD::INSERT_VECTOR_ELT:
9277
9414
return performPostLD1Combine (N, DCI, true );
9278
9415
case ISD::EXTRACT_VECTOR_ELT:
9279
- return performAcrossLaneReductionCombine (N, DAG, Subtarget);
9416
+ return performAcrossLaneAddReductionCombine (N, DAG, Subtarget);
9280
9417
case ISD::INTRINSIC_VOID:
9281
9418
case ISD::INTRINSIC_W_CHAIN:
9282
9419
switch (cast<ConstantSDNode>(N->getOperand (1 ))->getZExtValue ()) {
0 commit comments