Skip to content

Commit

Permalink
[NFC] Refactor switch function creation to generate single return poi…
Browse files Browse the repository at this point in the history
…nt (#2714)

The refactoring is to simplify the vectorization of generated functions.

Signed-off-by: Cui, Dele <dele.cui@intel.com>
  • Loading branch information
delecui authored Sep 10, 2024
1 parent ea2fcc1 commit a595261
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 55 deletions.
29 changes: 20 additions & 9 deletions lib/SPIRV/OCLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,17 +535,18 @@ getOrCreateSwitchFunc(StringRef MapName, Value *V,
F->setLinkage(GlobalValue::PrivateLinkage);

LLVMContext &Ctx = M->getContext();
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
IRBuilder<> IRB(BB);
BasicBlock *EntryBB = BasicBlock::Create(Ctx, "entry", F);
IRBuilder<> EntryIRB(EntryBB);
AllocaInst *Result = EntryIRB.CreateAlloca(Ty, nullptr, "result");
SwitchInst *SI;
F->arg_begin()->setName("key");
if (KeyMask) {
Value *MaskV = ConstantInt::get(Type::getInt32Ty(Ctx), KeyMask);
Value *NewKey = IRB.CreateAnd(MaskV, F->arg_begin());
Value *NewKey = EntryIRB.CreateAnd(MaskV, F->arg_begin());
NewKey->setName("key.masked");
SI = IRB.CreateSwitch(NewKey, BB);
SI = EntryIRB.CreateSwitch(NewKey, EntryBB);
} else {
SI = IRB.CreateSwitch(F->arg_begin(), BB);
SI = EntryIRB.CreateSwitch(F->arg_begin(), EntryBB);
}

if (!DefaultCase) {
Expand All @@ -555,17 +556,27 @@ getOrCreateSwitchFunc(StringRef MapName, Value *V,
SI->setDefaultDest(DefaultBB);
}

BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
BasicBlock *CaseBB = nullptr;
Map.foreach ([&](int Key, int Val) {
if (IsReverse)
std::swap(Key, Val);
BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case." + Twine(Key), F);
CaseBB = BasicBlock::Create(Ctx, "case." + Twine(Key), F);
IRBuilder<> CaseIRB(CaseBB);
CaseIRB.CreateRet(CaseIRB.getInt32(Val));
SI->addCase(IRB.getInt32(Key), CaseBB);
CaseIRB.CreateStore(CaseIRB.getInt32(Val), Result);
CaseIRB.CreateBr(ExitBB);
SI->addCase(EntryIRB.getInt32(Key), CaseBB);
if (Key == DefaultCase)
SI->setDefaultDest(CaseBB);
});
assert(SI->getDefaultDest() != BB && "Invalid default destination in switch");

ExitBB->moveAfter(CaseBB);
IRBuilder<> ExitIRB(ExitBB);
LoadInst *RetVal = ExitIRB.CreateLoad(Ty, Result, "retVal");
ExitIRB.CreateRet(RetVal);

assert(SI->getDefaultDest() != EntryBB &&
"Invalid default destination in switch");
return addCallInst(M, MapName, Ty, V, nullptr, InsertPoint);
}

Expand Down
38 changes: 28 additions & 10 deletions test/atomic_explicit_arguments.spt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

; CHECK: define private spir_func i32 @__translate_spirv_memory_scope(i32 %key) {
; CHECK: entry:
; CHECK: %result = alloca i32, align 4
; CHECK: switch i32 %key, label %default [
; CHECK: i32 4, label %case.4
; CHECK: i32 2, label %case.2
Expand All @@ -54,19 +55,28 @@
; CHECK: default:
; CHECK: unreachable
; CHECK: case.4:
; CHECK: ret i32 0
; CHECK: store i32 0, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2:
; CHECK: ret i32 1
; CHECK: store i32 1, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.1:
; CHECK: ret i32 2
; CHECK: store i32 2, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.0:
; CHECK: ret i32 3
; CHECK: store i32 3, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.3:
; CHECK: ret i32 4
; CHECK: store i32 4, ptr %result, align 4
; CHECK: br label %exit
; CHECK: exit:
; CHECK: %retVal = load i32, ptr %result, align 4
; CHECK: ret i32 %retVal
; CHECK: }

; CHECK: define private spir_func i32 @__translate_spirv_memory_order(i32 %key) {
; CHECK: entry:
; CHECK: %result = alloca i32, align 4
; CHECK: %key.masked = and i32 30, %key
; CHECK: switch i32 %key.masked, label %default [
; CHECK: i32 0, label %case.0
Expand All @@ -78,13 +88,21 @@
; CHECK: default:
; CHECK: unreachable
; CHECK: case.0:
; CHECK: ret i32 0
; CHECK: store i32 0, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2:
; CHECK: ret i32 2
; CHECK: store i32 2, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.4:
; CHECK: ret i32 3
; CHECK: store i32 3, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.8:
; CHECK: ret i32 4
; CHECK: store i32 4, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.16:
; CHECK: ret i32 5
; CHECK: store i32 5, ptr %result, align 4
; CHECK: br label %exit
; CHECK: exit:
; CHECK: %retVal = load i32, ptr %result, align 4
; CHECK: ret i32 %retVal
; CHECK: }
27 changes: 19 additions & 8 deletions test/barrier_explicit_arguments.spt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
; CHECK-12: call spir_func void @_Z7barrierj(i32 %call1)
; CHECK-20: call spir_func void @_Z18work_group_barrierj12memory_scope(i32 %call
; CHECK-20: call spir_func void @_Z17sub_group_barrierj12memory_scope(i32 %call
; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key)
; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) {
; CHECK: entry:
; CHECK: %result = alloca i32, align 4
; CHECK: %key.masked = and i32 2816, %key
; CHECK: switch i32 %key.masked, label %default [
; CHECK: i32 256, label %case.256
Expand All @@ -54,17 +55,27 @@
; CHECK: default:
; CHECK: unreachable
; CHECK: case.256:
; CHECK: ret i32 1
; CHECK: store i32 1, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.512:
; CHECK: ret i32 2
; CHECK: store i32 2, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.768:
; CHECK: ret i32 3
; CHECK: store i32 3, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2048:
; CHECK: ret i32 4
; CHECK: store i32 4, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2304:
; CHECK: ret i32 5
; CHECK: store i32 5, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2560:
; CHECK: ret i32 6
; CHECK: store i32 6, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2816:
; CHECK: ret i32 7
; CHECK: store i32 7, ptr %result, align 4
; CHECK: br label %exit
; CHECK: exit:
; CHECK: %retVal = load i32, ptr %result, align 4
; CHECK: ret i32 %retVal
; CHECK: }
46 changes: 33 additions & 13 deletions test/mem_fence_explicit_arguments.spt
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
; CHECK-20: %call3 = call spir_func i32 @__translate_spirv_memory_fence(i32 %[[VAL1]])
; CHECK-20: %call4 = call spir_func i32 @__translate_spirv_memory_order(i32 %[[VAL1]])
; CHECK-20: call spir_func void @_Z22atomic_work_item_fencej12memory_order12memory_scope(i32 %call3, i32 %call4, i32 %call2)
; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key)
; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) {
; CHECK: entry:
; CHECK: %result = alloca i32, align 4
; CHECK: %key.masked = and i32 2816, %key
; CHECK: switch i32 %key.masked, label %default [
; CHECK: i32 256, label %case.256
Expand All @@ -59,22 +60,33 @@
; CHECK: default:
; CHECK: unreachable
; CHECK: case.256:
; CHECK: ret i32 1
; CHECK: store i32 1, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.512:
; CHECK: ret i32 2
; CHECK: store i32 2, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.768:
; CHECK: ret i32 3
; CHECK: store i32 3, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2048:
; CHECK: ret i32 4
; CHECK: store i32 4, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2304:
; CHECK: ret i32 5
; CHECK: store i32 5, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2560:
; CHECK: ret i32 6
; CHECK: store i32 6, ptr %result, align 4
; CHECK: br label %exit
; CHECK: case.2816:
; CHECK: ret i32 7
; CHECK: store i32 7, ptr %result, align 4
; CHECK: br label %exit
; CHECK: exit:
; CHECK: %retVal = load i32, ptr %result, align 4
; CHECK: ret i32 %retVal
; CHECK: }
; CHECK-20: define private spir_func i32 @__translate_spirv_memory_order(i32 %key) {
; CHECK-20: entry:
; CHECK-20: %result = alloca i32, align 4
; CHECK-20: %key.masked = and i32 30, %key
; CHECK-20: switch i32 %key.masked, label %default [
; CHECK-20: i32 0, label %case.0
Expand All @@ -86,13 +98,21 @@
; CHECK-20: default:
; CHECK-20: unreachable
; CHECK-20: case.0:
; CHECK-20: ret i32 0
; CHECK-20: store i32 0, ptr %result, align 4
; CHECK-20: br label %exit
; CHECK-20: case.2:
; CHECK-20: ret i32 2
; CHECK-20: store i32 2, ptr %result, align 4
; CHECK-20: br label %exit
; CHECK-20: case.4:
; CHECK-20: ret i32 3
; CHECK-20: store i32 3, ptr %result, align 4
; CHECK-20: br label %exit
; CHECK-20: case.8:
; CHECK-20: ret i32 4
; CHECK-20: store i32 4, ptr %result, align 4
; CHECK-20: br label %exit
; CHECK-20: case.16:
; CHECK-20: ret i32 5
; CHECK-20: store i32 5, ptr %result, align 4
; CHECK-20: br label %exit
; CHECK-20: exit:
; CHECK-20: %retVal = load i32, ptr %result, align 4
; CHECK-20: ret i32 %retVal
; CHECK-20: }
40 changes: 30 additions & 10 deletions test/transcoding/atomic_explicit_arguments.cl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// clang-format off

// RUN: %clang_cc1 -triple spir -cl-std=cl2.0 %s -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc -o %t.bc
// RUN: llvm-spirv %t.bc -o %t.spv
// RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
Expand Down Expand Up @@ -31,32 +33,50 @@ int load (volatile atomic_int* obj, memory_order order, memory_scope scope) {

// CHECK-SPIRV: Function [[int]] [[TRANS_MEM_SCOPE]]
// CHECK-SPIRV: FunctionParameter [[int]] [[KEY:[0-9]+]]
// CHECK-SPIRV: Variable {{[0-9]+}} [[RES:[0-9]+]]
// CHECK-SPIRV: Switch [[KEY]] [[CASE_2:[0-9]+]] 0 [[CASE_0:[0-9]+]] 1 [[CASE_1:[0-9]+]] 2 [[CASE_2]] 3 [[CASE_3:[0-9]+]] 4 [[CASE_4:[0-9]+]]
// CHECK-SPIRV: Label [[CASE_0]]
// CHECK-SPIRV: ReturnValue [[FOUR]]
// CHECK-SPIRV: Store [[RES]] [[FOUR]]
// CHECK-SPIRV: Branch [[EXIT:[0-9]+]]
// CHECK-SPIRV: Label [[CASE_1]]
// CHECK-SPIRV: ReturnValue [[TWO]]
// CHECK-SPIRV: Store [[RES]] [[TWO]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_2]]
// CHECK-SPIRV: ReturnValue [[ONE]]
// CHECK-SPIRV: Store [[RES]] [[ONE]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_3]]
// CHECK-SPIRV: ReturnValue [[ZERO]]
// CHECK-SPIRV: Store [[RES]] [[ZERO]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_4]]
// CHECK-SPIRV: ReturnValue [[THREE]]
// CHECK-SPIRV: Store [[RES]] [[THREE]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[EXIT]]
// CHECK-SPIRV: Load [[int]] [[RET_VAR:[0-9]+]] [[RES]]
// CHECK-SPIRV: ReturnValue [[RET_VAR]]
// CHECK-SPIRV: FunctionEnd

// CHECK-SPIRV: Function [[int]] [[TRANS_MEM_ORDER]]
// CHECK-SPIRV: FunctionParameter [[int]] [[KEY:[0-9]+]]
// CHECK-SPIRV: Variable {{[0-9]+}} [[RES:[0-9]+]]
// CHECK-SPIRV: Switch [[KEY]] [[CASE_5:[0-9]+]] 0 [[CASE_0:[0-9]+]] 2 [[CASE_2:[0-9]+]] 3 [[CASE_3:[0-9]+]] 4 [[CASE_4:[0-9]+]] 5 [[CASE_5]]
// CHECK-SPIRV: Label [[CASE_0]]
// CHECK-SPIRV: ReturnValue [[ZERO]]
// CHECK-SPIRV: Store [[RES]] [[ZERO]]
// CHECK-SPIRV: Branch [[EXIT:[0-9]+]]
// CHECK-SPIRV: Label [[CASE_2]]
// CHECK-SPIRV: ReturnValue [[TWO]]
// CHECK-SPIRV: Store [[RES]] [[TWO]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_3]]
// CHECK-SPIRV: ReturnValue [[FOUR]]
// CHECK-SPIRV: Store [[RES]] [[FOUR]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_4]]
// CHECK-SPIRV: ReturnValue [[EIGHT]]
// CHECK-SPIRV: Store [[RES]] [[EIGHT]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[CASE_5]]
// CHECK-SPIRV: ReturnValue [[SIXTEEN]]
// CHECK-SPIRV: Store [[RES]] [[SIXTEEN]]
// CHECK-SPIRV: Branch [[EXIT]]
// CHECK-SPIRV: Label [[EXIT]]
// CHECK-SPIRV: Load [[int]] [[RET_VAR:[0-9]+]] [[RES]]
// CHECK-SPIRV: ReturnValue [[RET_VAR]]
// CHECK-SPIRV: FunctionEnd


Expand Down
19 changes: 14 additions & 5 deletions test/transcoding/barrier-runtime-scope.ll
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

; CHECK-LLVM: define private spir_func i32 @__translate_spirv_memory_scope(i32 %key) {
; CHECK-LLVM: entry:
; CHECK-LLVM: %result = alloca i32, align 4
; CHECK-LLVM: switch i32 %key, label %default [
; CHECK-LLVM: i32 4, label %case.4
; CHECK-LLVM: i32 2, label %case.2
Expand All @@ -33,15 +34,23 @@
; CHECK-LLVM: default: ; preds = %entry
; CHECK-LLVM: unreachable
; CHECK-LLVM: case.4: ; preds = %entry
; CHECK-LLVM: ret i32 0
; CHECK-LLVM: store i32 0, ptr %result, align 4
; CHECK-LLVM: br label %exit
; CHECK-LLVM: case.2: ; preds = %entry
; CHECK-LLVM: ret i32 1
; CHECK-LLVM: store i32 1, ptr %result, align 4
; CHECK-LLVM: br label %exit
; CHECK-LLVM: case.1: ; preds = %entry
; CHECK-LLVM: ret i32 2
; CHECK-LLVM: store i32 2, ptr %result, align 4
; CHECK-LLVM: br label %exit
; CHECK-LLVM: case.0: ; preds = %entry
; CHECK-LLVM: ret i32 3
; CHECK-LLVM: store i32 3, ptr %result, align 4
; CHECK-LLVM: br label %exit
; CHECK-LLVM: case.3: ; preds = %entry
; CHECK-LLVM: ret i32 4
; CHECK-LLVM: store i32 4, ptr %result, align 4
; CHECK-LLVM: br label %exit
; CHECK-LLVM: exit: ; preds = %case.3, %case.0, %case.1, %case.2, %case.4
; CHECK-LLVM: %retVal = load i32, ptr %result, align 4
; CHECK-LLVM: ret i32 %retVal
; CHECK-LLVM: }

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
Expand Down

0 comments on commit a595261

Please sign in to comment.