diff --git a/src/codegen/llvm/codegen_llvm_visitor.cpp b/src/codegen/llvm/codegen_llvm_visitor.cpp index d99e519dca..430f3d78de 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.cpp +++ b/src/codegen/llvm/codegen_llvm_visitor.cpp @@ -7,8 +7,8 @@ #include "codegen/llvm/codegen_llvm_visitor.hpp" #include "ast/all.hpp" +#include "codegen/codegen_helper_visitor.hpp" #include "visitors/rename_visitor.hpp" -#include "visitors/visitor_utils.hpp" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -44,7 +44,56 @@ void CodegenLLVMVisitor::run_llvm_opt_passes() { } -void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) { +void CodegenLLVMVisitor::create_external_method_call(const std::string& name, + const ast::ExpressionVector& arguments) { + std::vector argument_values; + std::vector argument_types; + for (const auto& arg: arguments) { + arg->accept(*this); + llvm::Value* value = values.back(); + llvm::Type* type = value->getType(); + values.pop_back(); + argument_types.push_back(type); + argument_values.push_back(value); + } + +#define DISPATCH(method_name, intrinsic) \ + if (name == method_name) { \ + llvm::Value* result = builder.CreateIntrinsic(intrinsic, argument_types, argument_values); \ + values.push_back(result); \ + return; \ + } + + DISPATCH("exp", llvm::Intrinsic::exp); + DISPATCH("pow", llvm::Intrinsic::pow); +#undef DISPATCH + + throw std::runtime_error("Error: External method" + name + " is not currently supported"); +} + +void CodegenLLVMVisitor::create_function_call(llvm::Function* func, + const std::string& name, + const ast::ExpressionVector& arguments) { + // Check that function is called with the expected number of arguments. + if (arguments.size() != func->arg_size()) { + throw std::runtime_error("Error: Incorrect number of arguments passed"); + } + + // Process each argument and add it to a vector to pass to the function call instruction. Note + // that type checks are not needed here as NMODL operates on doubles by default. + std::vector argument_values; + for (const auto& arg: arguments) { + arg->accept(*this); + llvm::Value* value = values.back(); + values.pop_back(); + argument_values.push_back(value); + } + + llvm::Value* call = builder.CreateCall(func, argument_values); + values.push_back(call); +} + +void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block& node) { const auto& name = node.get_node_name(); const auto& parameters = node.get_parameters(); @@ -57,11 +106,17 @@ void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) { llvm::Type* return_type = node.is_function_block() ? llvm::Type::getDoubleTy(*context) : llvm::Type::getVoidTy(*context); - llvm::Function* func = - llvm::Function::Create(llvm::FunctionType::get(return_type, arg_types, /*isVarArg=*/false), - llvm::Function::ExternalLinkage, - name, - *module); + // Create a function that is automatically inserted into module's symbol table. + llvm::Function::Create(llvm::FunctionType::get(return_type, arg_types, /*isVarArg=*/false), + llvm::Function::ExternalLinkage, + name, + *module); +} + +void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) { + const auto& name = node.get_node_name(); + const auto& parameters = node.get_parameters(); + llvm::Function* func = module->getFunction(name); // Create the entry basic block of the function/procedure and point the local named values table // to the symbol table. @@ -175,6 +230,22 @@ void CodegenLLVMVisitor::visit_function_block(const ast::FunctionBlock& node) { visit_procedure_or_function(node); } +void CodegenLLVMVisitor::visit_function_call(const ast::FunctionCall& node) { + const auto& name = node.get_node_name(); + auto func = module->getFunction(name); + if (func) { + create_function_call(func, name, node.get_arguments()); + } else { + auto symbol = sym_tab->lookup(name); + if (symbol && symbol->has_any_property(symtab::syminfo::NmodlType::extern_method)) { + create_external_method_call(name, node.get_arguments()); + } else { + throw std::runtime_error("Error: Unknown function name: " + name + + ". (External functions references are not supported)"); + } + } +} + void CodegenLLVMVisitor::visit_integer(const ast::Integer& node) { const auto& constant = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), node.get_value()); @@ -191,6 +262,24 @@ void CodegenLLVMVisitor::visit_local_list_statement(const ast::LocalListStatemen } void CodegenLLVMVisitor::visit_program(const ast::Program& node) { + // Before generating LLVM, gather information about AST. For now, information about functions + // and procedures is used only. + CodegenHelperVisitor v; + CodegenInfo info = v.analyze(node); + + // For every function and procedure, generate its declaration. Thus, we can look up + // `llvm::Function` in the symbol table in the module. + for (const auto& func: info.functions) { + emit_procedure_or_function_declaration(*func); + } + for (const auto& proc: info.procedures) { + emit_procedure_or_function_declaration(*proc); + } + + // Set the AST symbol table. + sym_tab = node.get_symbol_table(); + + // Proceed with code generation. node.visit_children(*this); if (opt_passes) { diff --git a/src/codegen/llvm/codegen_llvm_visitor.hpp b/src/codegen/llvm/codegen_llvm_visitor.hpp index 6b94ecffbe..32347bdabd 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.hpp +++ b/src/codegen/llvm/codegen_llvm_visitor.hpp @@ -18,6 +18,7 @@ #include #include +#include "symtab/symbol_table.hpp" #include "utils/logger.hpp" #include "visitors/ast_visitor.hpp" @@ -69,7 +70,10 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { // Pointer to the local symbol table. llvm::ValueSymbolTable* local_named_values = nullptr; - // Run optimisation passes if true + // Pointer to AST symbol table. + symtab::SymbolTable* sym_tab; + + // Run optimisation passes if true. bool opt_passes; /** @@ -96,6 +100,31 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { , builder(*context) , fpm(module.get()) {} + /** + * Create a function call to an external method + * \param name external method name + * \param arguments expressions passed as arguments to the given external method + */ + void create_external_method_call(const std::string& name, + const ast::ExpressionVector& arguments); + + /** + * Create a function call to NMODL function or procedure in the same mod file + * \param func LLVM function corresponding ti this call + * \param name function name + * \param arguments expressions passed as arguments to the function call + */ + void create_function_call(llvm::Function* func, + const std::string& name, + const ast::ExpressionVector& arguments); + + /** + * Emit function or procedure declaration in LLVM given the node + * + * \param node the AST node representing the function or procedure in NMODL + */ + void emit_procedure_or_function_declaration(const ast::Block& node); + /** * Visit nmodl function or procedure * \param node the AST node representing the function or procedure in NMODL @@ -107,6 +136,7 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { void visit_boolean(const ast::Boolean& node) override; void visit_double(const ast::Double& node) override; void visit_function_block(const ast::FunctionBlock& node) override; + void visit_function_call(const ast::FunctionCall& node) override; void visit_integer(const ast::Integer& node) override; void visit_local_list_statement(const ast::LocalListStatement& node) override; void visit_procedure_block(const ast::ProcedureBlock& node) override; diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 29957a7530..f9c76827fd 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -109,13 +109,14 @@ if(NMODL_ENABLE_LLVM) add_executable(testllvm visitor/main.cpp codegen/llvm.cpp) target_link_libraries( testllvm + llvm_codegen + codegen visitor symtab lexer util test_util printer - llvm_codegen ${NMODL_WRAPPER_LIBS} ${LLVM_LIBS_TO_LINK}) set(CODEGEN_TEST testllvm) diff --git a/test/unit/codegen/llvm.cpp b/test/unit/codegen/llvm.cpp index 9c86e8c30a..d2c0a65e86 100644 --- a/test/unit/codegen/llvm.cpp +++ b/test/unit/codegen/llvm.cpp @@ -12,7 +12,6 @@ #include "codegen/llvm/codegen_llvm_visitor.hpp" #include "parser/nmodl_driver.hpp" #include "visitors/checkparent_visitor.hpp" -#include "visitors/inline_visitor.hpp" #include "visitors/symtab_visitor.hpp" using namespace nmodl; @@ -28,7 +27,6 @@ std::string run_llvm_visitor(const std::string& text, bool opt = false) { const auto& ast = driver.parse_string(text); SymtabVisitor().visit_program(*ast); - InlineVisitor().visit_program(*ast); codegen::CodegenLLVMVisitor llvm_visitor("unknown", ".", opt); llvm_visitor.visit_program(*ast); @@ -156,6 +154,108 @@ SCENARIO("Function", "[visitor][llvm]") { } } +//============================================================================= +// FunctionCall +//============================================================================= + +SCENARIO("Function call", "[visitor][llvm]") { + GIVEN("A call to procedure") { + std::string nmodl_text = R"( + PROCEDURE bar() {} + FUNCTION foo() { + bar() + } + )"; + + THEN("a void call instruction is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for call instruction. + std::regex call(R"(call void @bar\(\))"); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to function declared below the caller") { + std::string nmodl_text = R"( + FUNCTION foo(x) { + foo = 4 * bar() + } + FUNCTION bar() { + bar = 5 + } + )"; + + THEN("a correct call instruction is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for call instruction. + std::regex call(R"(%[0-9]+ = call double @bar\(\))"); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to function with arguments") { + std::string nmodl_text = R"( + FUNCTION foo(x, y) { + foo = 4 * x - y + } + FUNCTION bar(i) { + bar = foo(i, 4) + } + )"; + + THEN("arguments are processed before the call and passed to call instruction") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check correct arguments. + std::regex i(R"(%1 = load double, double\* %i)"); + std::regex call(R"(call double @foo\(double %1, double 4.000000e\+00\))"); + REQUIRE(std::regex_search(module_string, m, i)); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to external method") { + std::string nmodl_text = R"( + FUNCTION bar(i) { + bar = exp(i) + } + )"; + + THEN("LLVM intrinsic corresponding to this method is created") { + std::string module_string = run_llvm_visitor(nmodl_text); + std::smatch m; + + // Check for intrinsic declaration. + std::regex exp(R"(declare double @llvm\.exp\.f64\(double\))"); + REQUIRE(std::regex_search(module_string, m, exp)); + + // Check the correct call is made. + std::regex call(R"(call double @llvm\.exp\.f64\(double %[0-9]+\))"); + REQUIRE(std::regex_search(module_string, m, call)); + } + } + + GIVEN("A call to function with the wrong number of arguments") { + std::string nmodl_text = R"( + FUNCTION foo(x, y) { + foo = 4 * x - y + } + FUNCTION bar(i) { + bar = foo(i) + } + )"; + + THEN("a runtime error is thrown") { + REQUIRE_THROWS_AS(run_llvm_visitor(nmodl_text), std::runtime_error); + } + } +} + //============================================================================= // LocalList and LocalVar //=============================================================================