Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended support for binary ops and refactoring #489

Merged
merged 6 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 121 additions & 37 deletions src/codegen/llvm/codegen_llvm_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ unsigned CodegenLLVMVisitor::get_array_index_or_length(const ast::IndexedName& i
return static_cast<unsigned>(*macro->get_value());
}

llvm::Type* CodegenLLVMVisitor::get_default_fp_type() {
if (use_single_precision)
return llvm::Type::getFloatTy(*context);
return llvm::Type::getDoubleTy(*context);
}

void CodegenLLVMVisitor::run_llvm_opt_passes() {
/// run some common optimisation passes that are commonly suggested
fpm.add(llvm::createInstructionCombiningPass());
Expand Down Expand Up @@ -139,10 +145,10 @@ void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block
// Procedure or function parameters are doubles by default.
std::vector<llvm::Type*> arg_types;
for (size_t i = 0; i < parameters.size(); ++i)
arg_types.push_back(llvm::Type::getDoubleTy(*context));
arg_types.push_back(get_default_fp_type());

// If visiting a function, the return type is a double by default.
llvm::Type* return_type = node.is_function_block() ? llvm::Type::getDoubleTy(*context)
llvm::Type* return_type = node.is_function_block() ? get_default_fp_type()
: llvm::Type::getVoidTy(*context);

// Create a function that is automatically inserted into module's symbol table.
Expand All @@ -152,6 +158,90 @@ void CodegenLLVMVisitor::emit_procedure_or_function_declaration(const ast::Block
*module);
}

llvm::Value* CodegenLLVMVisitor::visit_arithmetic_bin_op(llvm::Value* lhs,
llvm::Value* rhs,
unsigned op) {
const auto& bin_op = static_cast<ast::BinaryOp>(op);
llvm::Type* lhs_type = lhs->getType();
llvm::Value* result;

switch (bin_op) {
#define DISPATCH(binary_op, llvm_fp_op, llvm_int_op) \
case binary_op: \
if (lhs_type->isDoubleTy() || lhs_type->isFloatTy()) \
result = llvm_fp_op(lhs, rhs); \
else \
result = llvm_int_op(lhs, rhs); \
return result;

DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd, builder.CreateAdd);
DISPATCH(ast::BinaryOp::BOP_DIVISION, builder.CreateFDiv, builder.CreateSDiv);
DISPATCH(ast::BinaryOp::BOP_MULTIPLICATION, builder.CreateFMul, builder.CreateMul);
DISPATCH(ast::BinaryOp::BOP_SUBTRACTION, builder.CreateFSub, builder.CreateSub);

#undef DISPATCH

default:
return nullptr;
}
}

void CodegenLLVMVisitor::visit_assign_op(const ast::BinaryExpression& node, llvm::Value* rhs) {
auto var = dynamic_cast<ast::VarName*>(node.get_lhs().get());
if (!var) {
throw std::runtime_error("Error: only VarName assignment is currently supported.\n");
}
pramodk marked this conversation as resolved.
Show resolved Hide resolved

const auto& identifier = var->get_name();
if (identifier->is_name()) {
llvm::Value* alloca = local_named_values->lookup(var->get_node_name());
builder.CreateStore(rhs, alloca);
} else if (identifier->is_indexed_name()) {
auto indexed_name = std::dynamic_pointer_cast<ast::IndexedName>(identifier);
builder.CreateStore(rhs, codegen_indexed_name(*indexed_name));
} else {
throw std::runtime_error("Error: Unsupported variable type");
}
}

llvm::Value* CodegenLLVMVisitor::visit_logical_bin_op(llvm::Value* lhs,
llvm::Value* rhs,
unsigned op) {
const auto& bin_op = static_cast<ast::BinaryOp>(op);
return bin_op == ast::BinaryOp::BOP_AND ? builder.CreateAnd(lhs, rhs)
: builder.CreateOr(lhs, rhs);
}

