diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 437e9b92c7032..022c2b937f191 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1188,9 +1188,19 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { if (auto *CSrc = dyn_cast(Src)) { // A->B->C cast // TODO: Subsume this into EvaluateInDifferentType. + Value *A = CSrc->getOperand(0); + // If trunc has nuw flag, then convert directly to final type. + if (CSrc->hasNoUnsignedWrap()) { + CastInst *I = CastInst::CreateIntegerCast(A, DestTy, /*isSigned=*/false); + if (auto *ZExt = dyn_cast(I)) + ZExt->setNonNeg(); + if (auto *Trunc = dyn_cast(I)) + Trunc->setHasNoUnsignedWrap(true); + return I; + } + // Get the sizes of the types involved. We know that the intermediate type // will be smaller than A or C, but don't know the relation between A and C. - Value *A = CSrc->getOperand(0); unsigned SrcSize = A->getType()->getScalarSizeInBits(); unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); unsigned DstSize = DestTy->getScalarSizeInBits(); @@ -1461,11 +1471,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { Value *X; if (match(Src, m_Trunc(m_Value(X)))) { - // If the input has more sign bits than bits truncated, then convert - // directly to final type. - unsigned XBitSize = X->getType()->getScalarSizeInBits(); - if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize) - return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true); + // If trunc has nsw flag, then convert directly to final type. + auto *CSrc = cast(Src); + if (CSrc->hasNoSignedWrap()) { + CastInst *I = CastInst::CreateIntegerCast(X, DestTy, /*isSigned=*/true); + if (auto *Trunc = dyn_cast(I)) + Trunc->setHasNoSignedWrap(true); + return I; + } // If input is a trunc from the destination type, then convert into shifts. if (Src->hasOneUse() && X->getType() == DestTy) { @@ -1478,6 +1491,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { // the logic shift to arithmetic shift and eliminate the cast to // intermediate type: // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C) + unsigned XBitSize = X->getType()->getScalarSizeInBits(); Value *Y; if (Src->hasOneUse() && match(X, m_LShr(m_Value(Y), diff --git a/llvm/test/Transforms/InstCombine/sext.ll b/llvm/test/Transforms/InstCombine/sext.ll index e3b6058ce7f80..ad02594f020bc 100644 --- a/llvm/test/Transforms/InstCombine/sext.ll +++ b/llvm/test/Transforms/InstCombine/sext.ll @@ -423,3 +423,53 @@ define i64 @smear_set_bit_different_dest_type_wider_dst(i32 %x) { %s = sext i8 %a to i64 ret i64 %s } + +define i32 @sext_trunc_nsw(i16 %x) { +; CHECK-LABEL: @sext_trunc_nsw( +; CHECK-NEXT: [[E:%.*]] = sext i16 [[X:%.*]] to i32 +; CHECK-NEXT: ret i32 [[E]] +; + %c = trunc nsw i16 %x to i8 + %e = sext i8 %c to i32 + ret i32 %e +} + +define i16 @sext_trunc_nsw_2(i32 %x) { +; CHECK-LABEL: @sext_trunc_nsw_2( +; CHECK-NEXT: [[E:%.*]] = trunc nsw i32 [[X:%.*]] to i16 +; CHECK-NEXT: ret i16 [[E]] +; + %c = trunc nsw i32 %x to i8 + %e = sext i8 %c to i16 + ret i16 %e +} + +define i16 @sext_trunc_nsw_3(i16 %x) { +; CHECK-LABEL: @sext_trunc_nsw_3( +; CHECK-NEXT: ret i16 [[E:%.*]] +; + %c = trunc nsw i16 %x to i8 + %e = sext i8 %c to i16 + ret i16 %e +} + +define <2 x i32> @sext_trunc_nsw_vec(<2 x i16> %x) { +; CHECK-LABEL: @sext_trunc_nsw_vec( +; CHECK-NEXT: [[E:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32> +; CHECK-NEXT: ret <2 x i32> [[E]] +; + %c = trunc nsw <2 x i16> %x to <2 x i8> + %e = sext <2 x i8> %c to <2 x i32> + ret <2 x i32> %e +} + +define i32 @sext_trunc(i16 %x) { +; CHECK-LABEL: @sext_trunc( +; CHECK-NEXT: [[C:%.*]] = trunc i16 [[X:%.*]] to i8 +; CHECK-NEXT: [[E:%.*]] = sext i8 [[C]] to i32 +; CHECK-NEXT: ret i32 [[E]] +; + %c = trunc i16 %x to i8 + %e = sext i8 %c to i32 + ret i32 %e +} diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll index 88cd9c70af40d..07e06e6d26a27 100644 --- a/llvm/test/Transforms/InstCombine/zext.ll +++ b/llvm/test/Transforms/InstCombine/zext.ll @@ -867,3 +867,53 @@ entry: %res = zext nneg i2 %x to i32 ret i32 %res } + +define i32 @zext_trunc_nuw(i16 %x) { +; CHECK-LABEL: @zext_trunc_nuw( +; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[X:%.*]] to i32 +; CHECK-NEXT: ret i32 [[E1]] +; + %c = trunc nuw i16 %x to i8 + %e = zext i8 %c to i32 + ret i32 %e +} + +define i16 @zext_trunc_nuw_2(i32 %x) { +; CHECK-LABEL: @zext_trunc_nuw_2( +; CHECK-NEXT: [[E:%.*]] = trunc nuw i32 [[X:%.*]] to i16 +; CHECK-NEXT: ret i16 [[E]] +; + %c = trunc nuw i32 %x to i8 + %e = zext i8 %c to i16 + ret i16 %e +} + +define i16 @zext_trunc_nuw_3(i16 %x) { +; CHECK-LABEL: @zext_trunc_nuw_3( +; CHECK-NEXT: ret i16 [[E:%.*]] +; + %c = trunc nuw i16 %x to i8 + %e = zext i8 %c to i16 + ret i16 %e +} + +define <2 x i32> @zext_trunc_nuw_vec(<2 x i16> %x) { +; CHECK-LABEL: @zext_trunc_nuw_vec( +; CHECK-NEXT: [[E1:%.*]] = zext nneg <2 x i16> [[X:%.*]] to <2 x i32> +; CHECK-NEXT: ret <2 x i32> [[E1]] +; + %c = trunc nuw <2 x i16> %x to <2 x i8> + %e = zext <2 x i8> %c to <2 x i32> + ret <2 x i32> %e +} + +define i32 @zext_trunc(i16 %x) { +; CHECK-LABEL: @zext_trunc( +; CHECK-NEXT: [[E:%.*]] = and i16 [[X:%.*]], 255 +; CHECK-NEXT: [[E1:%.*]] = zext nneg i16 [[E]] to i32 +; CHECK-NEXT: ret i32 [[E1]] +; + %c = trunc i16 %x to i8 + %e = zext i8 %c to i32 + ret i32 %e +}