-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add r4k script transfom target
- Loading branch information
Showing
3 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
pub mod ode; | ||
pub mod r4k; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
use minijinja::{context, Environment}; | ||
|
||
use crate::Model; | ||
|
||
const ODE_PY_TEMPLATE: &str = include_str!("../../templates/ode.py.txt"); | ||
|
||
pub fn render_ode(model: Model) -> String { | ||
let env = Environment::new(); | ||
|
||
let mut ctx = context! { | ||
model => model, | ||
}; | ||
|
||
env.render_str(ODE_PY_TEMPLATE, &mut ctx).unwrap() | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[test] | ||
fn test_render_ode_abc_json() { | ||
use super::*; | ||
|
||
const ABC_JSON_STR: &str = include_str!("../../tests/fixtures/abc.json"); | ||
|
||
let model = serde_json::from_str::<Model>(ABC_JSON_STR).unwrap(); | ||
|
||
let ode = render_ode(model); | ||
|
||
const EXPECTED: &str = r#"import numpy as np | ||
from numpy import float64, ndarray | ||
def system(t: float64, P: ndarray, *args): | ||
A, B, C = P | ||
dA = (A * B ) | ||
dB = (A * B ) | ||
dC = (A * B / C ) | ||
return ndarray([dA, dB, dC])"#; | ||
|
||
assert_eq!(ode, EXPECTED); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
from numpy import float64, ndarray | ||
|
||
def system(t: float64, P: ndarray, *args): | ||
{% for node_id, node in model.nodes|items if node.related_constant_name is defined -%} | ||
{{ node.name }} {%- if not loop.last %}, {% endif %} | ||
{%- endfor %} = P | ||
{% for node_id, node in model.nodes|items -%} | ||
{% if node.related_constant_name %} | ||
d{{ node.name }} = {% for link in node.links -%} | ||
{%- set link_node = model.nodes[link.node_id] -%} | ||
{%- if link_node.related_constant_name -%} | ||
({{ link.sign }} {{ link_node.name }}) | ||
{%- else -%} | ||
( | ||
{%- for input in link_node.inputs recursive -%} | ||
{%- set inner_link_node = model.nodes[input] %} | ||
{%- if inner_link_node.related_constant_name -%} | ||
{{ inner_link_node.name }} {% if not loop.last %} {{ link_node.operation }} {% endif %} | ||
{%- else -%} | ||
{%- set old_link_node = link_node -%} | ||
{%- set link_node = inner_link_node -%} | ||
{{ loop(inner_link_node.inputs) }} | ||
{%- set link_node = old_link_node -%} | ||
{% if not loop.last %} {{ link_node.operation }} {% endif %} | ||
{%- endif -%} | ||
{%- endfor -%} | ||
) | ||
{%- endif -%} | ||
{% endfor -%} | ||
{% endif %} | ||
{%- endfor %} | ||
|
||
return ndarray([ | ||
{%- for node_id, node in model.nodes|items if node.related_constant_name is defined -%} | ||
d{{ node.name }} {%- if not loop.last %}, {% endif %} | ||
{%- endfor -%}]) |