Skip to content

Commit

Permalink
[NVVMReflect] Improve folding inside of the NVVMReflect pass
Browse files Browse the repository at this point in the history
Summary:
The previous patch did very simple folding that only worked for driectly
used branches. This patch improves this by traversing the use-def chain
to sipmlify every constant subexpression until it reaches a terminator
we can delete. The support should work for all expected cases now.
  • Loading branch information
jhuber6 committed Feb 9, 2024
1 parent ffabcbc commit b00f01a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 69 deletions.
3 changes: 1 addition & 2 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ input IR module ``module.bc``, the following compilation flow is recommended:

The ``NVVMReflect`` pass will attempt to remove dead code even without
optimizations. This allows potentially incompatible instructions to be avoided
at all optimizations levels. This currently only works for simple conditionals
like the above example.
at all optimizations levels by using the ``__CUDA_ARCH`` argument.

1. Save list of external functions in ``module.bc``
2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
Expand Down
70 changes: 17 additions & 53 deletions llvm/lib/Target/NVPTX/NVVMReflect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}

SmallVector<Instruction *, 4> ToRemove;
SmallVector<ICmpInst *, 4> ToSimplify;
SmallVector<Instruction *, 4> ToSimplify;

// Go through the calls in this function. Each call to __nvvm_reflect or
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
Expand Down Expand Up @@ -177,9 +177,8 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}

// If the immediate user is a simple comparison we want to simplify it.
// TODO: This currently does not handle switch instructions.
for (User *U : Call->users())
if (ICmpInst *I = dyn_cast<ICmpInst>(U))
if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);

Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
Expand All @@ -190,56 +189,21 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
I->eraseFromParent();

// The code guarded by __nvvm_reflect may be invalid for the target machine.
// We need to do some basic dead code elimination to trim invalid code before
// it reaches the backend at all optimization levels.
SmallVector<BranchInst *> Simplified;
for (ICmpInst *Cmp : ToSimplify) {
Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));

if (!LHS || !RHS)
continue;

// If the comparison is a compile time constant we simply propagate it.
Constant *C = ConstantFoldCompareInstOperands(
Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());

if (!C)
continue;

for (User *U : Cmp->users())
if (BranchInst *I = dyn_cast<BranchInst>(U))
Simplified.push_back(I);

Cmp->replaceAllUsesWith(C);
Cmp->eraseFromParent();
}

// Each instruction here is a conditional branch off of a constant true or
// false value. Simply replace it with an unconditional branch to the
// appropriate basic block and delete the rest if it is trivially dead.
DenseSet<Instruction *> Removed;
for (BranchInst *Branch : Simplified) {
if (Removed.contains(Branch))
continue;

ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
if (!C || (!C->isOne() && !C->isZero()))
continue;

BasicBlock *TrueBB =
C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
BasicBlock *FalseBB =
C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);

// This transformation is only correct on simple edges.
if (!FalseBB->hasNPredecessors(1))
continue;

ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
if (FalseBB->use_empty() && !FalseBB->getFirstNonPHIOrDbg()) {
Removed.insert(FalseBB->getFirstNonPHIOrDbg());
changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
// Traverse the use-def chain, continually simplifying constant expressions
// until we find a terminator that we can then remove.
while (!ToSimplify.empty()) {
Instruction *I = ToSimplify.pop_back_val();
if (Constant *C =
ConstantFoldInstruction(I, F.getParent()->getDataLayout())) {
for (User *U : I->users())
if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);

I->replaceAllUsesWith(C);
if (isInstructionTriviallyDead(I))
I->eraseFromParent();
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
}
}

Expand Down
78 changes: 64 additions & 14 deletions llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,24 @@ if.end:
ret void
}

; SM_52: .visible .func (.param .b32 func_retval0) qux()
; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_52: ret;
; SM_70: .visible .func (.param .b32 func_retval0) qux()
; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_70: ret;
; SM_90: .visible .func (.param .b32 func_retval0) qux()
; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]];
; SM_90: ret;
; SM_52: .visible .func (.param .b32 func_retval0) qux()
; SM_52: mov.b32 %[[REG:.+]], 3;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) qux()
; SM_70: mov.b32 %[[REG:.+]], 2;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) qux()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @qux() {
entry:
%call = call i32 @__nvvm_reflect(ptr noundef @.str)
%cmp = icmp uge i32 %call, 700
%conv = zext i1 %cmp to i32
switch i32 %conv, label %sw.default [
switch i32 %call, label %sw.default [
i32 900, label %sw.bb
i32 700, label %sw.bb1
i32 520, label %sw.bb2
Expand Down Expand Up @@ -173,3 +174,52 @@ if.exit:
exit:
ret float 0.000000e+00
}

; SM_52: .visible .func (.param .b32 func_retval0) prop()
; SM_52: mov.b32 %[[REG:.+]], 3;
; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_52-NEXT: ret;
;
; SM_70: .visible .func (.param .b32 func_retval0) prop()
; SM_70: mov.b32 %[[REG:.+]], 2;
; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_70-NEXT: ret;
;
; SM_90: .visible .func (.param .b32 func_retval0) prop()
; SM_90: mov.b32 %[[REG:.+]], 1;
; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
; SM_90-NEXT: ret;
define i32 @prop() {
entry:
%call = call i32 @__nvvm_reflect(ptr @.str)
%conv = zext i32 %call to i64
%div = udiv i64 %conv, 100
%cmp = icmp eq i64 %div, 9
br i1 %cmp, label %if.then, label %if.else

if.then:
br label %return

if.else:
%div2 = udiv i64 %conv, 100
%cmp3 = icmp eq i64 %div2, 7
br i1 %cmp3, label %if.then5, label %if.else6

if.then5:
br label %return

if.else6:
%div7 = udiv i64 %conv, 100
%cmp8 = icmp eq i64 %div7, 5
br i1 %cmp8, label %if.then10, label %if.else11

if.then10:
br label %return

if.else11:
br label %return

return:
%retval = phi i32 [ 1, %if.then ], [ 2, %if.then5 ], [ 3, %if.then10 ], [ 4, %if.else11 ]
ret i32 %retval
}

0 comments on commit b00f01a

Please sign in to comment.