Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for TABLE statements in FUNCTION blocks #913

Merged
merged 6 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _config_exe(exe_name):
"jupyter-client",
"jupyter",
"myst_parser",
"mistune<2", # prevents a version conflict with nbconvert
"mistune<3", # prevents a version conflict with nbconvert
"nbconvert",
"nbsphinx>=0.3.2",
"pytest>=3.7.2",
Expand Down
201 changes: 117 additions & 84 deletions src/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1497,22 +1497,21 @@ void CodegenCVisitor::print_table_check_function(const Block& node) {
printer->start_block(
fmt::format("void check_{}({})", method_name(name), get_parameter_str(internal_params)));
{
printer->add_line(fmt::format("if ( {} == 0) {}", use_table_var, "{"));
printer->add_line(" return;");
printer->add_line("}");
printer->fmt_start_block("if ({} == 0)", use_table_var);
printer->add_line("return;");
printer->end_block(1);

printer->add_line("static bool make_table = true;");
for (const auto& variable: depend_variables) {
printer->add_line(
fmt::format("static {} save_{};", float_type, variable->get_node_name()));
printer->fmt_line("static {} save_{};", float_type, variable->get_node_name());
}

for (const auto& variable: depend_variables) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
printer->add_line(fmt::format("if (save_{} != {}) {}", name, instance_name, "{"));
printer->add_line(" make_table = true;");
printer->add_line("}");
printer->fmt_start_block("if (save_{} != {})", name, instance_name);
printer->add_line("make_table = true;");
printer->end_block(1);
}

printer->start_block("if (make_table)");
Expand All @@ -1532,22 +1531,28 @@ void CodegenCVisitor::print_table_check_function(const Block& node) {
printer->add_newline();


printer->add_line(fmt::format("double dx = (tmax-{})/{}.0;", tmin_name, with));
printer->add_line(fmt::format("{} = 1.0/dx;", mfac_name));
printer->fmt_line("double dx = (tmax-{}) / {}.;", tmin_name, with);
printer->fmt_line("{} = 1./dx;", mfac_name);

printer->add_line("int i = 0;");
printer->add_line("double x = 0;");
printer->add_line(fmt::format(
"for(i = 0, x = {}; i < {}; x += dx, i++) {}", tmin_name, with + 1, "{"));
printer->fmt_line("double x = {};", tmin_name);
printer->fmt_start_block("for (size_t i = 0; i < {}; x += dx, i++)", with + 1);
auto function = method_name("f_" + name);
printer->add_line(fmt::format(" {}({}, x);", function, internal_method_arguments()));
for (const auto& variable: table_variables) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
if (node.is_procedure_block()) {
printer->fmt_line("{}({}, x);", function, internal_method_arguments());
for (const auto& variable: table_variables) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
auto table_name = get_variable_name("t_" + name);
printer->fmt_line("{}[i] = {};", table_name, instance_name);
}
} else {
auto table_name = get_variable_name("t_" + name);
printer->add_line(fmt::format(" {}[i] = {};", table_name, instance_name));
printer->fmt_line("{}[i] = {}({}, x);",
table_name,
function,
internal_method_arguments());
}
printer->add_line("}");
printer->end_block(1);

for (const auto& variable: depend_variables) {
auto name = variable->get_node_name();
Expand Down Expand Up @@ -1577,45 +1582,68 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) {
printer->start_block();
{
const auto& params = node.get_parameters();
printer->add_line(fmt::format("if ( {} == 0) {{", use_table_var));
printer->add_line(fmt::format(" {}({}, {});",
function_name,
internal_method_arguments(),
params[0].get()->get_node_name()));
printer->add_line(" return 0;");
printer->add_line("}");
printer->fmt_start_block("if ({} == 0)", use_table_var);
if (node.is_procedure_block()) {
printer->fmt_line("{}({}, {});",
function_name,
internal_method_arguments(),
params[0].get()->get_node_name());
printer->add_line("return 0;");
} else {
printer->fmt_line("return {}({}, {});",
function_name,
internal_method_arguments(),
params[0].get()->get_node_name());
}
printer->end_block(1);

printer->add_line(fmt::format(
"double xi = {} * ({} - {});", mfac_name, params[0].get()->get_node_name(), tmin_name));
printer->add_line("if (isnan(xi)) {");
for (const auto& var: table_variables) {
auto name = get_variable_name(var->get_node_name());
printer->add_line(fmt::format(" {} = xi;", name));
printer->fmt_line("double xi = {} * ({} - {});",
mfac_name,
params[0].get()->get_node_name(),
tmin_name);
printer->start_block("if (isnan(xi))");
if (node.is_procedure_block()) {
for (const auto& var: table_variables) {
auto name = get_variable_name(var->get_node_name());
printer->fmt_line("{} = xi;", name);
}
printer->add_line("return 0;");
} else {
printer->add_line("return xi;");
}
printer->add_line(" return 0;");
printer->add_line("}");
printer->end_block(1);

printer->add_line(fmt::format("if (xi <= 0.0 || xi >= {}) {}", with, "{"));
printer->add_line(fmt::format(" int index = (xi <= 0.0) ? 0 : {};", with));
for (const auto& variable: table_variables) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
printer->fmt_start_block("if (xi <= 0. || xi >= {}.)", with);
printer->fmt_line("int index = (xi <= 0.) ? 0 : {};", with);
if (node.is_procedure_block()) {
for (const auto& variable: table_variables) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
auto table_name = get_variable_name("t_" + name);
printer->fmt_line("{} = {}[index];", instance_name, table_name);
}
printer->add_line("return 0;");
} else {
auto table_name = get_variable_name("t_" + name);
printer->add_line(fmt::format(" {} = {}[index];", instance_name, table_name));
printer->fmt_line("return {}[index];", table_name);
}
printer->add_line(" return 0;");
printer->add_line("}");
printer->end_block(1);

printer->add_line("int i = int(xi);");
printer->add_line("double theta = xi - double(i);");
for (const auto& var: table_variables) {
auto instance_name = get_variable_name(var->get_node_name());
auto table_name = get_variable_name("t_" + var->get_node_name());
printer->add_line(
fmt::format("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);", instance_name, table_name));
if (node.is_procedure_block()) {
for (const auto& var: table_variables) {
auto instance_name = get_variable_name(var->get_node_name());
auto table_name = get_variable_name("t_" + var->get_node_name());
printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);",
instance_name,
table_name);
}
printer->add_line("return 0;");
} else {
auto table_name = get_variable_name("t_" + name);
printer->fmt_line("return {0}[i] + theta * ({0}[i+1] - {0}[i]);", table_name);
}

printer->add_line("return 0;");
}
printer->end_block(1);
}
Expand All @@ -1630,25 +1658,25 @@ void CodegenCVisitor::print_check_table_thread_function() {
auto name = method_name("check_table_thread");
auto parameters = external_method_parameters(true);

printer->add_line(fmt::format("static void {} ({}) {}", name, parameters, "{"));
printer->add_line(" Memb_list* ml = nt->_ml_list[tml_id];");
printer->add_line(" setup_instance(nt, ml);");
printer->add_line(fmt::format(" {0}* inst = ({0}*) ml->instance;", instance_struct()));
printer->add_line(" double v = 0;");
printer->fmt_start_block("static void {} ({})", name, parameters);
printer->add_line("Memb_list* ml = nt->_ml_list[tml_id];");
printer->add_line("setup_instance(nt, ml);");
printer->fmt_line("{0}* inst = ({0}*) ml->instance;", instance_struct());
printer->add_line("double v = 0;");

for (const auto& function: info.functions_with_table) {
auto name = method_name("check_" + function->get_node_name());
auto arguments = internal_method_arguments();
printer->add_line(fmt::format(" {}({});", name, arguments));
printer->add_line(fmt::format("{}({});", name, arguments));
}

/**
* \todo `check_table_thread` is called multiple times from coreneuron including
* after `finitialize`. If we cleaup the instance then it will result in segfault
* but if we don't then there is memory leak
*/
printer->add_line(" // cleanup_instance(ml);");
printer->add_line("}");
printer->add_line("// cleanup_instance(ml);");
printer->end_block(1);
}


Expand Down Expand Up @@ -2447,19 +2475,18 @@ void CodegenCVisitor::print_mechanism_global_var_structure() {
auto float_type = default_float_data_type();
printer->add_newline(2);
printer->add_line("/** all global variables */");
printer->add_line(fmt::format("struct {} {}", global_struct(), "{"));
printer->increase_indent();
printer->fmt_start_block("struct {}", global_struct());

if (!info.ions.empty()) {
for (const auto& ion: info.ions) {
auto name = fmt::format("{}_type", ion.name);
printer->add_line(fmt::format("{}int {};", qualifier, name));
printer->fmt_line("{}int {};", qualifier, name);
codegen_global_variables.push_back(make_symbol(name));
}
}

if (info.point_process) {
printer->add_line(fmt::format("{}int point_type;", qualifier));
printer->fmt_line("{}int point_type;", qualifier);
codegen_global_variables.push_back(make_symbol("point_type"));
}

Expand All @@ -2468,7 +2495,7 @@ void CodegenCVisitor::print_mechanism_global_var_structure() {
auto name = var->get_name() + "0";
auto symbol = program_symtab->lookup(name);
if (symbol == nullptr) {
printer->add_line(fmt::format("{}{} {};", qualifier, float_type, name));
printer->fmt_line("{}{} {};", qualifier, float_type, name);
codegen_global_variables.push_back(make_symbol(name));
}
}
Expand All @@ -2484,28 +2511,30 @@ void CodegenCVisitor::print_mechanism_global_var_structure() {
auto name = var->get_name();
auto length = var->get_length();
if (var->is_array()) {
printer->add_line(fmt::format("{}{} {}[{}];", qualifier, float_type, name, length));
printer->fmt_line("{}{} {}[{}];", qualifier, float_type, name, length);
} else {
printer->add_line(fmt::format("{}{} {};", qualifier, float_type, name));
printer->fmt_line("{}{} {};", qualifier, float_type, name);
}
codegen_global_variables.push_back(var);
}
}

if (!info.thread_variables.empty()) {
printer->add_line(fmt::format("{}int thread_data_in_use;", qualifier));
printer->add_line(
fmt::format("{}{} thread_data[{}];", qualifier, float_type, info.thread_var_data_size));
printer->fmt_line("{}int thread_data_in_use;", qualifier);
printer->fmt_line("{}{} thread_data[{}];",
qualifier,
float_type,
info.thread_var_data_size);
codegen_global_variables.push_back(make_symbol("thread_data_in_use"));
auto symbol = make_symbol("thread_data");
symbol->set_as_array(info.thread_var_data_size);
codegen_global_variables.push_back(symbol);
}

printer->add_line(fmt::format("{}int reset;", qualifier));
printer->fmt_line("{}int reset;", qualifier);
codegen_global_variables.push_back(make_symbol("reset"));

printer->add_line(fmt::format("{}int mech_type;", qualifier));
printer->fmt_line("{}int mech_type;", qualifier);
codegen_global_variables.push_back(make_symbol("mech_type"));

auto& globals = info.global_variables;
Expand All @@ -2516,9 +2545,9 @@ void CodegenCVisitor::print_mechanism_global_var_structure() {
auto name = var->get_name();
auto length = var->get_length();
if (var->is_array()) {
printer->add_line(fmt::format("{}{} {}[{}];", qualifier, float_type, name, length));
printer->fmt_line("{}{} {}[{}];", qualifier, float_type, name, length);
} else {
printer->add_line(fmt::format("{}{} {};", qualifier, float_type, name));
printer->fmt_line("{}{} {};", qualifier, float_type, name);
}
codegen_global_variables.push_back(var);
}
Expand All @@ -2528,50 +2557,54 @@ void CodegenCVisitor::print_mechanism_global_var_structure() {
for (const auto& var: constants) {
auto name = var->get_name();
auto value_ptr = var->get_value();
printer->add_line(fmt::format("{}{} {};", qualifier, float_type, name));
printer->fmt_line("{}{} {};", qualifier, float_type, name);
codegen_global_variables.push_back(var);
}
}

if (info.primes_size != 0) {
printer->add_line(fmt::format("int* {}slist1;", qualifier));
printer->add_line(fmt::format("int* {}dlist1;", qualifier));
printer->fmt_line("int* {}slist1;", qualifier);
printer->fmt_line("int* {}dlist1;", qualifier);
codegen_global_variables.push_back(make_symbol("slist1"));
codegen_global_variables.push_back(make_symbol("dlist1"));
if (info.derivimplicit_used()) {
printer->add_line(fmt::format("int* {}slist2;", qualifier));
printer->fmt_line("int* {}slist2;", qualifier);
codegen_global_variables.push_back(make_symbol("slist2"));
}
}

if (info.table_count > 0) {
printer->add_line(fmt::format("{}double usetable;", qualifier));
printer->fmt_line("{}double usetable;", qualifier);
codegen_global_variables.push_back(make_symbol(naming::USE_TABLE_VARIABLE));

for (const auto& block: info.functions_with_table) {
auto name = block->get_node_name();
printer->add_line(fmt::format("{}{} tmin_{};", qualifier, float_type, name));
printer->add_line(fmt::format("{}{} mfac_{};", qualifier, float_type, name));
printer->fmt_line("{}{} tmin_{};", qualifier, float_type, name);
printer->fmt_line("{}{} mfac_{};", qualifier, float_type, name);
codegen_global_variables.push_back(make_symbol("tmin_" + name));
codegen_global_variables.push_back(make_symbol("mfac_" + name));
if (block->is_function_block()) {
printer->fmt_line("{}* {}t_{};", float_type, qualifier, name);
codegen_global_variables.push_back(make_symbol("t_" + name));
}
}

for (const auto& variable: info.table_statement_variables) {
auto name = "t_" + variable->get_name();
printer->add_line(fmt::format("{}* {}{};", float_type, qualifier, name));
printer->fmt_line("{}* {}{};", float_type, qualifier, name);
codegen_global_variables.push_back(make_symbol(name));
}
}

if (info.vectorize) {
printer->add_line(fmt::format("ThreadDatum* {}ext_call_thread;", qualifier));
printer->fmt_line("ThreadDatum* {}ext_call_thread;", qualifier);
codegen_global_variables.push_back(make_symbol("ext_call_thread"));
}

printer->decrease_indent();
printer->add_line("};");
printer->end_block(0);
printer->add_text(";");
printer->add_newline(2);

printer->add_newline(1);
printer->add_line("/** holds object of global variable */");
print_global_variable_device_create_annotation_pre();
print_global_var_struct_decl();
Expand Down
Loading