Skip to content
5 changes: 3 additions & 2 deletions llvm/include/llvm/Support/TypeSize.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
/// This function tells the caller whether the element count is known at
/// compile time to be a multiple of the scalar value RHS.
constexpr bool isKnownMultipleOf(ScalarTy RHS) const {
return getKnownMinValue() % RHS == 0;
return RHS != 0 && getKnownMinValue() % RHS == 0;
}

/// Returns whether or not the callee is known to be a multiple of RHS.
Expand All @@ -191,7 +191,8 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
// x % y == 0 !=> x % (vscale * y) == 0
if (!isScalable() && RHS.isScalable())
return false;
return getKnownMinValue() % RHS.getKnownMinValue() == 0;
return RHS.getKnownMinValue() != 0 &&
getKnownMinValue() % RHS.getKnownMinValue() == 0;
}

// Return the minimum value with the assumption that the count is exact.
Expand Down
39 changes: 25 additions & 14 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18785,21 +18785,25 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();

unsigned NumUses = N->use_size();
// Count the number of users which are extract_vectors.
unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
});

auto MaskEC = N->getValueType(0).getVectorElementCount();
if (!MaskEC.isKnownMultipleOf(NumUses))
if (!MaskEC.isKnownMultipleOf(NumExts))
return SDValue();

ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();

SmallVector<SDNode *> Extracts(NumUses, nullptr);
SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();
continue;

// Ensure the extract type is correct (e.g. if NumUses is 4 and
// Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
return SDValue();
Expand All @@ -18820,32 +18824,39 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,

SDValue Idx = N->getOperand(0);
SDValue TC = N->getOperand(1);
EVT OpVT = Idx.getValueType();
if (OpVT != MVT::i64) {
if (Idx.getValueType() != MVT::i64) {
Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
}

// Create the whilelo_x2 intrinsics from each pair of extracts
EVT ExtVT = Extracts[0]->getValueType(0);
EVT DoubleExtVT = ExtVT.getDoubleNumVectorElementsVT(*DAG.getContext());
auto R =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
SmallVector<SDValue> Concats = {DAG.getNode(
ISD::CONCAT_VECTORS, DL, DoubleExtVT, R.getValue(0), R.getValue(1))};

if (NumUses == 2)
return SDValue(N, 0);
if (NumExts == 2) {
assert(N->getValueType(0) == DoubleExtVT);
return Concats[0];
}

auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
for (unsigned I = 2; I < NumUses; I += 2) {
auto Elts =
DAG.getElementCount(DL, MVT::i64, ExtVT.getVectorElementCount() * 2);
for (unsigned I = 2; I < NumExts; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
Idx = DAG.getNode(ISD::UADDSAT, DL, MVT::i64, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[I], R.getValue(0));
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT,
R.getValue(0), R.getValue(1)));
}

return SDValue(N, 0);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Concats);
}

// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
Expand Down
181 changes: 181 additions & 0 deletions llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
ret void
}

; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.

define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB11_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB11_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
br label %if.end

if.end:
ret void
}

; Extra use of the get_active_lane_mask from an extractelement, which is
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.

define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding a similar test for the NumExts != 2 case, if only to see if that better exposes the issues I believe exist in the PR as it stands today.

; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB12_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB12_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
%v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
%v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
%elt0 = extractelement <vscale x 8 x i1> %r, i64 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
br label %if.end

if.end:
ret void
}

define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB13_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB13_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: cnth x8
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4)
%v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
%v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12)
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
br label %if.end

if.end:
ret void
}

define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB14_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB14_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: cntw x8
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
%v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
%elt0 = extractelement <vscale x 8 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
br label %if.end

if.end:
ret void
}

declare void @use(...)

attributes #0 = { nounwind }
1 change: 1 addition & 0 deletions llvm/unittests/Support/TypeSizeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ static_assert(ElementCount::getFixed(8).divideCoefficientBy(2) ==
static_assert(ElementCount::getFixed(8).multiplyCoefficientBy(3) ==
ElementCount::getFixed(24));
static_assert(ElementCount::getFixed(8).isKnownMultipleOf(2));
static_assert(!ElementCount::getFixed(8).isKnownMultipleOf(0));

constexpr TypeSize TSFixed0 = TypeSize::getFixed(0);
constexpr TypeSize TSFixed1 = TypeSize::getFixed(1);
Expand Down