From a5952614c12594c6452aa2ef611eaa1f00cc5364 Mon Sep 17 00:00:00 2001 From: "Cui, Dele" Date: Tue, 10 Sep 2024 21:49:17 +0800 Subject: [PATCH] [NFC] Refactor switch function creation to generate single return point (#2714) The refactoring is to simplify the vectorization of generated functions. Signed-off-by: Cui, Dele --- lib/SPIRV/OCLUtil.h | 29 ++++++++---- test/atomic_explicit_arguments.spt | 38 +++++++++++---- test/barrier_explicit_arguments.spt | 27 +++++++---- test/mem_fence_explicit_arguments.spt | 46 +++++++++++++------ test/transcoding/atomic_explicit_arguments.cl | 40 ++++++++++++---- test/transcoding/barrier-runtime-scope.ll | 19 ++++++-- 6 files changed, 144 insertions(+), 55 deletions(-) diff --git a/lib/SPIRV/OCLUtil.h b/lib/SPIRV/OCLUtil.h index 31bde8e591..d8052cb772 100644 --- a/lib/SPIRV/OCLUtil.h +++ b/lib/SPIRV/OCLUtil.h @@ -535,17 +535,18 @@ getOrCreateSwitchFunc(StringRef MapName, Value *V, F->setLinkage(GlobalValue::PrivateLinkage); LLVMContext &Ctx = M->getContext(); - BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); - IRBuilder<> IRB(BB); + BasicBlock *EntryBB = BasicBlock::Create(Ctx, "entry", F); + IRBuilder<> EntryIRB(EntryBB); + AllocaInst *Result = EntryIRB.CreateAlloca(Ty, nullptr, "result"); SwitchInst *SI; F->arg_begin()->setName("key"); if (KeyMask) { Value *MaskV = ConstantInt::get(Type::getInt32Ty(Ctx), KeyMask); - Value *NewKey = IRB.CreateAnd(MaskV, F->arg_begin()); + Value *NewKey = EntryIRB.CreateAnd(MaskV, F->arg_begin()); NewKey->setName("key.masked"); - SI = IRB.CreateSwitch(NewKey, BB); + SI = EntryIRB.CreateSwitch(NewKey, EntryBB); } else { - SI = IRB.CreateSwitch(F->arg_begin(), BB); + SI = EntryIRB.CreateSwitch(F->arg_begin(), EntryBB); } if (!DefaultCase) { @@ -555,17 +556,27 @@ getOrCreateSwitchFunc(StringRef MapName, Value *V, SI->setDefaultDest(DefaultBB); } + BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); + BasicBlock *CaseBB = nullptr; Map.foreach ([&](int Key, int Val) { if (IsReverse) std::swap(Key, Val); - BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case." + Twine(Key), F); + CaseBB = BasicBlock::Create(Ctx, "case." + Twine(Key), F); IRBuilder<> CaseIRB(CaseBB); - CaseIRB.CreateRet(CaseIRB.getInt32(Val)); - SI->addCase(IRB.getInt32(Key), CaseBB); + CaseIRB.CreateStore(CaseIRB.getInt32(Val), Result); + CaseIRB.CreateBr(ExitBB); + SI->addCase(EntryIRB.getInt32(Key), CaseBB); if (Key == DefaultCase) SI->setDefaultDest(CaseBB); }); - assert(SI->getDefaultDest() != BB && "Invalid default destination in switch"); + + ExitBB->moveAfter(CaseBB); + IRBuilder<> ExitIRB(ExitBB); + LoadInst *RetVal = ExitIRB.CreateLoad(Ty, Result, "retVal"); + ExitIRB.CreateRet(RetVal); + + assert(SI->getDefaultDest() != EntryBB && + "Invalid default destination in switch"); return addCallInst(M, MapName, Ty, V, nullptr, InsertPoint); } diff --git a/test/atomic_explicit_arguments.spt b/test/atomic_explicit_arguments.spt index 60945b21be..cd35268884 100644 --- a/test/atomic_explicit_arguments.spt +++ b/test/atomic_explicit_arguments.spt @@ -44,6 +44,7 @@ ; CHECK: define private spir_func i32 @__translate_spirv_memory_scope(i32 %key) { ; CHECK: entry: +; CHECK: %result = alloca i32, align 4 ; CHECK: switch i32 %key, label %default [ ; CHECK: i32 4, label %case.4 ; CHECK: i32 2, label %case.2 @@ -54,19 +55,28 @@ ; CHECK: default: ; CHECK: unreachable ; CHECK: case.4: -; CHECK: ret i32 0 +; CHECK: store i32 0, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2: -; CHECK: ret i32 1 +; CHECK: store i32 1, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.1: -; CHECK: ret i32 2 +; CHECK: store i32 2, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.0: -; CHECK: ret i32 3 +; CHECK: store i32 3, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.3: -; CHECK: ret i32 4 +; CHECK: store i32 4, ptr %result, align 4 +; CHECK: br label %exit +; CHECK: exit: +; CHECK: %retVal = load i32, ptr %result, align 4 +; CHECK: ret i32 %retVal ; CHECK: } ; CHECK: define private spir_func i32 @__translate_spirv_memory_order(i32 %key) { ; CHECK: entry: +; CHECK: %result = alloca i32, align 4 ; CHECK: %key.masked = and i32 30, %key ; CHECK: switch i32 %key.masked, label %default [ ; CHECK: i32 0, label %case.0 @@ -78,13 +88,21 @@ ; CHECK: default: ; CHECK: unreachable ; CHECK: case.0: -; CHECK: ret i32 0 +; CHECK: store i32 0, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2: -; CHECK: ret i32 2 +; CHECK: store i32 2, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.4: -; CHECK: ret i32 3 +; CHECK: store i32 3, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.8: -; CHECK: ret i32 4 +; CHECK: store i32 4, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.16: -; CHECK: ret i32 5 +; CHECK: store i32 5, ptr %result, align 4 +; CHECK: br label %exit +; CHECK: exit: +; CHECK: %retVal = load i32, ptr %result, align 4 +; CHECK: ret i32 %retVal ; CHECK: } diff --git a/test/barrier_explicit_arguments.spt b/test/barrier_explicit_arguments.spt index e9c3be0f8d..36737c289d 100644 --- a/test/barrier_explicit_arguments.spt +++ b/test/barrier_explicit_arguments.spt @@ -39,8 +39,9 @@ ; CHECK-12: call spir_func void @_Z7barrierj(i32 %call1) ; CHECK-20: call spir_func void @_Z18work_group_barrierj12memory_scope(i32 %call ; CHECK-20: call spir_func void @_Z17sub_group_barrierj12memory_scope(i32 %call -; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) +; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) { ; CHECK: entry: +; CHECK: %result = alloca i32, align 4 ; CHECK: %key.masked = and i32 2816, %key ; CHECK: switch i32 %key.masked, label %default [ ; CHECK: i32 256, label %case.256 @@ -54,17 +55,27 @@ ; CHECK: default: ; CHECK: unreachable ; CHECK: case.256: -; CHECK: ret i32 1 +; CHECK: store i32 1, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.512: -; CHECK: ret i32 2 +; CHECK: store i32 2, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.768: -; CHECK: ret i32 3 +; CHECK: store i32 3, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2048: -; CHECK: ret i32 4 +; CHECK: store i32 4, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2304: -; CHECK: ret i32 5 +; CHECK: store i32 5, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2560: -; CHECK: ret i32 6 +; CHECK: store i32 6, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2816: -; CHECK: ret i32 7 +; CHECK: store i32 7, ptr %result, align 4 +; CHECK: br label %exit +; CHECK: exit: +; CHECK: %retVal = load i32, ptr %result, align 4 +; CHECK: ret i32 %retVal ; CHECK: } diff --git a/test/mem_fence_explicit_arguments.spt b/test/mem_fence_explicit_arguments.spt index dd38dd77d8..afddb736e6 100644 --- a/test/mem_fence_explicit_arguments.spt +++ b/test/mem_fence_explicit_arguments.spt @@ -44,8 +44,9 @@ ; CHECK-20: %call3 = call spir_func i32 @__translate_spirv_memory_fence(i32 %[[VAL1]]) ; CHECK-20: %call4 = call spir_func i32 @__translate_spirv_memory_order(i32 %[[VAL1]]) ; CHECK-20: call spir_func void @_Z22atomic_work_item_fencej12memory_order12memory_scope(i32 %call3, i32 %call4, i32 %call2) -; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) +; CHECK: define private spir_func i32 @__translate_spirv_memory_fence(i32 %key) { ; CHECK: entry: +; CHECK: %result = alloca i32, align 4 ; CHECK: %key.masked = and i32 2816, %key ; CHECK: switch i32 %key.masked, label %default [ ; CHECK: i32 256, label %case.256 @@ -59,22 +60,33 @@ ; CHECK: default: ; CHECK: unreachable ; CHECK: case.256: -; CHECK: ret i32 1 +; CHECK: store i32 1, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.512: -; CHECK: ret i32 2 +; CHECK: store i32 2, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.768: -; CHECK: ret i32 3 +; CHECK: store i32 3, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2048: -; CHECK: ret i32 4 +; CHECK: store i32 4, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2304: -; CHECK: ret i32 5 +; CHECK: store i32 5, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2560: -; CHECK: ret i32 6 +; CHECK: store i32 6, ptr %result, align 4 +; CHECK: br label %exit ; CHECK: case.2816: -; CHECK: ret i32 7 +; CHECK: store i32 7, ptr %result, align 4 +; CHECK: br label %exit +; CHECK: exit: +; CHECK: %retVal = load i32, ptr %result, align 4 +; CHECK: ret i32 %retVal ; CHECK: } ; CHECK-20: define private spir_func i32 @__translate_spirv_memory_order(i32 %key) { ; CHECK-20: entry: +; CHECK-20: %result = alloca i32, align 4 ; CHECK-20: %key.masked = and i32 30, %key ; CHECK-20: switch i32 %key.masked, label %default [ ; CHECK-20: i32 0, label %case.0 @@ -86,13 +98,21 @@ ; CHECK-20: default: ; CHECK-20: unreachable ; CHECK-20: case.0: -; CHECK-20: ret i32 0 +; CHECK-20: store i32 0, ptr %result, align 4 +; CHECK-20: br label %exit ; CHECK-20: case.2: -; CHECK-20: ret i32 2 +; CHECK-20: store i32 2, ptr %result, align 4 +; CHECK-20: br label %exit ; CHECK-20: case.4: -; CHECK-20: ret i32 3 +; CHECK-20: store i32 3, ptr %result, align 4 +; CHECK-20: br label %exit ; CHECK-20: case.8: -; CHECK-20: ret i32 4 +; CHECK-20: store i32 4, ptr %result, align 4 +; CHECK-20: br label %exit ; CHECK-20: case.16: -; CHECK-20: ret i32 5 +; CHECK-20: store i32 5, ptr %result, align 4 +; CHECK-20: br label %exit +; CHECK-20: exit: +; CHECK-20: %retVal = load i32, ptr %result, align 4 +; CHECK-20: ret i32 %retVal ; CHECK-20: } diff --git a/test/transcoding/atomic_explicit_arguments.cl b/test/transcoding/atomic_explicit_arguments.cl index 2b4cc8be89..15c77883c5 100644 --- a/test/transcoding/atomic_explicit_arguments.cl +++ b/test/transcoding/atomic_explicit_arguments.cl @@ -1,3 +1,5 @@ +// clang-format off + // RUN: %clang_cc1 -triple spir -cl-std=cl2.0 %s -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc -o %t.bc // RUN: llvm-spirv %t.bc -o %t.spv // RUN: llvm-spirv %t.spv -to-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV @@ -31,32 +33,50 @@ int load (volatile atomic_int* obj, memory_order order, memory_scope scope) { // CHECK-SPIRV: Function [[int]] [[TRANS_MEM_SCOPE]] // CHECK-SPIRV: FunctionParameter [[int]] [[KEY:[0-9]+]] +// CHECK-SPIRV: Variable {{[0-9]+}} [[RES:[0-9]+]] // CHECK-SPIRV: Switch [[KEY]] [[CASE_2:[0-9]+]] 0 [[CASE_0:[0-9]+]] 1 [[CASE_1:[0-9]+]] 2 [[CASE_2]] 3 [[CASE_3:[0-9]+]] 4 [[CASE_4:[0-9]+]] // CHECK-SPIRV: Label [[CASE_0]] -// CHECK-SPIRV: ReturnValue [[FOUR]] +// CHECK-SPIRV: Store [[RES]] [[FOUR]] +// CHECK-SPIRV: Branch [[EXIT:[0-9]+]] // CHECK-SPIRV: Label [[CASE_1]] -// CHECK-SPIRV: ReturnValue [[TWO]] +// CHECK-SPIRV: Store [[RES]] [[TWO]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_2]] -// CHECK-SPIRV: ReturnValue [[ONE]] +// CHECK-SPIRV: Store [[RES]] [[ONE]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_3]] -// CHECK-SPIRV: ReturnValue [[ZERO]] +// CHECK-SPIRV: Store [[RES]] [[ZERO]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_4]] -// CHECK-SPIRV: ReturnValue [[THREE]] +// CHECK-SPIRV: Store [[RES]] [[THREE]] +// CHECK-SPIRV: Branch [[EXIT]] +// CHECK-SPIRV: Label [[EXIT]] +// CHECK-SPIRV: Load [[int]] [[RET_VAR:[0-9]+]] [[RES]] +// CHECK-SPIRV: ReturnValue [[RET_VAR]] // CHECK-SPIRV: FunctionEnd // CHECK-SPIRV: Function [[int]] [[TRANS_MEM_ORDER]] // CHECK-SPIRV: FunctionParameter [[int]] [[KEY:[0-9]+]] +// CHECK-SPIRV: Variable {{[0-9]+}} [[RES:[0-9]+]] // CHECK-SPIRV: Switch [[KEY]] [[CASE_5:[0-9]+]] 0 [[CASE_0:[0-9]+]] 2 [[CASE_2:[0-9]+]] 3 [[CASE_3:[0-9]+]] 4 [[CASE_4:[0-9]+]] 5 [[CASE_5]] // CHECK-SPIRV: Label [[CASE_0]] -// CHECK-SPIRV: ReturnValue [[ZERO]] +// CHECK-SPIRV: Store [[RES]] [[ZERO]] +// CHECK-SPIRV: Branch [[EXIT:[0-9]+]] // CHECK-SPIRV: Label [[CASE_2]] -// CHECK-SPIRV: ReturnValue [[TWO]] +// CHECK-SPIRV: Store [[RES]] [[TWO]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_3]] -// CHECK-SPIRV: ReturnValue [[FOUR]] +// CHECK-SPIRV: Store [[RES]] [[FOUR]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_4]] -// CHECK-SPIRV: ReturnValue [[EIGHT]] +// CHECK-SPIRV: Store [[RES]] [[EIGHT]] +// CHECK-SPIRV: Branch [[EXIT]] // CHECK-SPIRV: Label [[CASE_5]] -// CHECK-SPIRV: ReturnValue [[SIXTEEN]] +// CHECK-SPIRV: Store [[RES]] [[SIXTEEN]] +// CHECK-SPIRV: Branch [[EXIT]] +// CHECK-SPIRV: Label [[EXIT]] +// CHECK-SPIRV: Load [[int]] [[RET_VAR:[0-9]+]] [[RES]] +// CHECK-SPIRV: ReturnValue [[RET_VAR]] // CHECK-SPIRV: FunctionEnd diff --git a/test/transcoding/barrier-runtime-scope.ll b/test/transcoding/barrier-runtime-scope.ll index 1a56c4413e..1f451f2e38 100644 --- a/test/transcoding/barrier-runtime-scope.ll +++ b/test/transcoding/barrier-runtime-scope.ll @@ -23,6 +23,7 @@ ; CHECK-LLVM: define private spir_func i32 @__translate_spirv_memory_scope(i32 %key) { ; CHECK-LLVM: entry: +; CHECK-LLVM: %result = alloca i32, align 4 ; CHECK-LLVM: switch i32 %key, label %default [ ; CHECK-LLVM: i32 4, label %case.4 ; CHECK-LLVM: i32 2, label %case.2 @@ -33,15 +34,23 @@ ; CHECK-LLVM: default: ; preds = %entry ; CHECK-LLVM: unreachable ; CHECK-LLVM: case.4: ; preds = %entry -; CHECK-LLVM: ret i32 0 +; CHECK-LLVM: store i32 0, ptr %result, align 4 +; CHECK-LLVM: br label %exit ; CHECK-LLVM: case.2: ; preds = %entry -; CHECK-LLVM: ret i32 1 +; CHECK-LLVM: store i32 1, ptr %result, align 4 +; CHECK-LLVM: br label %exit ; CHECK-LLVM: case.1: ; preds = %entry -; CHECK-LLVM: ret i32 2 +; CHECK-LLVM: store i32 2, ptr %result, align 4 +; CHECK-LLVM: br label %exit ; CHECK-LLVM: case.0: ; preds = %entry -; CHECK-LLVM: ret i32 3 +; CHECK-LLVM: store i32 3, ptr %result, align 4 +; CHECK-LLVM: br label %exit ; CHECK-LLVM: case.3: ; preds = %entry -; CHECK-LLVM: ret i32 4 +; CHECK-LLVM: store i32 4, ptr %result, align 4 +; CHECK-LLVM: br label %exit +; CHECK-LLVM: exit: ; preds = %case.3, %case.0, %case.1, %case.2, %case.4 +; CHECK-LLVM: %retVal = load i32, ptr %result, align 4 +; CHECK-LLVM: ret i32 %retVal ; CHECK-LLVM: } target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"