Skip to content

Commit

Permalink
LLVMCodeBuilder: Add main function pointer member
Browse files Browse the repository at this point in the history
  • Loading branch information
adazem009 committed Jan 7, 2025
1 parent c8e7938 commit 802e71f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
117 changes: 58 additions & 59 deletions src/dev/engine/internal/llvm/llvmcodebuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,39 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
// Create function
std::string funcName = getMainFunctionName(m_procedurePrototype);
llvm::FunctionType *funcType = getMainFunctionType(m_procedurePrototype);
llvm::Function *func;

if (m_procedurePrototype)
func = getOrCreateFunction(funcName, funcType);
m_function = getOrCreateFunction(funcName, funcType);
else
func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);
m_function = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);

llvm::Value *executionContextPtr = func->getArg(0);
llvm::Value *targetPtr = func->getArg(1);
llvm::Value *targetVariables = func->getArg(2);
llvm::Value *targetLists = func->getArg(3);
llvm::Value *executionContextPtr = m_function->getArg(0);
llvm::Value *targetPtr = m_function->getArg(1);
llvm::Value *targetVariables = m_function->getArg(2);
llvm::Value *targetLists = m_function->getArg(3);
llvm::Value *warpArg = nullptr;

if (m_procedurePrototype)
warpArg = func->getArg(4);
warpArg = m_function->getArg(4);

if (m_procedurePrototype && m_warp)
func->addFnAttr(llvm::Attribute::InlineHint);
m_function->addFnAttr(llvm::Attribute::InlineHint);
else {
// NOTE: These attributes will be overriden by LLVMCompilerContext
// TODO: Optimize all functions, maybe it doesn't take so long
func->addFnAttr(llvm::Attribute::NoInline);
func->addFnAttr(llvm::Attribute::OptimizeNone);
m_function->addFnAttr(llvm::Attribute::NoInline);
m_function->addFnAttr(llvm::Attribute::OptimizeNone);
}

llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", func);
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", func);
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", m_function);
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", m_function);
m_builder.SetInsertPoint(entry);

// Init coroutine
std::unique_ptr<LLVMCoroutine> coro;

if (!m_warp)
coro = std::make_unique<LLVMCoroutine>(m_module, &m_builder, func);
coro = std::make_unique<LLVMCoroutine>(m_module, &m_builder, m_function);

std::vector<LLVMIfStatement> ifStatements;
std::vector<LLVMLoop> loops;
Expand Down Expand Up @@ -677,14 +676,14 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
llvm::Value *isAllocated = m_builder.CreateICmpUGT(allocatedSize, size);
llvm::BasicBlock *ifBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *ifBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(isAllocated, ifBlock, elseBlock);

// If there's enough space, use the allocated memory
m_builder.SetInsertPoint(ifBlock);
llvm::Value *itemPtr = getListItem(listPtr, size, func);
llvm::Value *itemPtr = getListItem(listPtr, size);
createReusedValueStore(arg.second, itemPtr, type);
m_builder.CreateStore(m_builder.CreateAdd(size, m_builder.getInt64(1)), listPtr.sizePtr);
m_builder.CreateBr(nextBlock);
Expand Down Expand Up @@ -748,7 +747,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = m_builder.CreateFPToUI(castValue(indexArg.second, indexArg.first), m_builder.getInt64Ty());
llvm::Value *itemPtr = getListItem(listPtr, index, func);
llvm::Value *itemPtr = getListItem(listPtr, index);
createValueStore(valueArg.second, itemPtr, type, listPtr.type);

auto &typeMap = m_scopeLists.back();
Expand Down Expand Up @@ -778,7 +777,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
const auto &arg = step.args[0];
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = m_builder.CreateFPToUI(castValue(arg.second, arg.first), m_builder.getInt64Ty());
step.functionReturnReg->value = getListItem(listPtr, index, func);
step.functionReturnReg->value = getListItem(listPtr, index);
step.functionReturnReg->setType(listPtr.type);
break;
}
Expand All @@ -795,28 +794,28 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
assert(step.args.size() == 1);
const auto &arg = step.args[0];
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
step.functionReturnReg->value = m_builder.CreateSIToFP(getListItemIndex(listPtr, arg.second, func), m_builder.getDoubleTy());
step.functionReturnReg->value = m_builder.CreateSIToFP(getListItemIndex(listPtr, arg.second), m_builder.getDoubleTy());
break;
}

case LLVMInstruction::Type::ListContainsItem: {
assert(step.args.size() == 1);
const auto &arg = step.args[0];
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
llvm::Value *index = getListItemIndex(listPtr, arg.second, func);
llvm::Value *index = getListItemIndex(listPtr, arg.second);
step.functionReturnReg->value = m_builder.CreateICmpSGT(index, llvm::ConstantInt::get(m_builder.getInt64Ty(), -1, true));
break;
}

case LLVMInstruction::Type::Yield:
// TODO: Do not allow use after suspend (use after free)
createSuspend(coro.get(), func, warpArg, targetVariables);
createSuspend(coro.get(), warpArg, targetVariables);
break;

case LLVMInstruction::Type::BeginIf: {
LLVMIfStatement statement;
statement.beforeIf = m_builder.GetInsertBlock();
statement.body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
statement.body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// Use last reg
assert(step.args.size() == 1);
Expand Down Expand Up @@ -847,13 +846,13 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()

// Jump to the branch after the if statement
assert(!statement.afterIf);
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", func);
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
freeScopeHeap();
m_builder.CreateBr(statement.afterIf);

// Create else branch
assert(!statement.elseBranch);
statement.elseBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
statement.elseBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// Since there's an else branch, the conditional instruction should jump to it
m_builder.SetInsertPoint(statement.beforeIf);
Expand All @@ -871,7 +870,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()

// Jump to the branch after the if statement
if (!statement.afterIf)
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", func);
statement.afterIf = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

m_builder.CreateBr(statement.afterIf);

Expand Down Expand Up @@ -900,9 +899,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
m_builder.CreateStore(zero, loop.index);

// Create branches
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// Use last reg for count
assert(step.args.size() == 1);
Expand All @@ -928,10 +927,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
// Check index
m_builder.SetInsertPoint(loop.conditionBranch);

llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

if (!loop.afterLoop)
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index);
comparison = m_builder.CreateOr(isInf, m_builder.CreateICmpULT(currentIndex, count));
Expand All @@ -958,8 +957,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
LLVMLoop &loop = loops.back();

// Create branches
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// Use last reg
assert(step.args.size() == 1);
Expand All @@ -979,8 +978,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
LLVMLoop &loop = loops.back();

// Create branches
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
loop.afterLoop = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// Use last reg
assert(step.args.size() == 1);
Expand All @@ -998,7 +997,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
case LLVMInstruction::Type::BeginLoopCondition: {
LLVMLoop loop;
loop.isRepeatLoop = false;
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
loop.conditionBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateBr(loop.conditionBranch);
m_builder.SetInsertPoint(loop.conditionBranch);
loops.push_back(loop);
Expand Down Expand Up @@ -1030,7 +1029,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()

case LLVMInstruction::Type::Stop: {
m_builder.CreateBr(endBranch);
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.SetInsertPoint(nextBranch);
break;
}
Expand All @@ -1046,7 +1045,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
std::vector<llvm::Value *> args;

for (size_t i = 0; i < m_defaultArgCount; i++)
args.push_back(func->getArg(i));
args.push_back(m_function->getArg(i));

// Add warp arg
if (m_warp)
Expand All @@ -1065,12 +1064,12 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
llvm::Value *handle = m_builder.CreateCall(resolveFunction(name, type), args);

if (!m_warp && !step.procedurePrototype->warp()) {
llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(m_builder.CreateIsNull(handle), nextBranch, suspendBranch);

m_builder.SetInsertPoint(suspendBranch);
createSuspend(coro.get(), func, warpArg, targetVariables);
createSuspend(coro.get(), warpArg, targetVariables);
name = getResumeFunctionName(step.procedurePrototype);
llvm::Value *done = m_builder.CreateCall(resolveFunction(name, m_resumeFuncType), { handle });
m_builder.CreateCondBr(done, nextBranch, suspendBranch);
Expand All @@ -1085,7 +1084,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()

case LLVMInstruction::Type::ProcedureArg: {
assert(m_procedurePrototype);
llvm::Value *arg = func->getArg(m_defaultArgCount + 1 + step.procedureArgIndex); // omit warp arg
llvm::Value *arg = m_function->getArg(m_defaultArgCount + 1 + step.procedureArgIndex); // omit warp arg
step.functionReturnReg->value = arg;
break;
}
Expand All @@ -1107,7 +1106,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
else
coro->end();

verifyFunction(func);
verifyFunction(m_function);

// Create resume function
// bool resume(void *)
Expand All @@ -1126,7 +1125,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()

verifyFunction(resumeFunc);

return std::make_shared<LLVMExecutableCode>(m_ctx, func->getName().str(), resumeFunc->getName().str());
return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName().str(), resumeFunc->getName().str());
}

