Skip to content

Commit

Permalink
[ConstantRange] Estimate tighter lower (upper) bounds for masked bina…
Browse files Browse the repository at this point in the history
…ry and (or)
  • Loading branch information
zsrkmyn committed Dec 19, 2024
1 parent 4a7673d commit 90f7539
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 15 deletions.
18 changes: 9 additions & 9 deletions clang/test/CodeGen/AArch64/fpm-helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ extern "C" {
//
fpm_t test_init() { return __arm_fpm_init(); }

// CHECK-LABEL: define dso_local noundef i64 @test_src1_1(
// CHECK-LABEL: define dso_local noundef range(i64 0, -4) i64 @test_src1_1(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 -8
Expand All @@ -44,7 +44,7 @@ fpm_t test_src1_1() {
return __arm_set_fpm_src1_format(INIT_ONES, __ARM_FPM_E5M2);
}

// CHECK-LABEL: define dso_local noundef i64 @test_src1_2(
// CHECK-LABEL: define dso_local noundef range(i64 0, -4) i64 @test_src1_2(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 1
Expand All @@ -53,7 +53,7 @@ fpm_t test_src1_2() {
return __arm_set_fpm_src1_format(INIT_ZERO, __ARM_FPM_E4M3);
}

// CHECK-LABEL: define dso_local noundef i64 @test_src2_1(
// CHECK-LABEL: define dso_local noundef range(i64 0, -32) i64 @test_src2_1(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 -57
Expand All @@ -62,7 +62,7 @@ fpm_t test_src2_1() {
return __arm_set_fpm_src2_format(INIT_ONES, __ARM_FPM_E5M2);
}

// CHECK-LABEL: define dso_local noundef i64 @test_src2_2(
// CHECK-LABEL: define dso_local noundef range(i64 0, -32) i64 @test_src2_2(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 8
Expand All @@ -71,7 +71,7 @@ fpm_t test_src2_2() {
return __arm_set_fpm_src2_format(INIT_ZERO, __ARM_FPM_E4M3);
}

// CHECK-LABEL: define dso_local noundef i64 @test_dst1_1(
// CHECK-LABEL: define dso_local noundef range(i64 0, -256) i64 @test_dst1_1(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 -449
Expand All @@ -80,7 +80,7 @@ fpm_t test_dst1_1() {
return __arm_set_fpm_dst_format(INIT_ONES, __ARM_FPM_E5M2);
}

// CHECK-LABEL: define dso_local noundef i64 @test_dst2_2(
// CHECK-LABEL: define dso_local noundef range(i64 0, -256) i64 @test_dst2_2(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 64
Expand Down Expand Up @@ -139,21 +139,21 @@ fpm_t test_lscale() { return __arm_set_fpm_lscale(INIT_ZERO, 127); }
//
fpm_t test_lscale2() { return __arm_set_fpm_lscale2(INIT_ZERO, 63); }

// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_1(
// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_1(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 2147483648
//
fpm_t test_nscale_1() { return __arm_set_fpm_nscale(INIT_ZERO, -128); }

// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_2(
// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_2(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 2130706432
//
fpm_t test_nscale_2() { return __arm_set_fpm_nscale(INIT_ZERO, 127); }

// CHECK-LABEL: define dso_local noundef range(i64 0, 4294967296) i64 @test_nscale_3(
// CHECK-LABEL: define dso_local noundef range(i64 0, 4286578688) i64 @test_nscale_3(
// CHECK-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: ret i64 4278190080
Expand Down
106 changes: 100 additions & 6 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1520,15 +1520,102 @@ ConstantRange ConstantRange::binaryNot() const {
return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
}

/// Estimate the 'bit-masked AND' operation's lower bound.
///
/// E.g., given two ranges as follows (single quotes are separators and
/// have no meaning here),
///
/// LHS = [10'001'010, ; LLo
/// 10'100'000] ; LHi
/// RHS = [10'111'010, ; RLo
/// 10'111'100] ; RHi
///
/// we know that the higher 2 bits of the result is always '10'; and note that
/// there's at least one bit is 1 in LHS[3:6] (since the range is continuous),
/// and all bits in RHS[3:6] are 1, so we know the lower bound of the result is
/// 10'001'000.
///
/// The algorithm is as follows,
/// 1. we first calculate a mask to mask out the higher common bits by
/// Mask = (LLo ^ LHi) | (LLo ^ LHi) | (LLo ^ RLo);
/// Mask = set all non-leading-zero bits to 1 for Mask;
/// 2. find the bit field with at least 1 in LHS (i.e., bit 3:6 in the example)
/// after applying the mask, with
/// StartBit = BitWidth - (LLo & Mask).clz() - 1;
/// EndBit = BitWidth - (LHi & Mask).clz();
/// 3. check if all bits in [StartBit:EndBit] in RHS are 1, and all bits of
/// RLo and RHi in [StartBit:BitWidth] are same, and if so, the lower bound
/// can be updated to
/// LowerBound = LLo & Keep;
/// where Keep is a mask to mask out trailing bits (the lower 3 bits in the
/// example);
/// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower
/// bound with the smaller one.
static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS,
const ConstantRange &RHS) {
auto BitWidth = LHS.getBitWidth();
// If either is full set or unsigned wrapped, then the range must contain '0'
// which leads the lower bound to 0.
if ((LHS.isFullSet() || RHS.isFullSet()) ||
(LHS.isWrappedSet() || RHS.isWrappedSet()))
return APInt::getZero(BitWidth);

auto LLo = LHS.getLower();
auto LHi = LHS.getUpper() - 1;
auto RLo = RHS.getLower();
auto RHi = RHS.getUpper() - 1;

// Calculate the mask that mask out the higher common bits.
auto Mask = (LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo);
unsigned LeadingZeros = Mask.countLeadingZeros();
Mask.setLowBits(BitWidth - LeadingZeros);

auto estimateBound =
[BitWidth, &Mask](const APInt &ALo, const APInt &AHi, const APInt &BLo,
const APInt &BHi) -> std::optional<APInt> {
unsigned LeadingZeros = (ALo & Mask).countLeadingZeros();
if (LeadingZeros == BitWidth)
return std::nullopt;

unsigned StartBit = BitWidth - LeadingZeros - 1;

if (BLo.extractBits(BitWidth - StartBit, StartBit) !=
BHi.extractBits(BitWidth - StartBit, StartBit))
return std::nullopt;

unsigned EndBit = BitWidth - (AHi & Mask).countLeadingZeros();
if (!(BLo.extractBits(EndBit - StartBit, StartBit) &
BHi.extractBits(EndBit - StartBit, StartBit))
.isAllOnes())
return std::nullopt;

APInt Keep(BitWidth, 0);
Keep.setBits(StartBit, BitWidth);
return Keep & ALo;
};

auto LowerBoundByLHS = estimateBound(LLo, LHi, RLo, RHi);
auto LowerBoundByRHS = estimateBound(RLo, RHi, LLo, LHi);

if (LowerBoundByLHS && LowerBoundByRHS)
return LowerBoundByLHS->ult(*LowerBoundByRHS) ? *LowerBoundByLHS
: *LowerBoundByRHS;
if (LowerBoundByLHS)
return *LowerBoundByLHS;
if (LowerBoundByRHS)
return *LowerBoundByRHS;
return APInt::getZero(BitWidth);
}

ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const {
if (isEmptySet() || Other.isEmptySet())
return getEmpty();

ConstantRange KnownBitsRange =
fromKnownBits(toKnownBits() & Other.toKnownBits(), false);
ConstantRange UMinUMaxRange =
getNonEmpty(APInt::getZero(getBitWidth()),
APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
auto LowerBound = estimateBitMaskedAndLowerBound(*this, Other);
ConstantRange UMinUMaxRange = getNonEmpty(
LowerBound, APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
return KnownBitsRange.intersectWith(UMinUMaxRange);
}

Expand All @@ -1538,10 +1625,17 @@ ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const {

ConstantRange KnownBitsRange =
fromKnownBits(toKnownBits() | Other.toKnownBits(), false);

// ~a & ~b >= x
// <=> ~(~a & ~b) <= ~x
// <=> a | b <= ~x
// <=> a | b < ~x + 1 = -x
// thus, UpperBound(a | b) == -LowerBound(~a & ~b)
auto UpperBound =
-estimateBitMaskedAndLowerBound(binaryNot(), Other.binaryNot());
// Upper wrapped range.
ConstantRange UMaxUMinRange =
getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()),
APInt::getZero(getBitWidth()));
ConstantRange UMaxUMinRange = getNonEmpty(
APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), UpperBound);
return KnownBitsRange.intersectWith(UMaxUMinRange);
}

