Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVVMReflect] Improve folding inside of the NVVMReflect pass #81253

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ input IR module ``module.bc``, the following compilation flow is recommended:

The ``NVVMReflect`` pass will attempt to remove dead code even without
optimizations. This allows potentially incompatible instructions to be avoided
at all optimizations levels. This currently only works for simple conditionals
like the above example.
at all optimizations levels by using the ``__CUDA_ARCH`` argument.

1. Save list of external functions in ``module.bc``
2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
Expand Down
70 changes: 17 additions & 53 deletions llvm/lib/Target/NVPTX/NVVMReflect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}

SmallVector<Instruction *, 4> ToRemove;
SmallVector<ICmpInst *, 4> ToSimplify;
SmallVector<Instruction *, 4> ToSimplify;

// Go through the calls in this function. Each call to __nvvm_reflect or
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
Expand Down Expand Up @@ -177,9 +177,8 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}

// If the immediate user is a simple comparison we want to simplify it.
// TODO: This currently does not handle switch instructions.
for (User *U : Call->users())
if (ICmpInst *I = dyn_cast<ICmpInst>(U))
if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);

Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
Expand All @@ -190,56 +189,21 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
I->eraseFromParent();

// The code guarded by __nvvm_reflect may be invalid for the target machine.
// We need to do some basic dead code elimination to trim invalid code before
// it reaches the backend at all optimization levels.
SmallVector<BranchInst *> Simplified;
for (ICmpInst *Cmp : ToSimplify) {
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));

if (!LHS || !RHS)
continue;

// If the comparison is a compile time constant we simply propagate it.
Constant *C = ConstantFoldCompareInstOperands(
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());

if (!C)
continue;

for (User *U : Cmp->users())
if (BranchInst *I = dyn_cast<BranchInst>(U))
Simplified.push_back(I);

Cmp->replaceAllUsesWith(C);
Cmp->eraseFromParent();
}

// Each instruction here is a conditional branch off of a constant true or
// false value. Simply replace it with an unconditional branch to the
// appropriate basic block and delete the rest if it is trivially dead.
DenseSet<Instruction *> Removed;
for (BranchInst *Branch : Simplified) {
if (Removed.contains(Branch))
continue;

ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
if (!C || (!C->isOne() && !C->isZero()))
continue;

BasicBlock *TrueBB =
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
BasicBlock *FalseBB =
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);

// This transformation is only correct on simple edges.
if (!FalseBB->hasNPredecessors(1))
continue;

ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
if (FalseBB->use_empty() && !FalseBB->getFirstNonPHIOrDbg()) {
Removed.insert(FalseBB->getFirstNonPHIOrDbg());
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
// Traverse the use-def chain, continually simplifying constant expressions
// until we find a terminator that we can then remove.
while (!ToSimplify.empty()) {
Instruction *I = ToSimplify.pop_back_val();
if (Constant *C =
ConstantFoldInstruction(I, F.getParent()->getDataLayout())) {
for (User *U : I->users())
if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);

I->replaceAllUsesWith(C);
if (isInstructionTriviallyDead(I))
I->eraseFromParent();
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
}
}

Expand Down
78 changes: 64 additions & 14 deletions llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,24 @@ if.end:
ret void
}

; SM_52: .visible .func (.param .b32 func_retval0) qux()
; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_52: ret;
; SM_70: .visible .func (.param .b32 func_retval0) qux()
; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_70: ret;
; SM_90: .visible .func (.param .b32 func_retval0) qux()
; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_90: ret;
; SM_52: .visible .func (.param .b32 func_retval0) qux()
; SM_52: mov.b32 %[[REG:.+]], 3;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) qux()
; SM_70: mov.b32 %[[REG:.+]], 2;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) qux()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @qux() {
entry:
%call = call i32 @__nvvm_reflect(ptr noundef @.str)
%cmp = icmp uge i32 %call, 700
%conv = zext i1 %cmp to i32
switch i32 %conv, label %sw.default [
switch i32 %call, label %sw.default [
i32 900, label %sw.bb
i32 700, label %sw.bb1
i32 520, label %sw.bb2
Expand Down Expand Up @@ -173,3 +174,52 @@ if.exit:
exit:
ret float 0.000000e+00
}

; SM_52: .visible .func (.param .b32 func_retval0) prop()
; SM_52: mov.b32 %[[REG:.+]], 3;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) prop()
; SM_70: mov.b32 %[[REG:.+]], 2;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) prop()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @prop() {
entry:
%call = call i32 @__nvvm_reflect(ptr @.str)
%conv = zext i32 %call to i64
%div = udiv i64 %conv, 100
%cmp = icmp eq i64 %div, 9
br i1 %cmp, label %if.then, label %if.else

if.then:
br label %return

if.else:
%div2 = udiv i64 %conv, 100
%cmp3 = icmp eq i64 %div2, 7
br i1 %cmp3, label %if.then5, label %if.else6

if.then5:
br label %return

if.else6:
%div7 = udiv i64 %conv, 100
%cmp8 = icmp eq i64 %div7, 5
br i1 %cmp8, label %if.then10, label %if.else11

if.then10:
br label %return

if.else11:
br label %return

return:
%retval = phi i32 [ 1, %if.then ], [ 2, %if.then5 ], [ 3, %if.then10 ], [ 4, %if.else11 ]
ret i32 %retval
}
Loading