llvm::Value* CodegenLLVMVisitor::visit_comparison_bin_op(llvm::Value* lhs,
llvm::Value* rhs,
unsigned op) {
const auto& bin_op = static_cast<ast::BinaryOp>(op);
llvm::Type* lhs_type = lhs->getType();
llvm::Value* result;

switch (bin_op) {
#define DISPATCH(binary_op, f_llvm_op, i_llvm_op) \
case binary_op: \
if (lhs_type->isDoubleTy() || lhs_type->isFloatTy()) \
result = f_llvm_op(lhs, rhs); \
else \
result = i_llvm_op(lhs, rhs); \
return result;

DISPATCH(ast::BinaryOp::BOP_EXACT_EQUAL, builder.CreateICmpEQ, builder.CreateFCmpOEQ);
DISPATCH(ast::BinaryOp::BOP_GREATER, builder.CreateICmpSGT, builder.CreateFCmpOGT);
DISPATCH(ast::BinaryOp::BOP_GREATER_EQUAL, builder.CreateICmpSGE, builder.CreateFCmpOGE);
DISPATCH(ast::BinaryOp::BOP_LESS, builder.CreateICmpSLT, builder.CreateFCmpOLT);
DISPATCH(ast::BinaryOp::BOP_LESS_EQUAL, builder.CreateICmpSLE, builder.CreateFCmpOLE);
DISPATCH(ast::BinaryOp::BOP_NOT_EQUAL, builder.CreateICmpNE, builder.CreateFCmpONE);

#undef DISPATCH

default:
return nullptr;
}
}

void CodegenLLVMVisitor::visit_procedure_or_function(const ast::Block& node) {
const auto& name = node.get_node_name();
const auto& parameters = node.get_parameters();
Expand Down Expand Up @@ -222,44 +312,39 @@ void CodegenLLVMVisitor::visit_binary_expression(const ast::BinaryExpression& no
llvm::Value* rhs = values.back();
values.pop_back();
if (op == ast::BinaryOp::BOP_ASSIGN) {
auto var = dynamic_cast<ast::VarName*>(node.get_lhs().get());
if (!var) {
throw std::runtime_error("Error: only VarName assignment is currently supported.\n");
}

const auto& identifier = var->get_name();
if (identifier->is_name()) {
llvm::Value* alloca = local_named_values->lookup(var->get_node_name());
builder.CreateStore(rhs, alloca);
} else if (identifier->is_indexed_name()) {
auto indexed_name = std::dynamic_pointer_cast<ast::IndexedName>(identifier);
builder.CreateStore(rhs, codegen_indexed_name(*indexed_name));
} else {
throw std::runtime_error("Error: Unsupported variable type");
}
visit_assign_op(node, rhs);
return;
}

node.get_lhs()->accept(*this);
llvm::Value* lhs = values.back();
values.pop_back();
llvm::Value* result;

// \todo: Support other binary operators
llvm::Value* result;
switch (op) {
#define DISPATCH(binary_op, llvm_op) \
case binary_op: \
result = llvm_op(lhs, rhs); \
values.push_back(result); \
case ast::BOP_ADDITION:
case ast::BOP_DIVISION:
case ast::BOP_MULTIPLICATION:
case ast::BOP_SUBTRACTION:
result = visit_arithmetic_bin_op(lhs, rhs, op);
break;

DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd);
DISPATCH(ast::BinaryOp::BOP_DIVISION, builder.CreateFDiv);
DISPATCH(ast::BinaryOp::BOP_MULTIPLICATION, builder.CreateFMul);
DISPATCH(ast::BinaryOp::BOP_SUBTRACTION, builder.CreateFSub);

#undef DISPATCH
case ast::BOP_AND:
case ast::BOP_OR:
result = visit_logical_bin_op(lhs, rhs, op);
break;
case ast::BOP_EXACT_EQUAL:
case ast::BOP_GREATER:
case ast::BOP_GREATER_EQUAL:
case ast::BOP_LESS:
case ast::BOP_LESS_EQUAL:
case ast::BOP_NOT_EQUAL:
result = visit_comparison_bin_op(lhs, rhs, op);
break;
default:
throw std::runtime_error("Error: binary operator is not supported\n");
}

values.push_back(result);
}