Expand Down
88 changes: 88 additions & 0 deletions llvm/test/Transforms/SCCP/range-and-or-bit-masked.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -passes=ipsccp %s | FileCheck %s

declare void @use(i1)

define i1 @test1(i64 %x) {
; CHECK-LABEL: @test1(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[COND:%.*]] = icmp ugt i64 [[X:%.*]], 65535
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
; CHECK-NEXT: [[MASK:%.*]] = and i64 [[X]], -65521
; CHECK-NEXT: ret i1 false
;
entry:
%cond = icmp ugt i64 %x, 65535
call void @llvm.assume(i1 %cond)
%mask = and i64 %x, -65521
%cmp = icmp eq i64 %mask, 0
ret i1 %cmp
}

define void @test.and(i64 %x, i64 %y) {
; CHECK-LABEL: @test.and(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[C0:%.*]] = icmp uge i64 [[X:%.*]], 138
; CHECK-NEXT: [[C1:%.*]] = icmp ule i64 [[X]], 161
; CHECK-NEXT: call void @llvm.assume(i1 [[C0]])
; CHECK-NEXT: call void @llvm.assume(i1 [[C1]])
; CHECK-NEXT: [[C2:%.*]] = icmp uge i64 [[Y:%.*]], 186
; CHECK-NEXT: [[C3:%.*]] = icmp ule i64 [[Y]], 188
; CHECK-NEXT: call void @llvm.assume(i1 [[C2]])
; CHECK-NEXT: call void @llvm.assume(i1 [[C3]])
; CHECK-NEXT: [[AND:%.*]] = and i64 [[X]], [[Y]]
; CHECK-NEXT: call void @use(i1 false)
; CHECK-NEXT: [[R1:%.*]] = icmp ult i64 [[AND]], 137
; CHECK-NEXT: call void @use(i1 [[R1]])
; CHECK-NEXT: ret void
;
entry:
%c0 = icmp uge i64 %x, 138 ; 0b10001010
%c1 = icmp ule i64 %x, 161 ; 0b10100000
call void @llvm.assume(i1 %c0)
call void @llvm.assume(i1 %c1)
%c2 = icmp uge i64 %y, 186 ; 0b10111010
%c3 = icmp ule i64 %y, 188 ; 0b10111110
call void @llvm.assume(i1 %c2)
call void @llvm.assume(i1 %c3)
%and = and i64 %x, %y
%r0 = icmp ult i64 %and, 136 ; 0b10001000
call void @use(i1 %r0) ; false
%r1 = icmp ult i64 %and, 137
call void @use(i1 %r1) ; unknown
ret void
}

