diff --git a/src/codegen/llvm/codegen_llvm_visitor.cpp b/src/codegen/llvm/codegen_llvm_visitor.cpp index 6228b39d04..6f134149e3 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.cpp +++ b/src/codegen/llvm/codegen_llvm_visitor.cpp @@ -65,6 +65,12 @@ unsigned CodegenLLVMVisitor::get_array_index_or_length(const ast::IndexedName& i return static_cast(*macro->get_value()); } +llvm::Type* CodegenLLVMVisitor::get_default_fp_type() { + if (use_single_precision) + return llvm::Type::getFloatTy(*context); + return llvm::Type::getDoubleTy(*context); +} + void CodegenLLVMVisitor::run_llvm_opt_passes() { /// run some common optimisation passes that are commonly suggested fpm.add(llvm::createInstructionCombiningPass()); @@ -139,10 +145,10 @@ void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block // Procedure or function parameters are doubles by default. std::vector arg_types; for (size_t i = 0; i < parameters.size(); ++i) - arg_types.push_back(llvm::Type::getDoubleTy(*context)); + arg_types.push_back(get_default_fp_type()); // If visiting a function, the return type is a double by default. - llvm::Type* return_type = node.is_function_block() ? llvm::Type::getDoubleTy(*context) + llvm::Type* return_type = node.is_function_block() ? get_default_fp_type() : llvm::Type::getVoidTy(*context); // Create a function that is automatically inserted into module's symbol table. @@ -152,6 +158,90 @@ void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block *module); } +llvm::Value* CodegenLLVMVisitor::visit_arithmetic_bin_op(llvm::Value* lhs, + llvm::Value* rhs, + unsigned op) { + const auto& bin_op = static_cast(op); + llvm::Type* lhs_type = lhs->getType(); + llvm::Value* result; + + switch (bin_op) { +#define DISPATCH(binary_op, llvm_fp_op, llvm_int_op) \ + case binary_op: \ + if (lhs_type->isDoubleTy() || lhs_type->isFloatTy()) \ + result = llvm_fp_op(lhs, rhs); \ + else \ + result = llvm_int_op(lhs, rhs); \ + return result; + + DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd, builder.CreateAdd); + DISPATCH(ast::BinaryOp::BOP_DIVISION, builder.CreateFDiv, builder.CreateSDiv); + DISPATCH(ast::BinaryOp::BOP_MULTIPLICATION, builder.CreateFMul, builder.CreateMul); + DISPATCH(ast::BinaryOp::BOP_SUBTRACTION, builder.CreateFSub, builder.CreateSub); + +#undef DISPATCH + + default: + return nullptr; + } +} + +void CodegenLLVMVisitor::visit_assign_op(const ast::BinaryExpression& node, llvm::Value* rhs) { + auto var = dynamic_cast(node.get_lhs().get()); + if (!var) { + throw std::runtime_error("Error: only VarName assignment is currently supported.\n"); + } + + const auto& identifier = var->get_name(); + if (identifier->is_name()) { + llvm::Value* alloca = local_named_values->lookup(var->get_node_name()); + builder.CreateStore(rhs, alloca); + } else if (identifier->is_indexed_name()) { + auto indexed_name = std::dynamic_pointer_cast(identifier); + builder.CreateStore(rhs, codegen_indexed_name(*indexed_name)); + } else { + throw std::runtime_error("Error: Unsupported variable type"); + } +} + +llvm::Value* CodegenLLVMVisitor::visit_logical_bin_op(llvm::Value* lhs, + llvm::Value* rhs, + unsigned op) { + const auto& bin_op = static_cast(op); + return bin_op == ast::BinaryOp::BOP_AND ? builder.CreateAnd(lhs, rhs) + : builder.CreateOr(lhs, rhs); +} + +llvm::Value* CodegenLLVMVisitor::visit_comparison_bin_op(llvm::Value* lhs, + llvm::Value* rhs, + unsigned op) { + const auto& bin_op = static_cast(op); + llvm::Type* lhs_type = lhs->getType(); + llvm::Value* result; + + switch (bin_op) { +#define DISPATCH(binary_op, f_llvm_op, i_llvm_op) \ + case binary_op: \ + if (lhs_type->isDoubleTy() || lhs_type->isFloatTy()) \ + result = f_llvm_op(lhs, rhs); \ + else \ + result = i_llvm_op(lhs, rhs); \ + return result; + + DISPATCH(ast::BinaryOp::BOP_EXACT_EQUAL, builder.CreateICmpEQ, builder.CreateFCmpOEQ); + DISPATCH(ast::BinaryOp::BOP_GREATER, builder.CreateICmpSGT, builder.CreateFCmpOGT); + DISPATCH(ast::BinaryOp::BOP_GREATER_EQUAL, builder.CreateICmpSGE, builder.CreateFCmpOGE); + DISPATCH(ast::BinaryOp::BOP_LESS, builder.CreateICmpSLT, builder.CreateFCmpOLT); + DISPATCH(ast::BinaryOp::BOP_LESS_EQUAL, builder.CreateICmpSLE, builder.CreateFCmpOLE); + DISPATCH(ast::BinaryOp::BOP_NOT_EQUAL, builder.CreateICmpNE, builder.CreateFCmpONE); + +#undef DISPATCH + + default: + return nullptr; + } +} + void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) { const auto& name = node.get_node_name(); const auto& parameters = node.get_parameters(); @@ -222,44 +312,39 @@ void CodegenLLVMVisitor::visit_binary_expression(const ast::BinaryExpression& no llvm::Value* rhs = values.back(); values.pop_back(); if (op == ast::BinaryOp::BOP_ASSIGN) { - auto var = dynamic_cast(node.get_lhs().get()); - if (!var) { - throw std::runtime_error("Error: only VarName assignment is currently supported.\n"); - } - - const auto& identifier = var->get_name(); - if (identifier->is_name()) { - llvm::Value* alloca = local_named_values->lookup(var->get_node_name()); - builder.CreateStore(rhs, alloca); - } else if (identifier->is_indexed_name()) { - auto indexed_name = std::dynamic_pointer_cast(identifier); - builder.CreateStore(rhs, codegen_indexed_name(*indexed_name)); - } else { - throw std::runtime_error("Error: Unsupported variable type"); - } + visit_assign_op(node, rhs); return; } node.get_lhs()->accept(*this); llvm::Value* lhs = values.back(); values.pop_back(); - llvm::Value* result; - // \todo: Support other binary operators + llvm::Value* result; switch (op) { -#define DISPATCH(binary_op, llvm_op) \ - case binary_op: \ - result = llvm_op(lhs, rhs); \ - values.push_back(result); \ + case ast::BOP_ADDITION: + case ast::BOP_DIVISION: + case ast::BOP_MULTIPLICATION: + case ast::BOP_SUBTRACTION: + result = visit_arithmetic_bin_op(lhs, rhs, op); break; - - DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd); - DISPATCH(ast::BinaryOp::BOP_DIVISION, builder.CreateFDiv); - DISPATCH(ast::BinaryOp::BOP_MULTIPLICATION, builder.CreateFMul); - DISPATCH(ast::BinaryOp::BOP_SUBTRACTION, builder.CreateFSub); - -#undef DISPATCH + case ast::BOP_AND: + case ast::BOP_OR: + result = visit_logical_bin_op(lhs, rhs, op); + break; + case ast::BOP_EXACT_EQUAL: + case ast::BOP_GREATER: + case ast::BOP_GREATER_EQUAL: + case ast::BOP_LESS: + case ast::BOP_LESS_EQUAL: + case ast::BOP_NOT_EQUAL: + result = visit_comparison_bin_op(lhs, rhs, op); + break; + default: + throw std::runtime_error("Error: binary operator is not supported\n"); } + + values.push_back(result); } void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) { @@ -269,8 +354,7 @@ void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) { } void CodegenLLVMVisitor::visit_double(const ast::Double& node) { - const auto& constant = llvm::ConstantFP::get(llvm::Type::getDoubleTy(*context), - node.get_value()); + const auto& constant = llvm::ConstantFP::get(get_default_fp_type(), node.get_value()); values.push_back(constant); } @@ -310,10 +394,10 @@ void CodegenLLVMVisitor::visit_local_list_statement(const ast::LocalListStatemen if (identifier->is_indexed_name()) { auto indexed_name = std::dynamic_pointer_cast(identifier); unsigned length = get_array_index_or_length(*indexed_name); - var_type = llvm::ArrayType::get(llvm::Type::getDoubleTy(*context), length); + var_type = llvm::ArrayType::get(get_default_fp_type(), length); } else if (identifier->is_name()) { // This case corresponds to a scalar local variable. Its type is double by default. - var_type = llvm::Type::getDoubleTy(*context); + var_type = get_default_fp_type(); } else { throw std::runtime_error("Error: Unsupported local variable type"); } @@ -367,10 +451,10 @@ void CodegenLLVMVisitor::visit_unary_expression(const ast::UnaryExpression& node llvm::Value* value = values.back(); values.pop_back(); if (op == ast::UOP_NEGATION) { - llvm::Value* result = builder.CreateFNeg(value); - values.push_back(result); + values.push_back(builder.CreateFNeg(value)); + } else if (op == ast::UOP_NOT) { + values.push_back(builder.CreateNot(value)); } else { - // Support only `double` operators for now. throw std::runtime_error("Error: unsupported unary operator\n"); } } diff --git a/src/codegen/llvm/codegen_llvm_visitor.hpp b/src/codegen/llvm/codegen_llvm_visitor.hpp index 599cfc7b58..066bdf35e3 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.hpp +++ b/src/codegen/llvm/codegen_llvm_visitor.hpp @@ -76,6 +76,9 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { // Run optimisation passes if true. bool opt_passes; + // Use 32-bit floating-point type if true. Otherwise, use deafult 64-bit. + bool use_single_precision; + /** *\brief Run LLVM optimisation passes on generated IR * @@ -93,10 +96,12 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { */ CodegenLLVMVisitor(const std::string& mod_filename, const std::string& output_dir, - bool opt_passes) + bool opt_passes, + bool use_single_precision = false) : mod_filename(mod_filename) , output_dir(output_dir) , opt_passes(opt_passes) + , use_single_precision(use_single_precision) , builder(*context) , fpm(module.get()) {} @@ -129,6 +134,12 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { */ unsigned get_array_index_or_length(const ast::IndexedName& node); + /** + * Returns 64-bit or 32-bit LLVM floating type + * \return \c LLVM floating point type according to `use_single_precision` flag + */ + llvm::Type* get_default_fp_type(); + /** * Create a function call to an external method * \param name external method name @@ -162,6 +173,40 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { return std::move(module); } + /** + * Visit nmodl arithmetic binary operator + * \param lhs LLVM value of evaluated lhs expression + * \param rhs LLVM value of evaluated rhs expression + * \param op the AST binary operator (ADD, DIV, MUL, SUB) + * \return LLVM IR value result + */ + llvm::Value* visit_arithmetic_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op); + + /** + * Visit nmodl assignment operator (ASSIGN) + * \param node the AST node representing the binary expression in NMODL + * \param rhs LLVM value of evaluated rhs expression + */ + void visit_assign_op(const ast::BinaryExpression& node, llvm::Value* rhs); + + /** + * Visit nmodl logical binary operator + * \param lhs LLVM value of evaluated lhs expression + * \param rhs LLVM value of evaluated rhs expression + * \param op the AST binary operator (AND, OR) + * \return LLVM IR value result + */ + llvm::Value* visit_logical_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op); + + /** + * Visit nmodl comparison binary operator + * \param lhs LLVM value of evaluated lhs expression + * \param rhs LLVM value of evaluated rhs expression + * \param op the AST binary operator (EXACT_EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, + * NOT_EQUAL) \return LLVM IR value result + */ + llvm::Value* visit_comparison_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op); + /** * Visit nmodl function or procedure * \param node the AST node representing the function or procedure in NMODL diff --git a/src/nmodl/main.cpp b/src/nmodl/main.cpp index 49cf3e333e..b9989062cf 100644 --- a/src/nmodl/main.cpp +++ b/src/nmodl/main.cpp @@ -170,6 +170,9 @@ int main(int argc, const char* argv[]) { /// generate llvm IR bool llvm_ir(false); + /// use single precision floating-point types + bool llvm_float_type(false); + /// run llvm optimisation passes bool llvm_opt_passes(false); #endif @@ -287,6 +290,9 @@ int main(int argc, const char* argv[]) { llvm_opt->add_flag("--opt", llvm_opt_passes, "Run LLVM optimisation passes ({})"_format(llvm_opt_passes))->ignore_case(); + llvm_opt->add_flag("--single-precision", + llvm_float_type, + "Use single precision floating-point types ({})"_format(llvm_float_type))->ignore_case(); #endif // clang-format on @@ -573,7 +579,7 @@ int main(int argc, const char* argv[]) { #ifdef NMODL_LLVM_BACKEND if (llvm_ir) { logger->info("Running LLVM backend code generator"); - CodegenLLVMVisitor visitor(modfile, output_dir, llvm_opt_passes); + CodegenLLVMVisitor visitor(modfile, output_dir, llvm_opt_passes, llvm_float_type); visitor.visit_program(*ast); ast_to_nmodl(*ast, filepath("llvm")); } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 1fe2778fc9..dbdc1b68d3 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -95,7 +95,7 @@ target_link_libraries( if(NMODL_ENABLE_LLVM) include_directories(${LLVM_INCLUDE_DIRS}) - add_executable(testllvm visitor/main.cpp codegen/llvm.cpp) + add_executable(testllvm visitor/main.cpp codegen/codegen_llvm_ir.cpp) add_executable(test_llvm_runner visitor/main.cpp codegen/codegen_llvm_execution.cpp) target_link_libraries( testllvm diff --git a/test/unit/codegen/llvm.cpp b/test/unit/codegen/codegen_llvm_ir.cpp similarity index 95% rename from test/unit/codegen/llvm.cpp rename to test/unit/codegen/codegen_llvm_ir.cpp index d644947e79..e44b2b15cd 100644 --- a/test/unit/codegen/llvm.cpp +++ b/test/unit/codegen/codegen_llvm_ir.cpp @@ -22,13 +22,18 @@ using nmodl::parser::NmodlDriver; // Utility to get LLVM module as a string //============================================================================= -std::string run_llvm_visitor(const std::string& text, bool opt = false) { +std::string run_llvm_visitor(const std::string& text, + bool opt = false, + bool use_single_precision = false) { NmodlDriver driver; const auto& ast = driver.parse_string(text); SymtabVisitor().visit_program(*ast); - codegen::CodegenLLVMVisitor llvm_visitor("unknown", ".", opt); + codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown", + /*output_dir=*/".", + opt, + use_single_precision); llvm_visitor.visit_program(*ast); return llvm_visitor.print_module(); } @@ -47,14 +52,15 @@ SCENARIO("Binary expression", "[visitor][llvm]") { )"; THEN("variables are loaded and add instruction is created") { - std::string module_string = run_llvm_visitor(nmodl_text); + std::string module_string = + run_llvm_visitor(nmodl_text, /*opt=*/false, /*use_single_precision=*/true); std::smatch m; - std::regex rhs(R"(%1 = load double, double\* %b)"); - std::regex lhs(R"(%2 = load double, double\* %a)"); - std::regex res(R"(%3 = fadd double %2, %1)"); + std::regex rhs(R"(%1 = load float, float\* %b)"); + std::regex lhs(R"(%2 = load float, float\* %a)"); + std::regex res(R"(%3 = fadd float %2, %1)"); - // Check the values are loaded correctly and added + // Check the float values are loaded correctly and added REQUIRE(std::regex_search(module_string, m, rhs)); REQUIRE(std::regex_search(module_string, m, lhs)); REQUIRE(std::regex_search(module_string, m, res));