void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) {
Expand All @@ -269,8 +354,7 @@ void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) {
}

void CodegenLLVMVisitor::visit_double(const ast::Double& node) {
const auto& constant = llvm::ConstantFP::get(llvm::Type::getDoubleTy(*context),
node.get_value());
const auto& constant = llvm::ConstantFP::get(get_default_fp_type(), node.get_value());
values.push_back(constant);
}

Expand Down Expand Up @@ -310,10 +394,10 @@ void CodegenLLVMVisitor::visit_local_list_statement(const ast::LocalListStatemen
if (identifier->is_indexed_name()) {
auto indexed_name = std::dynamic_pointer_cast<ast::IndexedName>(identifier);
unsigned length = get_array_index_or_length(*indexed_name);
var_type = llvm::ArrayType::get(llvm::Type::getDoubleTy(*context), length);
var_type = llvm::ArrayType::get(get_default_fp_type(), length);
} else if (identifier->is_name()) {
// This case corresponds to a scalar local variable. Its type is double by default.
var_type = llvm::Type::getDoubleTy(*context);
var_type = get_default_fp_type();
} else {
throw std::runtime_error("Error: Unsupported local variable type");
}
Expand Down Expand Up @@ -367,10 +451,10 @@ void CodegenLLVMVisitor::visit_unary_expression(const ast::UnaryExpression& node
llvm::Value* value = values.back();
values.pop_back();
if (op == ast::UOP_NEGATION) {
llvm::Value* result = builder.CreateFNeg(value);
values.push_back(result);
values.push_back(builder.CreateFNeg(value));
} else if (op == ast::UOP_NOT) {
values.push_back(builder.CreateNot(value));
} else {
// Support only `double` operators for now.
throw std::runtime_error("Error: unsupported unary operator\n");
}
}
Expand Down
47 changes: 46 additions & 1 deletion src/codegen/llvm/codegen_llvm_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
// Run optimisation passes if true.
bool opt_passes;

// Use 32-bit floating-point type if true. Otherwise, use deafult 64-bit.
bool use_single_precision;

/**
*\brief Run LLVM optimisation passes on generated IR
*
Expand All @@ -93,10 +96,12 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
*/
CodegenLLVMVisitor(const std::string& mod_filename,
const std::string& output_dir,
bool opt_passes)
bool opt_passes,
bool use_single_precision = false)
: mod_filename(mod_filename)
, output_dir(output_dir)
, opt_passes(opt_passes)
, use_single_precision(use_single_precision)
, builder(*context)
, fpm(module.get()) {}

Expand Down Expand Up @@ -129,6 +134,12 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
*/
unsigned get_array_index_or_length(const ast::IndexedName& node);

/**
* Returns 64-bit or 32-bit LLVM floating type
* \return \c LLVM floating point type according to `use_single_precision` flag
*/
llvm::Type* get_default_fp_type();

/**
* Create a function call to an external method
* \param name external method name
Expand Down Expand Up @@ -162,6 +173,40 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor {
return std::move(module);
}

/**
* Visit nmodl arithmetic binary operator
* \param lhs LLVM value of evaluated lhs expression
* \param rhs LLVM value of evaluated rhs expression
* \param op the AST binary operator (ADD, DIV, MUL, SUB)
* \return LLVM IR value result
*/
llvm::Value* visit_arithmetic_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op);

/**
* Visit nmodl assignment operator (ASSIGN)
* \param node the AST node representing the binary expression in NMODL
* \param rhs LLVM value of evaluated rhs expression
*/
void visit_assign_op(const ast::BinaryExpression& node, llvm::Value* rhs);

/**
* Visit nmodl logical binary operator
* \param lhs LLVM value of evaluated lhs expression
* \param rhs LLVM value of evaluated rhs expression
* \param op the AST binary operator (AND, OR)
* \return LLVM IR value result
*/
llvm::Value* visit_logical_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op);

