diff --git a/src/codegen/llvm/codegen_llvm_helper_visitor.cpp b/src/codegen/llvm/codegen_llvm_helper_visitor.cpp index b3f75b9372..c34ae2c873 100644 --- a/src/codegen/llvm/codegen_llvm_helper_visitor.cpp +++ b/src/codegen/llvm/codegen_llvm_helper_visitor.cpp @@ -146,7 +146,7 @@ void CodegenLLVMHelperVisitor::create_function_for_node(ast::Block& node) { /// create new type and name for creating new ast node auto type = new ast::CodegenVarType(FLOAT_TYPE); auto var = param->get_name()->clone(); - arguments.emplace_back(new ast::CodegenVarWithType(type, 0, var)); + arguments.emplace_back(new ast::CodegenVarWithType(type, /*is_pointer=*/0, var)); } /// return type of the function is same as return variable type @@ -170,31 +170,31 @@ std::shared_ptr CodegenLLVMHelperVisitor::create_instance_s }; /// float variables are standard pointers to float vectors - for (auto& float_var: info.codegen_float_variables) { - add_var_with_type(float_var->get_name(), FLOAT_TYPE, 1); + for (const auto& float_var: info.codegen_float_variables) { + add_var_with_type(float_var->get_name(), FLOAT_TYPE, /*is_pointer=*/1); } /// int variables are pointers to indexes for other vectors - for (auto& int_var: info.codegen_int_variables) { - add_var_with_type(int_var.symbol->get_name(), FLOAT_TYPE, 1); + for (const auto& int_var: info.codegen_int_variables) { + add_var_with_type(int_var.symbol->get_name(), FLOAT_TYPE, /*is_pointer=*/1); } // for integer variables, there should be index - for (auto& int_var: info.codegen_int_variables) { + for (const auto& int_var: info.codegen_int_variables) { std::string var_name = int_var.symbol->get_name() + "_index"; - add_var_with_type(var_name, INTEGER_TYPE, 1); + add_var_with_type(var_name, INTEGER_TYPE, /*is_pointer=*/1); } // add voltage and node index - add_var_with_type("voltage", FLOAT_TYPE, 1); - add_var_with_type("node_index", INTEGER_TYPE, 1); + add_var_with_type("voltage", FLOAT_TYPE, /*is_pointer=*/1); + add_var_with_type("node_index", INTEGER_TYPE, /*is_pointer=*/1); // add dt, t, celsius - add_var_with_type(naming::NTHREAD_T_VARIABLE, FLOAT_TYPE, 0); - add_var_with_type(naming::NTHREAD_DT_VARIABLE, FLOAT_TYPE, 0); - add_var_with_type(naming::CELSIUS_VARIABLE, FLOAT_TYPE, 0); - add_var_with_type(naming::SECOND_ORDER_VARIABLE, INTEGER_TYPE, 0); - add_var_with_type(MECH_NODECOUNT_VAR, INTEGER_TYPE, 0); + add_var_with_type(naming::NTHREAD_T_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::NTHREAD_DT_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::CELSIUS_VARIABLE, FLOAT_TYPE, /*is_pointer=*/0); + add_var_with_type(naming::SECOND_ORDER_VARIABLE, INTEGER_TYPE, /*is_pointer=*/0); + add_var_with_type(MECH_NODECOUNT_VAR, INTEGER_TYPE, /*is_pointer=*/0); return std::make_shared(codegen_vars); } @@ -384,7 +384,7 @@ void CodegenLLVMHelperVisitor::convert_to_instance_variable(ast::Node& node, std::string& index_var) { /// collect all variables in the node of type ast::VarName auto variables = collect_nodes(node, {ast::AstNodeType::VAR_NAME}); - for (auto& v: variables) { + for (const auto& v: variables) { auto variable = std::dynamic_pointer_cast(v); auto variable_name = variable->get_node_name(); @@ -450,6 +450,44 @@ void CodegenLLVMHelperVisitor::visit_function_block(ast::FunctionBlock& node) { create_function_for_node(node); } +/// Create asr::Varname node with given a given variable name +static ast::VarName* create_varname(const std::string& varname) { + return new ast::VarName(new ast::Name(new ast::String(varname)), nullptr, nullptr); +} + +/** + * Create for loop initialization expression + * @param code Usually "id = 0" as a string + * @return Expression representing code + * \todo : we can not use `create_statement_as_expression` function because + * NMODL parser is using `ast::Double` type to represent all variables + * including Integer. See #542. + */ +static std::shared_ptr loop_initialization_expression( + const std::string& induction_var) { + // create id = 0 + const auto& id = create_varname(induction_var); + const auto& zero = new ast::Integer(0, nullptr); + return std::make_shared(id, ast::BinaryOperator(ast::BOP_ASSIGN), zero); +} + +/** + * Create loop increment expression `id = id + width` + * \todo : same as loop_initialization_expression() + */ +static std::shared_ptr loop_increment_expression(const std::string& induction_var, + int vector_width) { + // first create id + x + const auto& id = create_varname(induction_var); + const auto& inc = new ast::Integer(vector_width, nullptr); + const auto& inc_expr = + new ast::BinaryExpression(id, ast::BinaryOperator(ast::BOP_ADDITION), inc); + // now create id = id + x + return std::make_shared(id->clone(), + ast::BinaryOperator(ast::BOP_ASSIGN), + inc_expr); +} + /** * \brief Convert ast::NrnStateBlock to corresponding code generation function nrn_state * @param node AST node representing ast::NrnStateBlock @@ -471,9 +509,9 @@ void CodegenLLVMHelperVisitor::visit_nrn_state_block(ast::NrnStateBlock& node) { /// create now main compute part : for loop over channel instances /// loop constructs : initialization, condition and increment - const auto& initialization = create_statement_as_expression("id = 0"); - const auto& condition = create_expression("id < node_count"); - const auto& increment = create_statement_as_expression("id = id + {}"_format(vector_width)); + const auto& initialization = loop_initialization_expression(INDUCTION_VAR); + const auto& condition = create_expression("{} < {}"_format(INDUCTION_VAR, MECH_NODECOUNT_VAR)); + const auto& increment = loop_increment_expression(INDUCTION_VAR, vector_width); /// loop body : initialization + solve blocks ast::StatementVector loop_def_statements; @@ -484,7 +522,8 @@ void CodegenLLVMHelperVisitor::visit_nrn_state_block(ast::NrnStateBlock& node) { std::vector double_variables{"v"}; /// access node index and corresponding voltage - loop_index_statements.push_back(visitor::create_statement("node_id = node_index[id]")); + loop_index_statements.push_back( + visitor::create_statement("node_id = node_index[{}]"_format(INDUCTION_VAR))); loop_body_statements.push_back(visitor::create_statement("v = voltage[node_id]")); /// read ion variables @@ -558,7 +597,7 @@ void CodegenLLVMHelperVisitor::visit_nrn_state_block(ast::NrnStateBlock& node) { ast::CodegenVarWithTypeVector code_arguments; auto instance_var_type = new ast::CodegenVarType(ast::AstNodeType::INSTANCE_STRUCT); - auto instance_var_name = new ast::Name(new ast::String("mech")); + auto instance_var_name = new ast::Name(new ast::String(MECH_INSTANCE_VAR)); auto instance_var = new ast::CodegenVarWithType(instance_var_type, 1, instance_var_name); code_arguments.emplace_back(instance_var); @@ -567,7 +606,7 @@ void CodegenLLVMHelperVisitor::visit_nrn_state_block(ast::NrnStateBlock& node) { std::make_shared(return_type, name, code_arguments, function_block); codegen_functions.push_back(function); - std::cout << nmodl::to_nmodl(function); + std::cout << nmodl::to_nmodl(function) << std::endl; } void CodegenLLVMHelperVisitor::visit_program(ast::Program& node) { @@ -583,8 +622,6 @@ void CodegenLLVMHelperVisitor::visit_program(ast::Program& node) { for (auto& fun: codegen_functions) { node.emplace_back_node(fun); } - - std::cout << nmodl::to_nmodl(node); } diff --git a/src/codegen/llvm/codegen_llvm_helper_visitor.hpp b/src/codegen/llvm/codegen_llvm_helper_visitor.hpp index 981372b4d5..b67aa7ee09 100644 --- a/src/codegen/llvm/codegen_llvm_helper_visitor.hpp +++ b/src/codegen/llvm/codegen_llvm_helper_visitor.hpp @@ -120,6 +120,9 @@ class CodegenLLVMHelperVisitor: public visitor::AstVisitor { const std::string MECH_INSTANCE_VAR = "mech"; const std::string MECH_NODECOUNT_VAR = "node_count"; + /// name of induction variable used in the kernel. + const std::string INDUCTION_VAR = "id"; + /// create new function for FUNCTION or PROCEDURE block void create_function_for_node(ast::Block& node); @@ -134,6 +137,10 @@ class CodegenLLVMHelperVisitor: public visitor::AstVisitor { return instance_var_helper; } + std::string get_kernel_id() { + return INDUCTION_VAR; + } + /// run visitor and return code generation functions CodegenFunctionVector get_codegen_functions(const ast::Program& node); diff --git a/src/codegen/llvm/codegen_llvm_visitor.cpp b/src/codegen/llvm/codegen_llvm_visitor.cpp index 80bdfd20e3..62e69449b7 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.cpp +++ b/src/codegen/llvm/codegen_llvm_visitor.cpp @@ -21,14 +21,22 @@ namespace nmodl { namespace codegen { +static constexpr const char instance_struct_type_name[] = "__instance_var__type"; + +// The prefix is used to create a vectorised id that can be used as index to GEPs. However, for +// simple aligned vector loads and stores vector id is not needed. This is because we can bitcast +// the pointer to the vector pointer! \todo: Consider removing this. +static constexpr const char kernel_id_prefix[] = "__vec_"; + + /****************************************************************************************/ /* Helper routines */ /****************************************************************************************/ static bool is_supported_statement(const ast::Statement& statement) { return statement.is_codegen_var_list_statement() || statement.is_expression_statement() || - statement.is_codegen_return_statement() || statement.is_if_statement() || - statement.is_while_statement(); + statement.is_codegen_for_statement() || statement.is_codegen_return_statement() || + statement.is_if_statement() || statement.is_while_statement(); } bool CodegenLLVMVisitor::check_array_bounds(const ast::IndexedName& node, unsigned index) { @@ -56,10 +64,82 @@ llvm::Value* CodegenLLVMVisitor::codegen_indexed_name(const ast::IndexedName& no return create_gep(node.get_node_name(), index); } +llvm::Value* CodegenLLVMVisitor::codegen_instance_var(const ast::CodegenInstanceVar& node) { + const auto& member_node = node.get_member_var(); + const auto& instance_name = node.get_instance_var()->get_node_name(); + const auto& member_name = member_node->get_node_name(); + + if (!instance_var_helper.is_an_instance_variable(member_name)) + throw std::runtime_error("Error: " + member_name + " is not a member of the instance!"); + + // Load the instance struct given its name from the ValueSymbolTable. + llvm::Value* instance_ptr = builder.CreateLoad(lookup(instance_name)); + + // Create a GEP instruction to get a pointer to the member. + int member_index = instance_var_helper.get_variable_index(member_name); + llvm::Type* index_type = llvm::Type::getInt32Ty(*context); + + std::vector indices; + indices.push_back(llvm::ConstantInt::get(index_type, 0)); + indices.push_back(llvm::ConstantInt::get(index_type, member_index)); + llvm::Value* member_ptr = builder.CreateInBoundsGEP(instance_ptr, indices); + + // Get the member AST node from the instance AST node, for which we proceed with the code + // generation. If the member is scalar, return the pointer to it straight away. + auto codegen_var_with_type = instance_var_helper.get_variable(member_name); + if (!codegen_var_with_type->get_is_pointer()) { + return member_ptr; + } + + // Otherwise, the codegen variable is a pointer, and the member AST node must be an IndexedName. + auto member_var_name = std::dynamic_pointer_cast(member_node); + if (!member_var_name->get_name()->is_indexed_name()) + throw std::runtime_error("Error: " + member_name + " is not an IndexedName!"); + + // Proceed to creating a GEP instruction to get the pointer to the member's element. While LLVM + // Helper set the indices to be Name nodes, a sanity check is added here. Note that this step + // can be avoided if using `get_array_index_or_length()`. However, it does not support indexing + // with Name/Expression at the moment. \todo: Reuse `get_array_index_or_length()` here. + auto member_indexed_name = std::dynamic_pointer_cast( + member_var_name->get_name()); + if (!member_indexed_name->get_length()->is_name()) + throw std::runtime_error("Error: " + member_name + " has a non-Name index!"); + + // Load the index variable that will be used to access the member's element. Since we index a + // pointer variable, we need to extend the 32-bit integer index variable to 64-bit. + llvm::Value* i32_index = builder.CreateLoad( + lookup(member_indexed_name->get_length()->get_node_name())); + llvm::Value* i64_index = builder.CreateSExt(i32_index, llvm::Type::getInt64Ty(*context)); + + // Create a indices vector for GEP to return the pointer to the element at the specified index. + std::vector member_indices; + member_indices.push_back(i64_index); + + // The codegen variable type is always a scalar, so we need to transform it to a pointer. Then + // load the member which would be indexed later. + llvm::Type* type = get_codegen_var_type(*codegen_var_with_type->get_type()); + llvm::Value* instance_member = + builder.CreateLoad(llvm::PointerType::get(type, /*AddressSpace=*/0), member_ptr); + + + // If the code is vectorised, then bitcast to a vector pointer. + if (is_kernel_code && vector_width > 1) { + llvm::Type* vector_type = + llvm::PointerType::get(llvm::FixedVectorType::get(type, vector_width), + /*AddressSpace=*/0); + llvm::Value* instance_member_bitcasted = builder.CreateBitCast(instance_member, + vector_type); + return builder.CreateInBoundsGEP(instance_member_bitcasted, member_indices); + } + + return builder.CreateInBoundsGEP(instance_member, member_indices); +} + unsigned CodegenLLVMVisitor::get_array_index_or_length(const ast::IndexedName& indexed_name) { + // \todo: Support indices with expressions and names: k[i + j] = ... auto integer = std::dynamic_pointer_cast(indexed_name.get_length()); if (!integer) - throw std::runtime_error("Error: expecting integer index or length"); + throw std::runtime_error("Error: only integer indices/length are supported!"); // Check if integer value is taken from a macro. if (!integer->get_macro()) @@ -74,6 +154,8 @@ llvm::Type* CodegenLLVMVisitor::get_codegen_var_type(const ast::CodegenVarType& return llvm::Type::getInt1Ty(*context); case ast::AstNodeType::DOUBLE: return get_default_fp_type(); + case ast::AstNodeType::INSTANCE_STRUCT: + return get_instance_struct_type(); case ast::AstNodeType::INTEGER: return llvm::Type::getInt32Ty(*context); case ast::AstNodeType::VOID: @@ -85,6 +167,26 @@ llvm::Type* CodegenLLVMVisitor::get_codegen_var_type(const ast::CodegenVarType& } } +llvm::Value* CodegenLLVMVisitor::get_constant_int_vector(int value) { + llvm::Type* i32_type = llvm::Type::getInt32Ty(*context); + std::vector constants; + for (unsigned i = 0; i < vector_width; ++i) { + const auto& element = llvm::ConstantInt::get(i32_type, value); + constants.push_back(element); + } + return llvm::ConstantVector::get(constants); +} + +llvm::Value* CodegenLLVMVisitor::get_constant_fp_vector(const std::string& value) { + llvm::Type* fp_type = get_default_fp_type(); + std::vector constants; + for (unsigned i = 0; i < vector_width; ++i) { + const auto& element = llvm::ConstantFP::get(fp_type, value); + constants.push_back(element); + } + return llvm::ConstantVector::get(constants); +} + llvm::Type* CodegenLLVMVisitor::get_default_fp_type() { if (use_single_precision) return llvm::Type::getFloatTy(*context); @@ -97,6 +199,59 @@ llvm::Type* CodegenLLVMVisitor::get_default_fp_ptr_type() { return llvm::Type::getDoublePtrTy(*context); } +llvm::Type* CodegenLLVMVisitor::get_instance_struct_type() { + std::vector members; + for (const auto& variable: instance_var_helper.instance->get_codegen_vars()) { + auto is_pointer = variable->get_is_pointer(); + auto nmodl_type = variable->get_type()->get_type(); + + llvm::Type* i32_type = llvm::Type::getInt32Ty(*context); + llvm::Type* i32ptr_type = llvm::Type::getInt32PtrTy(*context); + + switch (nmodl_type) { +#define DISPATCH(type, llvm_ptr_type, llvm_type) \ + case type: \ + members.push_back(is_pointer ? (llvm_ptr_type) : (llvm_type)); \ + break; + + DISPATCH(ast::AstNodeType::DOUBLE, get_default_fp_ptr_type(), get_default_fp_type()); + DISPATCH(ast::AstNodeType::INTEGER, i32ptr_type, i32_type); + +#undef DISPATCH + default: + throw std::runtime_error("Error: unsupported type found in instance struct"); + } + } + + llvm::StructType* llvm_struct_type = + llvm::StructType::create(*context, mod_filename + instance_struct_type_name); + llvm_struct_type->setBody(members); + return llvm::PointerType::get(llvm_struct_type, /*AddressSpace=*/0); +} + +llvm::Value* CodegenLLVMVisitor::get_variable_ptr(const ast::VarName& node) { + const auto& identifier = node.get_name(); + if (!identifier->is_name() && !identifier->is_indexed_name() && + !identifier->is_codegen_instance_var()) { + throw std::runtime_error("Error: Unsupported variable type - " + node.get_node_name()); + } + + llvm::Value* ptr; + if (identifier->is_name()) + ptr = lookup(node.get_node_name()); + + if (identifier->is_indexed_name()) { + auto indexed_name = std::dynamic_pointer_cast(identifier); + ptr = codegen_indexed_name(*indexed_name); + } + + if (identifier->is_codegen_instance_var()) { + auto instance_var = std::dynamic_pointer_cast(identifier); + ptr = codegen_instance_var(*instance_var); + } + return ptr; +} + void CodegenLLVMVisitor::run_llvm_opt_passes() { /// run some common optimisation passes that are commonly suggested fpm.add(llvm::createInstructionCombiningPass()); @@ -134,7 +289,7 @@ void CodegenLLVMVisitor::create_external_method_call(const std::string& name, } #define DISPATCH(method_name, intrinsic) \ - if (name == method_name) { \ + if (name == (method_name)) { \ llvm::Value* result = builder.CreateIntrinsic(intrinsic, argument_types, argument_values); \ values.push_back(result); \ return; \ @@ -234,12 +389,12 @@ llvm::Value* CodegenLLVMVisitor::visit_arithmetic_bin_op(llvm::Value* lhs, llvm::Value* result; switch (bin_op) { -#define DISPATCH(binary_op, llvm_fp_op, llvm_int_op) \ - case binary_op: \ - if (lhs_type->isDoubleTy() || lhs_type->isFloatTy()) \ - result = llvm_fp_op(lhs, rhs); \ - else \ - result = llvm_int_op(lhs, rhs); \ +#define DISPATCH(binary_op, llvm_fp_op, llvm_int_op) \ + case binary_op: \ + if (lhs_type->isIntOrIntVectorTy()) \ + result = llvm_int_op(lhs, rhs); \ + else \ + result = llvm_fp_op(lhs, rhs); \ return result; DISPATCH(ast::BinaryOp::BOP_ADDITION, builder.CreateFAdd, builder.CreateAdd); @@ -256,20 +411,11 @@ llvm::Value* CodegenLLVMVisitor::visit_arithmetic_bin_op(llvm::Value* lhs, void CodegenLLVMVisitor::visit_assign_op(const ast::BinaryExpression& node, llvm::Value* rhs) { auto var = dynamic_cast(node.get_lhs().get()); - if (!var) { - throw std::runtime_error("Error: only VarName assignment is currently supported.\n"); - } + if (!var) + throw std::runtime_error("Error: only VarName assignment is supported!"); - const auto& identifier = var->get_name(); - if (identifier->is_name()) { - llvm::Value* alloca = lookup(var->get_node_name()); - builder.CreateStore(rhs, alloca); - } else if (identifier->is_indexed_name()) { - auto indexed_name = std::dynamic_pointer_cast(identifier); - builder.CreateStore(rhs, codegen_indexed_name(*indexed_name)); - } else { - throw std::runtime_error("Error: Unsupported variable type"); - } + llvm::Value* ptr = get_variable_ptr(*var); + builder.CreateStore(rhs, ptr); } llvm::Value* CodegenLLVMVisitor::visit_logical_bin_op(llvm::Value* lhs, @@ -373,6 +519,117 @@ void CodegenLLVMVisitor::visit_boolean(const ast::Boolean& node) { values.push_back(constant); } +// Generating FOR loop in LLVM IR creates the following structure: +// +// +---------------------------+ +// | | +// | | +// | br %cond | +// +---------------------------+ +// | +// V +// +-----------------------------+ +// | | +// | %cond = ... |<------+ +// | cond_br %cond, %body, %exit | | +// +-----------------------------+ | +// | | | +// | V | +// | +------------------------+ | +// | | | | +// | | br %inc | | +// | +------------------------+ | +// | | | +// | V | +// | +------------------------+ | +// | | | | +// | | br %cond | | +// | +------------------------+ | +// | | | +// | +---------------+ +// V +// +---------------------------+ +// | | +// +---------------------------+ +void CodegenLLVMVisitor::visit_codegen_for_statement(const ast::CodegenForStatement& node) { + // Get the current and the next blocks within the function. + llvm::BasicBlock* curr_block = builder.GetInsertBlock(); + llvm::BasicBlock* next = curr_block->getNextNode(); + llvm::Function* func = curr_block->getParent(); + + // Create the basic blocks for FOR loop. + llvm::BasicBlock* for_cond = + llvm::BasicBlock::Create(*context, /*Name=*/"for.cond", func, next); + llvm::BasicBlock* for_body = + llvm::BasicBlock::Create(*context, /*Name=*/"for.body", func, next); + llvm::BasicBlock* for_inc = llvm::BasicBlock::Create(*context, /*Name=*/"for.inc", func, next); + llvm::BasicBlock* exit = llvm::BasicBlock::Create(*context, /*Name=*/"for.exit", func, next); + + // First, initialise the loop in the same basic block. + node.get_initialization()->accept(*this); + + // If the loop is to be vectorised, create a separate vector induction variable. + // \todo: See the comment for `kernel_id_prefix`. + if (vector_width > 1) { + // First, create a vector type and alloca for it. + llvm::Type* i32_type = llvm::Type::getInt32Ty(*context); + llvm::Type* vec_type = llvm::FixedVectorType::get(i32_type, vector_width); + llvm::Value* vec_alloca = builder.CreateAlloca(vec_type, + /*ArraySize=*/nullptr, + /*Name=*/kernel_id_prefix + kernel_id); + + // Then, store the initial value of <0, 1, ..., [W-1]> o the alloca pointer, where W is the + // vector width. + std::vector constants; + for (unsigned i = 0; i < vector_width; ++i) { + const auto& element = llvm::ConstantInt::get(i32_type, i); + constants.push_back(element); + } + llvm::Value* vector_id = llvm::ConstantVector::get(constants); + builder.CreateStore(vector_id, vec_alloca); + } + // Branch to condition basic block and insert condition code there. + builder.CreateBr(for_cond); + builder.SetInsertPoint(for_cond); + node.get_condition()->accept(*this); + + // Extract the condition to decide whether to branch to the loop body or loop exit. + llvm::Value* cond = values.back(); + values.pop_back(); + builder.CreateCondBr(cond, for_body, exit); + + // Generate code for the loop body and create the basic block for the increment. + builder.SetInsertPoint(for_body); + is_kernel_code = true; + const auto& statement_block = node.get_statement_block(); + statement_block->accept(*this); + is_kernel_code = false; + builder.CreateBr(for_inc); + + // Process increment. + builder.SetInsertPoint(for_inc); + node.get_increment()->accept(*this); + + // If the code is vectorised, then increment the vector id by where W is the + // vector width. + // \todo: See the comment for `kernel_id_prefix`. + if (vector_width > 1) { + // First, create an increment vector. + llvm::Value* vector_inc = get_constant_int_vector(vector_width); + + // Increment the kernel id elements by a constant vector width. + llvm::Value* vector_id_ptr = lookup(kernel_id_prefix + kernel_id); + llvm::Value* vector_id = builder.CreateLoad(vector_id_ptr); + llvm::Value* incremented = builder.CreateAdd(vector_id, vector_inc); + builder.CreateStore(incremented, vector_id_ptr); + } + + // Create a branch to condition block, then generate exit code out of the loop. + builder.CreateBr(for_cond); + builder.SetInsertPoint(exit); +} + + void CodegenLLVMVisitor::visit_codegen_function(const ast::CodegenFunction& node) { const auto& name = node.get_node_name(); const auto& arguments = node.get_arguments(); @@ -406,7 +663,7 @@ void CodegenLLVMVisitor::visit_codegen_function(const ast::CodegenFunction& node block->accept(*this); // If function has a void return type, add a terminator not handled by CodegenReturnVar. - if (node.is_void()) + if (node.get_return_type()->get_type() == ast::AstNodeType::VOID) builder.CreateRetVoid(); // Clear local values stack and remove the pointer to the local symbol table. @@ -419,7 +676,7 @@ void CodegenLLVMVisitor::visit_codegen_return_statement(const ast::CodegenReturn throw std::runtime_error("Error: CodegenReturnStatement must contain a name node\n"); std::string ret = "ret_" + current_func->getName().str(); - llvm::Value* ret_value = builder.CreateLoad(current_func->getValueSymbolTable()->lookup(ret)); + llvm::Value* ret_value = builder.CreateLoad(lookup(ret)); builder.CreateRet(ret_value); } @@ -456,6 +713,10 @@ void CodegenLLVMVisitor::visit_codegen_var_list_statement( } void CodegenLLVMVisitor::visit_double(const ast::Double& node) { + if (is_kernel_code && vector_width > 1) { + values.push_back(get_constant_fp_vector(node.get_value())); + return; + } const auto& constant = llvm::ConstantFP::get(get_default_fp_type(), node.get_value()); values.push_back(constant); } @@ -547,6 +808,10 @@ void CodegenLLVMVisitor::visit_if_statement(const ast::IfStatement& node) { } void CodegenLLVMVisitor::visit_integer(const ast::Integer& node) { + if (is_kernel_code && vector_width > 1) { + values.push_back(get_constant_int_vector(node.get_value())); + return; + } const auto& constant = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*context), node.get_value()); values.push_back(constant); @@ -561,9 +826,7 @@ void CodegenLLVMVisitor::visit_program(const ast::Program& node) { const auto& functions = v.get_codegen_functions(node); instance_var_helper = v.get_instance_var_helper(); - // TODO :: George / Ioannis :: before emitting procedures, we have - // to emmit INSTANCE_STRUCT type as it's used as an argument. - // Currently it's done in node.visit_children which is late. + kernel_id = v.get_kernel_id(); // For every function, generate its declaration. Thus, we can look up // `llvm::Function` in the symbol table in the module. @@ -574,8 +837,15 @@ void CodegenLLVMVisitor::visit_program(const ast::Program& node) { // Set the AST symbol table. sym_tab = node.get_symbol_table(); - // Proceed with code generation. - node.visit_children(*this); + // Proceed with code generation. Right now, we do not do + // node.visit_children(*this); + // The reason is that the node may contain AST nodes for which the visitor functions have been + // defined. In our implementation we assume that the code generation is happening within the + // function scope. To avoid generating code outside of functions, visit only them for now. + // \todo: Handle what is mentioned here. + for (const auto& func: functions) { + visit_codegen_function(*func); + } if (opt_passes) { logger->info("Running LLVM optimisation passes"); @@ -605,60 +875,21 @@ void CodegenLLVMVisitor::visit_unary_expression(const ast::UnaryExpression& node } void CodegenLLVMVisitor::visit_var_name(const ast::VarName& node) { - const auto& identifier = node.get_name(); - if (!identifier->is_name() && !identifier->is_indexed_name()) - throw std::runtime_error("Error: Unsupported variable type"); - - // TODO :: George :: here instance_var_helper can be used to query - // variable type and it's index into structure - auto name = node.get_node_name(); - - auto codegen_var_with_type = instance_var_helper.get_variable(name); - auto codegen_var_index = instance_var_helper.get_variable_index(name); - // this will be INTEGER or DOUBLE - auto var_type = codegen_var_with_type->get_type()->get_type(); - auto is_pointer = codegen_var_with_type->get_is_pointer(); - - llvm::Value* ptr; - if (identifier->is_name()) - ptr = lookup(node.get_node_name()); - - if (identifier->is_indexed_name()) { - auto indexed_name = std::dynamic_pointer_cast(identifier); - ptr = codegen_indexed_name(*indexed_name); - } + llvm::Value* ptr = get_variable_ptr(node); // Finally, load the variable from the pointer value. llvm::Value* var = builder.CreateLoad(ptr); - values.push_back(var); -} -void CodegenLLVMVisitor::visit_instance_struct(const ast::InstanceStruct& node) { - std::vector members; - for (const auto& variable: node.get_codegen_vars()) { - // TODO :: Ioannis / George :: we have now double*, int*, double and int - // variables in the instance structure. Each variable is of type - // ast::CodegenVarWithType. So we can query variable type and if - // it's pointer. - auto is_pointer = variable->get_is_pointer(); - auto type = variable->get_type()->get_type(); - - // todo : clean up ? - if (type == ast::AstNodeType::DOUBLE) { - auto llvm_type = is_pointer ? get_default_fp_ptr_type() : get_default_fp_type(); - members.push_back(llvm_type); - } else { - if (is_pointer) { - members.push_back(llvm::Type::getInt32PtrTy(*context)); - } else { - members.push_back(llvm::Type::getInt32Ty(*context)); - } - } + // If the vale should not be vectorised, or it is already a vector, add it to the stack. + if (!is_kernel_code || vector_width <= 1 || var->getType()->isVectorTy()) { + values.push_back(var); + return; } - llvm_struct = llvm::StructType::create(*context, mod_filename + "_Instance"); - llvm_struct->setBody(members); - module->getOrInsertGlobal("inst", llvm_struct); + // Otherwise, if we are generating vectorised inside the loop, replicate the value to form a + // vector of `vector_width`. + llvm::Value* vector_var = builder.CreateVectorSplat(vector_width, var); + values.push_back(vector_var); } void CodegenLLVMVisitor::visit_while_statement(const ast::WhileStatement& node) { diff --git a/src/codegen/llvm/codegen_llvm_visitor.hpp b/src/codegen/llvm/codegen_llvm_visitor.hpp index b20a19bac7..c93b76b1d6 100644 --- a/src/codegen/llvm/codegen_llvm_visitor.hpp +++ b/src/codegen/llvm/codegen_llvm_visitor.hpp @@ -82,11 +82,14 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { // Use 32-bit floating-point type if true. Otherwise, use deafult 64-bit. bool use_single_precision; - // explicit vectorisation width + // Explicit vectorisation width. int vector_width; - // LLVM mechanism struct - llvm::StructType* llvm_struct; + // The name of induction variable used in the kernel functions. + std::string kernel_id; + + // A flag to indicate that the code is generated for the kernel. + bool is_kernel_code = false; /** *\brief Run LLVM optimisation passes on generated IR @@ -106,8 +109,8 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { CodegenLLVMVisitor(const std::string& mod_filename, const std::string& output_dir, bool opt_passes, - int vector_width = 1, - bool use_single_precision = false) + bool use_single_precision = false, + int vector_width = 1) : mod_filename(mod_filename) , output_dir(output_dir) , opt_passes(opt_passes) @@ -130,6 +133,13 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { */ llvm::Value* codegen_indexed_name(const ast::IndexedName& node); + /** + * Generates LLVM code for the given Instance variable + * \param node CodegenInstanceVar NMODL AST node + * \return LLVM code generated for this AST node + */ + llvm::Value* codegen_instance_var(const ast::CodegenInstanceVar& node); + /** * Returns GEP instruction to 1D array * \param name 1D array name @@ -152,6 +162,20 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { */ llvm::Type* get_codegen_var_type(const ast::CodegenVarType& node); + /** + * Returns LLVM vector with `vector_width` int values. + * \param int value to replicate + * \return LLVM value + */ + llvm::Value* get_constant_int_vector(int value); + + /** + * Returns LLVM vector with `vector_width` double values. + * \param string a double value to replicate + * \return LLVM value + */ + llvm::Value* get_constant_fp_vector(const std::string& value); + /** * Returns 64-bit or 32-bit LLVM floating type * \return \c LLVM floating point type according to `use_single_precision` flag @@ -164,6 +188,18 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { */ llvm::Type* get_default_fp_ptr_type(); + /** + * Returns a pointer to LLVM struct type + * \return LLVM pointer type + */ + llvm::Type* get_instance_struct_type(); + + /** + * Returns a LLVM value corresponding to the VarName node + * \return LLVM value + */ + llvm::Value* get_variable_ptr(const ast::VarName& node); + /** * Create a function call to an external method * \param name external method name @@ -255,6 +291,7 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { void visit_binary_expression(const ast::BinaryExpression& node) override; void visit_boolean(const ast::Boolean& node) override; void visit_statement_block(const ast::StatementBlock& node) override; + void visit_codegen_for_statement(const ast::CodegenForStatement& node) override; void visit_codegen_function(const ast::CodegenFunction& node) override; void visit_codegen_return_statement(const ast::CodegenReturnStatement& node) override; void visit_codegen_var_list_statement(const ast::CodegenVarListStatement& node) override; @@ -267,7 +304,6 @@ class CodegenLLVMVisitor: public visitor::ConstAstVisitor { void visit_program(const ast::Program& node) override; void visit_unary_expression(const ast::UnaryExpression& node) override; void visit_var_name(const ast::VarName& node) override; - void visit_instance_struct(const ast::InstanceStruct& node) override; void visit_while_statement(const ast::WhileStatement& node) override; // \todo: move this to debug mode (e.g. -v option or --dump-ir) diff --git a/src/main.cpp b/src/main.cpp index f9e083f930..5fa5304776 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -594,7 +594,7 @@ int main(int argc, const char* argv[]) { if (llvm_ir) { logger->info("Running LLVM backend code generator"); CodegenLLVMVisitor visitor( - modfile, output_dir, llvm_opt_passes, llvm_vec_width, llvm_float_type); + modfile, output_dir, llvm_opt_passes, llvm_float_type, llvm_vec_width); visitor.visit_program(*ast); ast_to_nmodl(*ast, filepath("llvm", "mod")); ast_to_json(*ast, filepath("llvm", "json")); diff --git a/test/unit/codegen/codegen_llvm_ir.cpp b/test/unit/codegen/codegen_llvm_ir.cpp index ba0c725c0c..a376bd3f5c 100644 --- a/test/unit/codegen/codegen_llvm_ir.cpp +++ b/test/unit/codegen/codegen_llvm_ir.cpp @@ -794,39 +794,3 @@ SCENARIO("Dead code removal", "[visitor][llvm][opt]") { } } } - -//============================================================================= -// Create Instance Struct -//============================================================================= - -SCENARIO("Creation of Instance Struct", "[visitor][llvm][instance_struct]") { - GIVEN("NEURON block with RANGE variables and IONS") { - std::string nmodl_text = R"( - NEURON { - USEION na READ ena WRITE ina - NONSPECIFIC_CURRENT il - RANGE minf, hinf - } - - STATE { - m - } - - ASSIGNED { - v (mV) - celsius (degC) - minf - hinf - } - )"; - - THEN("create struct with the declared variables") { - std::string module_string = run_llvm_visitor(nmodl_text, true); - std::smatch m; - - std::regex instance_struct_declaration( - R"(%unknown_Instance = type \{ double\*, double\*, double\*, double\*, double\*, double\*, double\*, double\*, double\*, double\* \})"); - REQUIRE(std::regex_search(module_string, m, instance_struct_declaration)); - } - } -}