Skip to content

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Sep 17, 2025

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and returns the concatenated results of get_active_lane_mask.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.


Full diff: https://github.com/llvm/llvm-project/pull/159360.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+21-8)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-7)
  • (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+75)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..9c7ecf944e763 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18693,21 +18693,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  unsigned NumUses = N->use_size();
+  // Count the number of users which are extract_vectors
+  // The only other valid users for this combine are ptest_first
+  // and reinterpret_cast.
+  unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+    return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+  });
+
   auto MaskEC = N->getValueType(0).getVectorElementCount();
-  if (!MaskEC.isKnownMultipleOf(NumUses))
+  if (!MaskEC.isKnownMultipleOf(NumExts))
     return SDValue();
 
-  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
   if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  SmallVector<SDNode *> Extracts(NumExts, nullptr);
   for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+        Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+      continue;
+
     if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
       return SDValue();
 
-    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // Ensure the extract type is correct (e.g. if NumExts is 4 and
     // the mask return type is nxv8i1, each extract should be nxv2i1.
     if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
       return SDValue();
@@ -18741,11 +18751,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   DCI.CombineTo(Extracts[0], R.getValue(0));
   DCI.CombineTo(Extracts[1], R.getValue(1));
 
-  if (NumUses == 2)
-    return SDValue(N, 0);
+  if (NumExts == 2) {
+    DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+    return SDValue(SDValue(N, 0));
+  }
 
   auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
-  for (unsigned I = 2; I < NumUses; I += 2) {
+  for (unsigned I = 2; I < NumExts; I += 2) {
     // After the first whilelo_x2, we need to increment the starting value.
     Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
     R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18753,6 +18765,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     DCI.CombineTo(Extracts[I + 1], R.getValue(1));
   }
 
+  DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
   return SDValue(N, 0);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bf3d47ac43607..069d08663fdea 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
       return PredOpcode;
 
-    // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
-    // redundant since WHILE performs an implicit PTEST with an all active
-    // mask.
-    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
-      return PredOpcode;
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+      auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+      if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+          PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+        return PredOpcode;
+
+      // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+      // redundant since WHILE performs an implicit PTEST with an all active
+      // mask.
+      if (getElementSizeForOpcode(MaskOpcode) ==
+          getElementSizeForOpcode(PredOpcode))
+        return PredOpcode;
+    }
 
     return {};
   }
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..3b18008605413 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,81 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
   ret void
 }
 
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB11_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB12_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+    %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+    %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
 declare void @use(...)
 
 attributes #0 = { nounwind }

// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[I], R.getValue(0));
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}

DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above but in this case there's the extra problem that you're replacing N with the first result of the last instance of emitted while_pair, which is even more wrong if the uses happened to be a PTEST_FIRST.

Comment on lines 18805 to 18807
if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth trying to special case based on the users beyond verifying the presence of the relevant ISD::EXTRACT_SUBVECTOR to prove the value of using the while_pair instructions. Even if the original while remains, the resulting code might be better because multiple extracts have been replaced by a single while_pair?

if (NumUses == 2)
return SDValue(N, 0);
if (NumExts == 2) {
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm misunderstanding something but I don't know how this works because SDValue(N, 0) and R.getValue(0) are going to have different result types? so the post combine DAG is likely to be broken.

Comment on lines 1498 to 1510
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
return PredOpcode;

// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
// redundant since WHILE performs an implicit PTEST with an all active
// mask.
if (getElementSizeForOpcode(MaskOpcode) ==
getElementSizeForOpcode(PredOpcode))
return PredOpcode;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I'm kind of hoping this is just fallout from the DAG being broken and will not be necessary once fixed.

; Extra use of the get_active_lane_mask from an extractelement, which is
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.

define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding a similar test for the NumExts != 2 case, if only to see if that better exposes the issues I believe exist in the PR as it stands today.

…bine

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.
…performActiveLaneMaskCombine

- Add tests for the 4 extracts case which will use ptest & reinterpret_cast
- Remove changes to canRemovePTestInstr
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *ST) {
if (DCI.isBeforeLegalize())
if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The !! looks a little odd. Is it possible to just use DCI.isBeforeLegalizeOps()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a mistake, it should be !DCI.isBeforeLegalizeOps()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is doing what you intend. Given !DCI.isBeforeLegalizeOps()) means AfterLegalizeVectorOps, you've effectively written if "before type legalisation" and "after vector ops legalisation", which is always going to be false because they are are opposite ends of selection.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now removed this altogether, since finding that a concat with 4+ inputs will be split after this combine without changing the stage at which it applies (although I've also made changes to create multiple concat_vectors in the latest commit too).

- Create multiple concats in performActiveLaneMaskCombine when there are more than 2 extracts

auto MaskEC = N->getValueType(0).getVectorElementCount();
if (!MaskEC.isKnownMultipleOf(NumUses))
if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to move the zero check into isKnownMultipleOf?

DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
SmallVector<SDValue> Concats = {DAG.getNode(
ISD::CONCAT_VECTORS, DL, DoubleExtVT, {R.getValue(0), R.getValue(1)})};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is wrapping the operands in {} necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I've removed them in both places

Comment on lines 18857 to 18858
Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT,
{R.getValue(0), R.getValue(1)}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is wrapping the operands in {} necessary?

…t.cpp

- Remove unnecessary {} from getNode in performActiveLaneMaskCombine
Comment on lines 18789 to 18790
// The only other valid users for this combine are ptest_first
// and reinterpret_cast.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last bit can be removed because the ptest_first restriction no longer applies.

@kmclaughlin-arm kmclaughlin-arm merged commit cf50bbf into llvm:main Sep 30, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…bine (llvm#159360)

The combine replaces a get_active_lane_mask used by two extract
subvectors with a single paired whilelo intrinsic. When the instruction
is used for control flow in a vector loop, an additional extract of element
0 may introduce other uses of the intrinsic such as ptest and reinterpret
cast, which is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number
of extract subvectors using the mask instead of the total number of uses,
and returns the concatenated results of get_active_lane_mask.
@kmclaughlin-arm kmclaughlin-arm deleted the alm-combine-ptest branch October 7, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants