Skip to content

Commit

Permalink
Add function call LLVM code generation (#477)
Browse files Browse the repository at this point in the history
This patch adds support for function call code generation, particularly:

- User-defined procedures and functions can now lowered to LLVM IR.
- A framework for external method calls (e.g. sin, exp, etc.) has been created, currently `exp` and `pow` are supported.
- Corresponding tests added.

fixes #472
  • Loading branch information
georgemitenkov authored Dec 30, 2020
1 parent 3ef4d9c commit 183d647
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 11 deletions.
103 changes: 96 additions & 7 deletions src/codegen/llvm/codegen_llvm_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

#include "codegen/llvm/codegen_llvm_visitor.hpp"
#include "ast/all.hpp"
#include "codegen/codegen_helper_visitor.hpp"
#include "visitors/rename_visitor.hpp"
#include "visitors/visitor_utils.hpp"

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -44,7 +44,56 @@ void CodegenLLVMVisitor::run_llvm_opt_passes() {
}


void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) {
void CodegenLLVMVisitor::create_external_method_call(const std::string& name,
const ast::ExpressionVector& arguments) {
std::vector<llvm::Value*> argument_values;
std::vector<llvm::Type*> argument_types;
for (const auto& arg: arguments) {
arg->accept(*this);
llvm::Value* value = values.back();
llvm::Type* type = value->getType();
values.pop_back();
argument_types.push_back(type);
argument_values.push_back(value);
}

#define DISPATCH(method_name, intrinsic) \
if (name == method_name) { \
llvm::Value* result = builder.CreateIntrinsic(intrinsic, argument_types, argument_values); \
values.push_back(result); \
return; \
}

DISPATCH("exp", llvm::Intrinsic::exp);
DISPATCH("pow", llvm::Intrinsic::pow);
#undef DISPATCH

throw std::runtime_error("Error: External method" + name + " is not currently supported");
}

void CodegenLLVMVisitor::create_function_call(llvm::Function* func,
const std::string& name,
const ast::ExpressionVector& arguments) {
// Check that function is called with the expected number of arguments.
if (arguments.size() != func->arg_size()) {
throw std::runtime_error("Error: Incorrect number of arguments passed");
}

// Process each argument and add it to a vector to pass to the function call instruction. Note
// that type checks are not needed here as NMODL operates on doubles by default.
std::vector<llvm::Value*> argument_values;
for (const auto& arg: arguments) {
arg->accept(*this);
llvm::Value* value = values.back();
values.pop_back();
argument_values.push_back(value);
}

llvm::Value* call = builder.CreateCall(func, argument_values);
values.push_back(call);
}

void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block& node) {
const auto& name = node.get_node_name();
const auto& parameters = node.get_parameters();

Expand All @@ -57,11 +106,17 @@ void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) {
llvm::Type* return_type = node.is_function_block() ? llvm::Type::getDoubleTy(*context)
: llvm::Type::getVoidTy(*context);

llvm::Function* func =
llvm::Function::Create(llvm::FunctionType::get(return_type, arg_types, /*isVarArg=*/false),
llvm::Function::ExternalLinkage,
name,
*module);
// Create a function that is automatically inserted into module's symbol table.
llvm::Function::Create(llvm::FunctionType::get(return_type, arg_types, /*isVarArg=*/false),
llvm::Function::ExternalLinkage,
name,
*module);
}

void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) {
const auto& name = node.get_node_name();
const auto& parameters = node.get_parameters();
llvm::Function* func = module->getFunction(name);

// Create the entry basic block of the function/procedure and point the local named values table
// to the symbol table.
Expand Down Expand Up @@ -175,6 +230,22 @@ void CodegenLLVMVisitor::visit_function_block(const ast::FunctionBlock& node) {
visit_procedure_or_function(node);
}

void CodegenLLVMVisitor::visit_function_call(const ast::FunctionCall& node) {
const auto& name = node.get_node_name();
auto func = module->getFunction(name);
if (func) {
create_function_call(func, name, node.get_arguments());
} else {
auto symbol = sym_tab->lookup(name);
if (symbol && symbol->has_any_property(symtab::syminfo::NmodlType::extern_method)) {
create_external_method_call(name, node.get_arguments());
} else {
throw std::runtime_error("Error: Unknown function name: " + name +
". (External functions references are not supported)");
}
}
}

void CodegenLLVMVisitor::visit_integer(const ast::Integer& node) {
const auto& constant = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context),
node.get_value());
Expand All @@ -191,6 +262,24 @@ void CodegenLLVMVisitor::visit_local_list_statement(const ast::LocalListStatemen
}

