Skip to content

Commit 802e71f

Browse files
committed
LLVMCodeBuilder: Add main function pointer member
1 parent c8e7938 commit 802e71f

File tree

2 files changed

+63
-63
lines changed

2 files changed

+63
-63
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -63,40 +63,39 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
6363
// Create function
6464
std::string funcName = getMainFunctionName(m_procedurePrototype);
6565
llvm::FunctionType *funcType = getMainFunctionType(m_procedurePrototype);
66-
llvm::Function *func;
6766

6867
if (m_procedurePrototype)
69-
func = getOrCreateFunction(funcName, funcType);
68+
m_function = getOrCreateFunction(funcName, funcType);
7069
else
71-
func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);
70+
m_function = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);
7271

73-
llvm::Value *executionContextPtr = func->getArg(0);
74-
llvm::Value *targetPtr = func->getArg(1);
75-
llvm::Value *targetVariables = func->getArg(2);
76-
llvm::Value *targetLists = func->getArg(3);
72+
llvm::Value *executionContextPtr = m_function->getArg(0);
73+
llvm::Value *targetPtr = m_function->getArg(1);
74+
llvm::Value *targetVariables = m_function->getArg(2);
75+
llvm::Value *targetLists = m_function->getArg(3);
7776
llvm::Value *warpArg = nullptr;
7877

7978
if (m_procedurePrototype)
80-
warpArg = func->getArg(4);
79+
warpArg = m_function->getArg(4);
8180

8281
if (m_procedurePrototype && m_warp)
83-
func->addFnAttr(llvm::Attribute::InlineHint);
82+
m_function->addFnAttr(llvm::Attribute::InlineHint);
8483
else {
8584
// NOTE: These attributes will be overriden by LLVMCompilerContext
8685
// TODO: Optimize all functions, maybe it doesn't take so long
87-
func->addFnAttr(llvm::Attribute::NoInline);
88-
func->addFnAttr(llvm::Attribute::OptimizeNone);
86+
m_function->addFnAttr(llvm::Attribute::NoInline);
87+
m_function->addFnAttr(llvm::Attribute::OptimizeNone);
8988
}
9089

91-
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", func);
92-
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", func);
90+
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", m_function);
91+
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", m_function);
9392
m_builder.SetInsertPoint(entry);
9493

9594
// Init coroutine
9695
std::unique_ptr<LLVMCoroutine> coro;
9796

9897
if (!m_warp)
99-
coro = std::make_unique<LLVMCoroutine>(m_module, &m_builder, func);
98+
coro = std::make_unique<LLVMCoroutine>(m_module, &m_builder, m_function);
10099

101100
std::vector<LLVMIfStatement> ifStatements;
102101
std::vector<LLVMLoop> loops;
@@ -677,14 +676,14 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
677676
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
678677
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
679678
llvm::Value *isAllocated = m_builder.CreateICmpUGT(allocatedSize, size);
680-
llvm::BasicBlock *ifBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
681-
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
682-
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
679+
llvm::BasicBlock *ifBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
680+
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
681+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
683682
m_builder.CreateCondBr(isAllocated, ifBlock, elseBlock);
684683

685684
// If there's enough space, use the allocated memory
686685
m_builder.SetInsertPoint(ifBlock);
687-
llvm::Value *itemPtr = getListItem(listPtr, size, func);
686+
llvm::Value *itemPtr = getListItem(listPtr, size);
688687
createReusedValueStore(arg.second, itemPtr, type);
689688
m_builder.CreateStore(m_builder.CreateAdd(size, m_builder.getInt64(1)), listPtr.sizePtr);
690689
m_builder.CreateBr(nextBlock);
@@ -748,7 +747,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
748747
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
749748
LLVMListPtr &listPtr = m_listPtrs[step.workList];
750749
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
751-
llvm::Value *itemPtr = getListItem(listPtr, index, func);
750+
llvm::Value *itemPtr = getListItem(listPtr, index);
752751
createValueStore(valueArg.second, itemPtr, type, listPtr.type);
753752

754753
auto &typeMap = m_scopeLists.back();
@@ -778,7 +777,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
778777
const auto &arg = step.args[0];
779778
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
780779
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
781-
step.functionReturnReg->value = getListItem(listPtr, index, func);
780+
step.functionReturnReg->value = getListItem(listPtr, index);
782781
step.functionReturnReg->setType(listPtr.type);
783782
break;
784783
}
@@ -795,28 +794,28 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
795794
assert(step.args.size() == 1);
796795
const auto &arg = step.args[0];
797796
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
798-
step.functionReturnReg->value = m_builder.CreateSIToFP(getListItemIndex(listPtr, arg.second, func), m_builder.getDoubleTy());
797+
step.functionReturnReg->value = m_builder.CreateSIToFP(getListItemIndex(listPtr, arg.second), m_builder.getDoubleTy());
799798
break;
800799
}
801800

