Skip to content

Commit

Permalink
feat: Add r4k script transfom target
Browse files Browse the repository at this point in the history
  • Loading branch information
Diegovsky committed Aug 19, 2023
1 parent ea5cc62 commit 591ec51
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformations/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod ode;
pub mod r4k;
43 changes: 43 additions & 0 deletions src/transformations/r4k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use minijinja::{context, Environment};

use crate::Model;

const ODE_PY_TEMPLATE: &str = include_str!("../../templates/ode.py.txt");

pub fn render_ode(model: Model) -> String {
let env = Environment::new();

let mut ctx = context! {
model => model,
};

env.render_str(ODE_PY_TEMPLATE, &mut ctx).unwrap()
}

#[cfg(test)]
mod tests {
#[test]
fn test_render_ode_abc_json() {
use super::*;

const ABC_JSON_STR: &str = include_str!("../../tests/fixtures/abc.json");

let model = serde_json::from_str::<Model>(ABC_JSON_STR).unwrap();

let ode = render_ode(model);

const EXPECTED: &str = r#"import numpy as np
from numpy import float64, ndarray
def system(t: float64, P: ndarray, *args):
A, B, C = P
dA = (A * B )
dB = (A * B )
dC = (A * B / C )
return ndarray([dA, dB, dC])"#;

assert_eq!(ode, EXPECTED);
}
}
37 changes: 37 additions & 0 deletions templates/ode.py.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
from numpy import float64, ndarray

def system(t: float64, P: ndarray, *args):
{% for node_id, node in model.nodes|items if node.related_constant_name is defined -%}
{{ node.name }} {%- if not loop.last %}, {% endif %}
{%- endfor %} = P
{% for node_id, node in model.nodes|items -%}
{% if node.related_constant_name %}
d{{ node.name }} = {% for link in node.links -%}
{%- set link_node = model.nodes[link.node_id] -%}
{%- if link_node.related_constant_name -%}
({{ link.sign }} {{ link_node.name }})
{%- else -%}
(
{%- for input in link_node.inputs recursive -%}
{%- set inner_link_node = model.nodes[input] %}
{%- if inner_link_node.related_constant_name -%}
{{ inner_link_node.name }} {% if not loop.last %} {{ link_node.operation }} {% endif %}
{%- else -%}
{%- set old_link_node = link_node -%}
{%- set link_node = inner_link_node -%}
{{ loop(inner_link_node.inputs) }}
{%- set link_node = old_link_node -%}
{% if not loop.last %} {{ link_node.operation }} {% endif %}
{%- endif -%}
{%- endfor -%}
)
{%- endif -%}
{% endfor -%}
{% endif %}
{%- endfor %}

return ndarray([
{%- for node_id, node in model.nodes|items if node.related_constant_name is defined -%}
d{{ node.name }} {%- if not loop.last %}, {% endif %}
{%- endfor -%}])

0 comments on commit 591ec51

Please sign in to comment.