/**
* Visit nmodl comparison binary operator
* \param lhs LLVM value of evaluated lhs expression
* \param rhs LLVM value of evaluated rhs expression
* \param op the AST binary operator (EXACT_EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL,
* NOT_EQUAL) \return LLVM IR value result
*/
llvm::Value* visit_comparison_bin_op(llvm::Value* lhs, llvm::Value* rhs, unsigned op);

/**
* Visit nmodl function or procedure
* \param node the AST node representing the function or procedure in NMODL
Expand Down
8 changes: 7 additions & 1 deletion src/nmodl/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ int main(int argc, const char* argv[]) {
/// generate llvm IR
bool llvm_ir(false);

/// use single precision floating-point types
bool llvm_float_type(false);

/// run llvm optimisation passes
bool llvm_opt_passes(false);
#endif
Expand Down Expand Up @@ -287,6 +290,9 @@ int main(int argc, const char* argv[]) {
llvm_opt->add_flag("--opt",
llvm_opt_passes,
"Run LLVM optimisation passes ({})"_format(llvm_opt_passes))->ignore_case();
llvm_opt->add_flag("--single-precision",
llvm_float_type,
"Use single precision floating-point types ({})"_format(llvm_float_type))->ignore_case();
#endif
// clang-format on

Expand Down Expand Up @@ -573,7 +579,7 @@ int main(int argc, const char* argv[]) {
#ifdef NMODL_LLVM_BACKEND
if (llvm_ir) {
logger->info("Running LLVM backend code generator");
CodegenLLVMVisitor visitor(modfile, output_dir, llvm_opt_passes);
CodegenLLVMVisitor visitor(modfile, output_dir, llvm_opt_passes, llvm_float_type);
visitor.visit_program(*ast);
ast_to_nmodl(*ast, filepath("llvm"));
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ target_link_libraries(

if(NMODL_ENABLE_LLVM)
include_directories(${LLVM_INCLUDE_DIRS})
add_executable(testllvm visitor/main.cpp codegen/llvm.cpp)
add_executable(testllvm visitor/main.cpp codegen/codegen_llvm_ir.cpp)
add_executable(test_llvm_runner visitor/main.cpp codegen/codegen_llvm_execution.cpp)
target_link_libraries(
testllvm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ using nmodl::parser::NmodlDriver;
// Utility to get LLVM module as a string
//=============================================================================

std::string run_llvm_visitor(const std::string& text, bool opt = false) {
std::string run_llvm_visitor(const std::string& text,
bool opt = false,
bool use_single_precision = false) {
NmodlDriver driver;
const auto& ast = driver.parse_string(text);

SymtabVisitor().visit_program(*ast);

codegen::CodegenLLVMVisitor llvm_visitor("unknown", ".", opt);
codegen::CodegenLLVMVisitor llvm_visitor(/*mod_filename=*/"unknown",
/*output_dir=*/".",
opt,
use_single_precision);
llvm_visitor.visit_program(*ast);
return llvm_visitor.print_module();
}
Expand All @@ -47,14 +52,15 @@ SCENARIO("Binary expression", "[visitor][llvm]") {
)";

THEN("variables are loaded and add instruction is created") {
std::string module_string = run_llvm_visitor(nmodl_text);
std::string module_string =
run_llvm_visitor(nmodl_text, /*opt=*/false, /*use_single_precision=*/true);
std::smatch m;

std::regex rhs(R"(%1 = load double, double\* %b)");
std::regex lhs(R"(%2 = load double, double\* %a)");
std::regex res(R"(%3 = fadd double %2, %1)");
std::regex rhs(R"(%1 = load float, float\* %b)");
std::regex lhs(R"(%2 = load float, float\* %a)");
std::regex res(R"(%3 = fadd float %2, %1)");

// Check the values are loaded correctly and added
// Check the float values are loaded correctly and added
REQUIRE(std::regex_search(module_string, m, rhs));
REQUIRE(std::regex_search(module_string, m, lhs));
REQUIRE(std::regex_search(module_string, m, res));
Expand Down