802801
case LLVMInstruction::Type::ListContainsItem: {
803802
assert(step.args.size() == 1);
804803
const auto &arg = step.args[0];
805804
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
806-
llvm::Value *index = getListItemIndex(listPtr, arg.second, func);
805+
llvm::Value *index = getListItemIndex(listPtr, arg.second);
807806
step.functionReturnReg->value = m_builder.CreateICmpSGT(index, llvm::ConstantInt::get(m_builder.getInt64Ty(), -1, true));
808807
break;
809808
}
810809

811810
case LLVMInstruction::Type::Yield:
812811
// TODO: Do not allow use after suspend (use after free)
813-
createSuspend(coro.get(), func, warpArg, targetVariables);
812+
createSuspend(coro.get(), warpArg, targetVariables);
814813
break;
815814

816815
case LLVMInstruction::Type::BeginIf: {
817816
LLVMIfStatement statement;
818817
statement.beforeIf = m_builder.GetInsertBlock();
819-
statement.body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
818+
statement.body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
820819

821820
// Use last reg
822821
assert(step.args.size() == 1);
@@ -847,13 +846,13 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
847846

848847
// Jump to the branch after the if statement
849848
assert(!statement.afterIf);
850-
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", func);
849+
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
851850
freeScopeHeap();
852851
m_builder.CreateBr(statement.afterIf);
853852

854853
// Create else branch
855854
assert(!statement.elseBranch);
856-
statement.elseBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
855+
statement.elseBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
857856

858857
// Since there's an else branch, the conditional instruction should jump to it
859858
m_builder.SetInsertPoint(statement.beforeIf);
@@ -871,7 +870,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
871870

872871
// Jump to the branch after the if statement
873872
if (!statement.afterIf)
874-
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", func);
873+
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
875874

876875
m_builder.CreateBr(statement.afterIf);
877876

@@ -900,9 +899,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
900899
m_builder.CreateStore(zero, loop.index);
901900

902901
// Create branches
903-
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
904-
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
905-
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
902+
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
903+
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
904+
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
906905

907906
// Use last reg for count
908907
assert(step.args.size() == 1);
@@ -928,10 +927,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
928927
// Check index
929928
m_builder.SetInsertPoint(loop.conditionBranch);
930929

931-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
930+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
932931

933932
if (!loop.afterLoop)
934-
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
933+
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
935934

936935
llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index);
937936
comparison = m_builder.CreateOr(isInf, m_builder.CreateICmpULT(currentIndex, count));
@@ -958,8 +957,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
958957
LLVMLoop &loop = loops.back();
959958

960959
// Create branches
961-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
962-
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
960+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
961+
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
963962

964963
// Use last reg
965964
assert(step.args.size() == 1);
@@ -979,8 +978,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
979978
LLVMLoop &loop = loops.back();
980979

981980
// Create branches
982-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
983-
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
981+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
982+
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
984983

985984
// Use last reg
986985
assert(step.args.size() == 1);
@@ -998,7 +997,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
998997
case LLVMInstruction::Type::BeginLoopCondition: {
999998
LLVMLoop loop;
1000999
loop.isRepeatLoop = false;
1001-
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1000+
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
10021001
m_builder.CreateBr(loop.conditionBranch);
10031002
m_builder.SetInsertPoint(loop.conditionBranch);
10041003
loops.push_back(loop);
@@ -1030,7 +1029,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10301029

10311030
case LLVMInstruction::Type::Stop: {
10321031
m_builder.CreateBr(endBranch);
1033-
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1032+
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
10341033
m_builder.SetInsertPoint(nextBranch);
10351034
break;
10361035
}
@@ -1046,7 +1045,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10461045
std::vector<llvm::Value *> args;
10471046

10481047
for (size_t i = 0; i < m_defaultArgCount; i++)
1049-
args.push_back(func->getArg(i));
1048+
args.push_back(m_function->getArg(i));
10501049

10511050
// Add warp arg
10521051
if (m_warp)
@@ -1065,12 +1064,12 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10651064
llvm::Value *handle = m_builder.CreateCall(resolveFunction(name, type), args);
10661065

10671066
if (!m_warp && !step.procedurePrototype->warp()) {
1068-
llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1069-
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1067+
llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
1068+
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
10701069
m_builder.CreateCondBr(m_builder.CreateIsNull(handle), nextBranch, suspendBranch);
10711070

10721071
m_builder.SetInsertPoint(suspendBranch);
1073-
createSuspend(coro.get(), func, warpArg, targetVariables);
1072+
createSuspend(coro.get(), warpArg, targetVariables);
10741073
name = getResumeFunctionName(step.procedurePrototype);
10751074
llvm::Value *done = m_builder.CreateCall(resolveFunction(name, m_resumeFuncType), { handle });
10761075
m_builder.CreateCondBr(done, nextBranch, suspendBranch);
@@ -1085,7 +1084,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10851084

10861085
case LLVMInstruction::Type::ProcedureArg: {
10871086
assert(m_procedurePrototype);
1088-
llvm::Value *arg = func->getArg(m_defaultArgCount + 1 + step.procedureArgIndex); // omit warp arg
1087+
llvm::Value *arg = m_function->getArg(m_defaultArgCount + 1 + step.procedureArgIndex); // omit warp arg
10891088
step.functionReturnReg->value = arg;
10901089
break;
10911090
}
@@ -1107,7 +1106,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
11071106
else
11081107
coro->end();
11091108

1110-
verifyFunction(func);
1109+
verifyFunction(m_function);
11111110

11121111
// Create resume function
11131112
// bool resume(void *)
@@ -1126,7 +1125,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
11261125

11271126
verifyFunction(resumeFunc);
11281127

1129-
return std::make_shared<LLVMExecutableCode>(m_ctx, func->getName().str(), resumeFunc->getName().str());
1128+
return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName().str(), resumeFunc->getName().str());
11301129
}
11311130

