diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp index e9bf59c6850a3..b60b15b6c3a2b 100644 --- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp +++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" @@ -35,8 +36,38 @@ static bool tryToImproveAlign( return true; } } - // TODO: Also handle memory intrinsics. - return false; + + IntrinsicInst *II = dyn_cast(I); + if (!II) + return false; + + // TODO: Handle more memory intrinsics. + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: { + int AlignOpIdx = II->getIntrinsicID() == Intrinsic::masked_load ? 1 : 2; + Value *PtrOp = II->getIntrinsicID() == Intrinsic::masked_load + ? II->getArgOperand(0) + : II->getArgOperand(1); + Type *Type = II->getIntrinsicID() == Intrinsic::masked_load + ? II->getType() + : II->getArgOperand(0)->getType(); + + Align OldAlign = + cast(II->getArgOperand(AlignOpIdx))->getAlignValue(); + Align PrefAlign = DL.getPrefTypeAlign(Type); + Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign); + if (NewAlign <= OldAlign) + return false; + + Value *V = + ConstantInt::get(Type::getInt32Ty(II->getContext()), NewAlign.value()); + II->setOperand(AlignOpIdx, V); + return true; + } + default: + return false; + } } bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) { diff --git a/llvm/test/Transforms/InferAlignment/masked.ll b/llvm/test/Transforms/InferAlignment/masked.ll new file mode 100644 index 0000000000000..1b8d26417d75e --- /dev/null +++ b/llvm/test/Transforms/InferAlignment/masked.ll @@ -0,0 +1,34 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=infer-alignment -S | FileCheck %s + +define <2 x i32> @load(<2 x i1> %mask, ptr %ptr) { +; CHECK-LABEL: define <2 x i32> @load( +; CHECK-SAME: <2 x i1> [[MASK:%.*]], ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ] +; CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr [[PTR]], i32 64, <2 x i1> [[MASK]], <2 x i32> poison) +; CHECK-NEXT: ret <2 x i32> [[MASKED_LOAD]] +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ] + %masked_load = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr %ptr, i32 1, <2 x i1> %mask, <2 x i32> poison) + ret <2 x i32> %masked_load +} + +define void @store(<2 x i1> %mask, <2 x i32> %val, ptr %ptr) { +; CHECK-LABEL: define void @store( +; CHECK-SAME: <2 x i1> [[MASK:%.*]], <2 x i32> [[VAL:%.*]], ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ] +; CHECK-NEXT: tail call void @llvm.masked.store.v2i32.p0(<2 x i32> [[VAL]], ptr [[PTR]], i32 64, <2 x i1> [[MASK]]) +; CHECK-NEXT: ret void +; +entry: + call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ] + tail call void @llvm.masked.store.v2i32.p0(<2 x i32> %val, ptr %ptr, i32 1, <2 x i1> %mask) + ret void +} + +declare void @llvm.assume(i1) +declare <2 x i32> @llvm.masked.load.v2i32.p0(ptr, i32, <2 x i1>, <2 x i32>) +declare void @llvm.masked.store.v2i32.p0(<2 x i32>, ptr, i32, <2 x i1>)