void CodegenLLVMVisitor::visit_program(const ast::Program& node) {
// Before generating LLVM, gather information about AST. For now, information about functions
// and procedures is used only.
CodegenHelperVisitor v;
CodegenInfo info = v.analyze(node);

// For every function and procedure, generate its declaration. Thus, we can look up
// `llvm::Function` in the symbol table in the module.
for (const auto& func: info.functions) {
emit_procedure_or_function_declaration(*func);
}
for (const auto& proc: info.procedures) {
emit_procedure_or_function_declaration(*proc);
}

// Set the AST symbol table.
sym_tab = node.get_symbol_table();

// Proceed with code generation.
node.visit_children(*this);

if (opt_passes) {
Expand Down
32 changes: 31 additions & 1 deletion src/codegen/llvm/codegen_llvm_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ostream>
#include <string>

#include "symtab/symbol_table.hpp"
#include "utils/logger.hpp"
#include "visitors/ast_visitor.hpp"

Expand Down Expand Up @@ -69,7 +70,10 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
// Pointer to the local symbol table.
llvm::ValueSymbolTable* local_named_values = nullptr;

// Run optimisation passes if true
// Pointer to AST symbol table.
symtab::SymbolTable* sym_tab;

// Run optimisation passes if true.
bool opt_passes;

/**
Expand All @@ -96,6 +100,31 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
, builder(*context)
, fpm(module.get()) {}

/**
* Create a function call to an external method
* \param name external method name
* \param arguments expressions passed as arguments to the given external method
*/
void create_external_method_call(const std::string& name,
const ast::ExpressionVector& arguments);

/**
* Create a function call to NMODL function or procedure in the same mod file
* \param func LLVM function corresponding ti this call
* \param name function name
* \param arguments expressions passed as arguments to the function call
*/
void create_function_call(llvm::Function* func,
const std::string& name,
const ast::ExpressionVector& arguments);

/**
* Emit function or procedure declaration in LLVM given the node
*
* \param node the AST node representing the function or procedure in NMODL
*/
void emit_procedure_or_function_declaration(const ast::Block& node);

/**
* Visit nmodl function or procedure
* \param node the AST node representing the function or procedure in NMODL
Expand All @@ -107,6 +136,7 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
void visit_boolean(const ast::Boolean& node) override;
void visit_double(const ast::Double& node) override;
void visit_function_block(const ast::FunctionBlock& node) override;
void visit_function_call(const ast::FunctionCall& node) override;
void visit_integer(const ast::Integer& node) override;
void visit_local_list_statement(const ast::LocalListStatement& node) override;
void visit_procedure_block(const ast::ProcedureBlock& node) override;
Expand Down
3 changes: 2 additions & 1 deletion test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ if(NMODL_ENABLE_LLVM)
add_executable(testllvm visitor/main.cpp codegen/llvm.cpp)
target_link_libraries(
testllvm
llvm_codegen
codegen
visitor
symtab
lexer
util
test_util
printer
llvm_codegen
${NMODL_WRAPPER_LIBS}
${LLVM_LIBS_TO_LINK})
set(CODEGEN_TEST testllvm)
Expand Down
104 changes: 102 additions & 2 deletions test/unit/codegen/llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "codegen/llvm/codegen_llvm_visitor.hpp"
#include "parser/nmodl_driver.hpp"
#include "visitors/checkparent_visitor.hpp"
#include "visitors/inline_visitor.hpp"
#include "visitors/symtab_visitor.hpp"

using namespace nmodl;
Expand All @@ -28,7 +27,6 @@ std::string run_llvm_visitor(const std::string& text, bool opt = false) {
const auto& ast = driver.parse_string(text);

SymtabVisitor().visit_program(*ast);
InlineVisitor().visit_program(*ast);

codegen::CodegenLLVMVisitor llvm_visitor("unknown", ".", opt);
llvm_visitor.visit_program(*ast);
Expand Down Expand Up @@ -156,6 +154,108 @@ SCENARIO("Function", "[visitor][llvm]") {
}
}

//=============================================================================
// FunctionCall
//=============================================================================

SCENARIO("Function call", "[visitor][llvm]") {
GIVEN("A call to procedure") {
std::string nmodl_text = R"(
PROCEDURE bar() {}
FUNCTION foo() {
bar()
}
)";

THEN("a void call instruction is created") {
std::string module_string = run_llvm_visitor(nmodl_text);
std::smatch m;

// Check for call instruction.
std::regex call(R"(call void @bar\(\))");
REQUIRE(std::regex_search(module_string, m, call));
}
}

GIVEN("A call to function declared below the caller") {
std::string nmodl_text = R"(
FUNCTION foo(x) {
foo = 4 * bar()
}
FUNCTION bar() {
bar = 5
}
)";

THEN("a correct call instruction is created") {
std::string module_string = run_llvm_visitor(nmodl_text);
std::smatch m;

// Check for call instruction.
std::regex call(R"(%[0-9]+ = call double @bar\(\))");
REQUIRE(std::regex_search(module_string, m, call));
}
}

GIVEN("A call to function with arguments") {
std::string nmodl_text = R"(
FUNCTION foo(x, y) {
foo = 4 * x - y
}
FUNCTION bar(i) {
bar = foo(i, 4)
}
)";

THEN("arguments are processed before the call and passed to call instruction") {
std::string module_string = run_llvm_visitor(nmodl_text);
std::smatch m;

// Check correct arguments.
std::regex i(R"(%1 = load double, double\* %i)");
std::regex call(R"(call double @foo\(double %1, double 4.000000e\+00\))");
REQUIRE(std::regex_search(module_string, m, i));
REQUIRE(std::regex_search(module_string, m, call));
}
}

GIVEN("A call to external method") {
std::string nmodl_text = R"(
FUNCTION bar(i) {
bar = exp(i)
}
)";

THEN("LLVM intrinsic corresponding to this method is created") {
std::string module_string = run_llvm_visitor(nmodl_text);
std::smatch m;

// Check for intrinsic declaration.
std::regex exp(R"(declare double @llvm\.exp\.f64\(double\))");
REQUIRE(std::regex_search(module_string, m, exp));

// Check the correct call is made.
std::regex call(R"(call double @llvm\.exp\.f64\(double %[0-9]+\))");
REQUIRE(std::regex_search(module_string, m, call));
}
}

GIVEN("A call to function with the wrong number of arguments") {
std::string nmodl_text = R"(
FUNCTION foo(x, y) {
foo = 4 * x - y
}
FUNCTION bar(i) {
bar = foo(i)
}
)";

THEN("a runtime error is thrown") {
REQUIRE_THROWS_AS(run_llvm_visitor(nmodl_text), std::runtime_error);
}
}
}

//=============================================================================
// LocalList and LocalVar
//=============================================================================
Expand Down

0 comments on commit 183d647

Please sign in to comment.