11321131
CompilerValue *LLVMCodeBuilder::addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args)
@@ -2029,7 +2028,7 @@ void LLVMCodeBuilder::reloadLists()
20292028
}
20302029
}
20312030

2032-
void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr, llvm::Function *func)
2031+
void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
20332032
{
20342033
// dataPtr = dirty ? list_data(list) : dataPtr
20352034
// dirty = false
@@ -2190,21 +2189,21 @@ void LLVMCodeBuilder::copyStructField(llvm::Value *source, llvm::Value *target,
21902189
m_builder.CreateStore(m_builder.CreateLoad(fieldType, sourceField), targetField);
21912190
}
21922191

2193-
llvm::Value *LLVMCodeBuilder::getListItem(const LLVMListPtr &listPtr, llvm::Value *index, llvm::Function *func)
2192+
llvm::Value *LLVMCodeBuilder::getListItem(const LLVMListPtr &listPtr, llvm::Value *index)
21942193
{
2195-
updateListDataPtr(listPtr, func);
2194+
updateListDataPtr(listPtr);
21962195
return m_builder.CreateGEP(m_valueDataType, m_builder.CreateLoad(m_valueDataType->getPointerTo(), listPtr.dataPtr), index);
21972196
}
21982197

2199-
llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMRegister *item, llvm::Function *func)
2198+
llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMRegister *item)
22002199
{
22012200
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
2202-
llvm::BasicBlock *condBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2203-
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2204-
llvm::BasicBlock *cmpIfBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2205-
llvm::BasicBlock *cmpElseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2206-
llvm::BasicBlock *notFoundBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2207-
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2201+
llvm::BasicBlock *condBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2202+
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2203+
llvm::BasicBlock *cmpIfBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2204+
llvm::BasicBlock *cmpElseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2205+
llvm::BasicBlock *notFoundBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2206+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
22082207

22092208
// index = 0
22102209
llvm::Value *index = m_builder.CreateAlloca(m_builder.getInt64Ty());
@@ -2220,7 +2219,7 @@ llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMR
22202219
m_builder.SetInsertPoint(bodyBlock);
22212220
LLVMRegister currentItem(listPtr.type);
22222221
currentItem.isRawValue = false;
2223-
currentItem.value = getListItem(listPtr, m_builder.CreateLoad(m_builder.getInt64Ty(), index), func);
2222+
currentItem.value = getListItem(listPtr, m_builder.CreateLoad(m_builder.getInt64Ty(), index));
22242223
llvm::Value *cmp = createComparison(&currentItem, item, Comparison::EQ);
22252224
m_builder.CreateCondBr(cmp, cmpIfBlock, cmpElseBlock);
22262225

@@ -2481,14 +2480,14 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
24812480
}
24822481
}
24832482

2484-
void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Function *func, llvm::Value *warpArg, llvm::Value *targetVariables)
2483+
void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Value *warpArg, llvm::Value *targetVariables)
24852484
{
24862485
if (!m_warp) {
24872486
llvm::BasicBlock *suspendBranch, *nextBranch;
24882487

24892488
if (warpArg) {
2490-
suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2491-
nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2489+
suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
2490+
nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
24922491
m_builder.CreateCondBr(warpArg, nextBranch, suspendBranch);
24932492
m_builder.SetInsertPoint(suspendBranch);
24942493
}

0 commit comments

Comments
 (0)