CompilerValue *LLVMCodeBuilder::addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args)
Expand Down Expand Up @@ -2029,7 +2028,7 @@ void LLVMCodeBuilder::reloadLists()
}
}

void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr, llvm::Function *func)
void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
{
// dataPtr = dirty ? list_data(list) : dataPtr
// dirty = false
Expand Down Expand Up @@ -2190,21 +2189,21 @@ void LLVMCodeBuilder::copyStructField(llvm::Value *source, llvm::Value *target,
m_builder.CreateStore(m_builder.CreateLoad(fieldType, sourceField), targetField);
}

llvm::Value *LLVMCodeBuilder::getListItem(const LLVMListPtr &listPtr, llvm::Value *index, llvm::Function *func)
llvm::Value *LLVMCodeBuilder::getListItem(const LLVMListPtr &listPtr, llvm::Value *index)
{
updateListDataPtr(listPtr, func);
updateListDataPtr(listPtr);
return m_builder.CreateGEP(m_valueDataType, m_builder.CreateLoad(m_valueDataType->getPointerTo(), listPtr.dataPtr), index);
}

llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMRegister *item, llvm::Function *func)
llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMRegister *item)
{
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
llvm::BasicBlock *condBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *cmpIfBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *cmpElseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *notFoundBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", func);
llvm::BasicBlock *condBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *cmpIfBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *cmpElseBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *notFoundBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

// index = 0
llvm::Value *index = m_builder.CreateAlloca(m_builder.getInt64Ty());
Expand All @@ -2220,7 +2219,7 @@ llvm::Value *LLVMCodeBuilder::getListItemIndex(const LLVMListPtr &listPtr, LLVMR
m_builder.SetInsertPoint(bodyBlock);
LLVMRegister currentItem(listPtr.type);
currentItem.isRawValue = false;
currentItem.value = getListItem(listPtr, m_builder.CreateLoad(m_builder.getInt64Ty(), index), func);
currentItem.value = getListItem(listPtr, m_builder.CreateLoad(m_builder.getInt64Ty(), index));
llvm::Value *cmp = createComparison(&currentItem, item, Comparison::EQ);
m_builder.CreateCondBr(cmp, cmpIfBlock, cmpElseBlock);

Expand Down Expand Up @@ -2481,14 +2480,14 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
}
}

void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Function *func, llvm::Value *warpArg, llvm::Value *targetVariables)
void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Value *warpArg, llvm::Value *targetVariables)
{
if (!m_warp) {
llvm::BasicBlock *suspendBranch, *nextBranch;

if (warpArg) {
suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
m_builder.CreateCondBr(warpArg, nextBranch, suspendBranch);
m_builder.SetInsertPoint(suspendBranch);
}
Expand Down
Loading

0 comments on commit 802e71f

Please sign in to comment.