define void @test.or(i64 %x, i64 %y) {
; CHECK-LABEL: @test.or(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[C0:%.*]] = icmp ule i64 [[X:%.*]], 117
; CHECK-NEXT: [[C1:%.*]] = icmp uge i64 [[X]], 95
; CHECK-NEXT: call void @llvm.assume(i1 [[C0]])
; CHECK-NEXT: call void @llvm.assume(i1 [[C1]])
; CHECK-NEXT: [[C2:%.*]] = icmp ule i64 [[Y:%.*]], 69
; CHECK-NEXT: [[C3:%.*]] = icmp uge i64 [[Y]], 67
; CHECK-NEXT: call void @llvm.assume(i1 [[C2]])
; CHECK-NEXT: call void @llvm.assume(i1 [[C3]])
; CHECK-NEXT: [[OR:%.*]] = or i64 [[X]], [[Y]]
; CHECK-NEXT: call void @use(i1 false)
; CHECK-NEXT: [[R1:%.*]] = icmp ugt i64 [[OR]], 118
; CHECK-NEXT: call void @use(i1 [[R1]])
; CHECK-NEXT: ret void
;
entry:
%c0 = icmp ule i64 %x, 117 ; 0b01110101
%c1 = icmp uge i64 %x, 95 ; 0b01011111
call void @llvm.assume(i1 %c0)
call void @llvm.assume(i1 %c1)
%c2 = icmp ule i64 %y, 69 ; 0b01000101
%c3 = icmp uge i64 %y, 67 ; 0b01000011
call void @llvm.assume(i1 %c2)
call void @llvm.assume(i1 %c3)
%or = or i64 %x, %y
%r0 = icmp ugt i64 %or, 119 ; 0b01110111
call void @use(i1 %r0) ; false
%r1 = icmp ugt i64 %or, 118
call void @use(i1 %r1) ; unknown
ret void
}

0 comments on commit 90f7539

Please sign in to comment.