Skip to content

[InstCombine] Fold sext(trunc nsw) and zext(trunc nuw) #88609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

YanWQ-monad
Copy link
Contributor

@YanWQ-monad YanWQ-monad commented Apr 13, 2024

Fold

  • sext (trunc nsw X to Y) to Z to cast (nsw) X to Z, and
  • zext (trunc nuw X to Y) to Z to cast (nuw) X to Z

Alive2 proofs:

Closes #98017.

@llvmbot
Copy link
Member

llvmbot commented Apr 13, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Monad (YanWQ-monad)

Changes

Fold

  • sext (trunc nsw X to Y) to Z to cast (nsw) X to Z, and
  • zext (trunc nuw X to Y) to Z to cast (nuw) X to Z

Alive2 proofs:


Full diff: https://github.com/llvm/llvm-project/pull/88609.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+21-1)
  • (modified) llvm/test/Transforms/InstCombine/sext.ll (+41)
  • (modified) llvm/test/Transforms/InstCombine/zext.ll (+41)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 437e9b92c7032f..91c149305bb76c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1188,9 +1188,20 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
   if (auto *CSrc = dyn_cast<TruncInst>(Src)) {   // A->B->C cast
     // TODO: Subsume this into EvaluateInDifferentType.
 
+    Value *A = CSrc->getOperand(0);
+    // If TRUNC has nuw flag, then convert directly to final type.
+    if (CSrc->hasNoUnsignedWrap()) {
+      CastInst *I =
+          CastInst::CreateIntegerCast(A, DestTy, /* isSigned */ false);
+      if (auto *ZExt = dyn_cast<ZExtInst>(I))
+        ZExt->setNonNeg();
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoUnsignedWrap(true);
+      return I;
+    }
+
     // Get the sizes of the types involved.  We know that the intermediate type
     // will be smaller than A or C, but don't know the relation between A and C.
-    Value *A = CSrc->getOperand(0);
     unsigned SrcSize = A->getType()->getScalarSizeInBits();
     unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
     unsigned DstSize = DestTy->getScalarSizeInBits();
@@ -1467,6 +1478,15 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
     if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
       return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
 
+    // If trunc has nsw flag, then convert directly to final type.
+    auto *CSrc = static_cast<TruncInst *>(Src);
+    if (CSrc->hasNoSignedWrap()) {
+      CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
+      if (auto *Trunc = dyn_cast<TruncInst>(I))
+        Trunc->setHasNoSignedWrap(true);
+      return I;
+    }
+
     // If input is a trunc from the destination type, then convert into shifts.
     if (Src->hasOneUse() && X->getType() == DestTy) {
       // sext (trunc X) --> ashr (shl X, C), C
diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll
index e3b6058ce7f806..9eae03470a4693 100644
--- a/llvm/test/Transforms/InstCombine/sext.ll
+++ b/llvm/test/Transforms/InstCombine/sext.ll
@@ -423,3 +423,44 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) {
   %s = sext i8 %a to i64
   ret i64 %s
 }
+
+define i32 @sext_trunc_nsw(i16 %x) {
+; CHECK-LABEL: @sext_trunc_nsw(
+; CHECK-NEXT:    [[E:%.*]] = sext i16 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc nsw i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @sext_trunc_nsw_2(i32 %x) {
+; CHECK-LABEL: @sext_trunc_nsw_2(
+; CHECK-NEXT:    [[E:%.*]] = trunc nsw i32 [[X:%.*]] to i16
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nsw i32 %x to i8
+  %e = sext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @sext_trunc_nsw_vec(
+; CHECK-NEXT:    [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E]]
+;
+  %c = trunc nsw <2 x i16> %x to <2 x i8>
+  %e = sext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @sext_trunc(i16 %x) {
+; CHECK-LABEL: @sext_trunc(
+; CHECK-NEXT:    [[C:%.*]] = trunc i16 [[X:%.*]] to i8
+; CHECK-NEXT:    [[E:%.*]] = sext i8 [[C]] to i32
+; CHECK-NEXT:    ret i32 [[E]]
+;
+  %c = trunc i16 %x to i8
+  %e = sext i8 %c to i32
+  ret i32 %e
+}
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 88cd9c70af40d8..16e7ef143cef9e 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -867,3 +867,44 @@ entry:
   %res = zext nneg i2 %x to i32
   ret i32 %res
 }
+
+define i32 @zext_trunc_nuw(i16 %x) {
+; CHECK-LABEL: @zext_trunc_nuw(
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc nuw i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}
+
+define i16 @zext_trunc_nuw_2(i32 %x) {
+; CHECK-LABEL: @zext_trunc_nuw_2(
+; CHECK-NEXT:    [[E:%.*]] = trunc nuw i32 [[X:%.*]] to i16
+; CHECK-NEXT:    ret i16 [[E]]
+;
+  %c = trunc nuw i32 %x to i8
+  %e = zext i8 %c to i16
+  ret i16 %e
+}
+
+define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) {
+; CHECK-LABEL: @zext_trunc_nuw_vec(
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[E1]]
+;
+  %c = trunc nuw <2 x i16> %x to <2 x i8>
+  %e = zext <2 x i8> %c to <2 x i32>
+  ret <2 x i32> %e
+}
+
+define i32 @zext_trunc(i16 %x) {
+; CHECK-LABEL: @zext_trunc(
+; CHECK-NEXT:    [[E:%.*]] = and i16 [[X:%.*]], 255
+; CHECK-NEXT:    [[E1:%.*]] = zext nneg i16 [[E]] to i32
+; CHECK-NEXT:    ret i32 [[E1]]
+;
+  %c = trunc i16 %x to i8
+  %e = zext i8 %c to i32
+  ret i32 %e
+}

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Apr 13, 2024
@YanWQ-monad
Copy link
Contributor Author

This fold unfortunately causes information loss, which may hinder other folds.

For example, Y = zext (trunc nuw i8 X to i1) to i8 would be folded to Y = X, but the information that Y ranges in [0, 1] is lost. This loss might be hard to recover, unless we could retrieval it from X.

For the above reason, I am not sure if this fold is worth a try.

@goldsteinn
Copy link
Contributor

This fold unfortunately causes information loss, which may hinder other folds.

For example, Y = zext (trunc nuw i8 X to i1) to i8 would be folded to Y = X, but the information that Y ranges in [0, 1] is lost. This loss might be hard to recover, unless we could retrieval it from X.

For the above reason, I am not sure if this fold is worth a try.

maybe the right place is DAGCombiner then.

@nikic
Copy link
Contributor

nikic commented Apr 15, 2024

Doing this fold is the whole purpose of the flags, so if we can't do it, we may as well drop them again :)

Looking at the diffs, I think it looks ok on average? Note that we already essentially do this fold just via computeKnownBits/ComputeNumSignBits, which is also why there is such a small number of diffs overall. I think the extra changes you see are due to interactions of IPSCCP and InstCombine or something like that. I think we would get some more interesting cases after #88686.

A general issue I see in the diffs is that we're not very good at narrowing i8 "booleans" down to i1 when loop phis are involved, as computeKnownBits() can't look through them. I think this issue accounts for most of the regressions, but it's also tricky to solve...

@nikic
Copy link
Contributor

nikic commented Jun 25, 2024

@dtcxzyw Could you please rerun tests for this patch?

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jun 26, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Jun 26, 2024

@dtcxzyw Could you please rerun tests for this patch?

Done.

@andjo403
Copy link
Contributor

andjo403 commented Oct 26, 2024

noticed that the small number of changes in llvm-opt-benchmark is due to most cases is already handled in the

if (shouldChangeType(SrcTy, DestTy) &&
canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) {
assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
"Can't clear more bits than in SrcTy");
// Okay, we can transform this! Insert the new expression now.
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
" to avoid zero extend: "
<< Zext << '\n');
Value *Res = EvaluateInDifferentType(Src, DestTy, false);
assert(Res->getType() == DestTy);
// Preserve debug values referring to Src if the zext is its last use.
if (auto *SrcOp = dyn_cast<Instruction>(Src))
if (SrcOp->hasOneUse())
replaceAllDbgUsesWith(*SrcOp, *Res, Zext, DT);
uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear;
uint32_t DestBitSize = DestTy->getScalarSizeInBits();
// If the high bits are already filled with zeros, just replace this
// cast with the result.
if (MaskedValueIsZero(Res,
APInt::getHighBitsSet(DestBitSize,
DestBitSize - SrcBitsKept),
0, &Zext))
return replaceInstUsesWith(Zext, Res);
// We need to emit an AND to clear the high bits.
Constant *C = ConstantInt::get(Res->getType(),
APInt::getLowBitsSet(DestBitSize, SrcBitsKept));
return BinaryOperator::CreateAnd(Res, C);
}
and
https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp#L1494-L1513
so this fold is not executed that mush.

the only cases that I can see will be handled by this fold when it is placed after the linked code is:

  • when the src or target type is not defined in target Datalayout
  • when the src or target type is a vector
  • when there is multiple uses of the trunc instruction

if the "target datalayout" in the test files is updated with "n8:16:32:64" most of the test will no longer pass.

I tried and moved the fold before the code linked and got 3k files update in llvm-opt-benchmark.
see this commit main...andjo403:llvm-project:truncNuwNswFold

@goldsteinn
Copy link
Contributor

your proofs seem overly complex. I think the following suffice: https://alive2.llvm.org/ce/z/WbDanN

@@ -1188,9 +1188,19 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
if (auto *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast
// TODO: Subsume this into EvaluateInDifferentType.

Value *A = CSrc->getOperand(0);
// If trunc has nuw flag, then convert directly to final type.
if (CSrc->hasNoUnsignedWrap()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also handle nsw if the zext has nneg.

@YanWQ-monad
Copy link
Contributor Author

your proofs seem overly complex. I think the following suffice: https://alive2.llvm.org/ce/z/WbDanN

That's because Alive2 didn't support nsw and nuw at that time, then I used some tricks to emulate it. :) And yes, I should update the proofs. Thanks for the reminder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Missed optimization for zext(trunc nuw(x))
6 participants