Skip to content

Commit

Permalink
Add the Send trait to OdeSolverTrait. Improve multithreaded test with…
Browse files Browse the repository at this point in the history
… Rayon
  • Loading branch information
cpmech committed May 23, 2024
1 parent 498b7e0 commit b6d3fbe
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 39 deletions.
1 change: 1 addition & 0 deletions russell_ode/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ structopt = "0.3"

[dev-dependencies]
plotpy = "0.6"
rayon = "1.10"
serial_test = "3.0"
2 changes: 1 addition & 1 deletion russell_ode/src/ode_solver_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{Params, Workspace};
use russell_lab::Vector;

/// Defines the numerical solver
pub(crate) trait OdeSolverTrait<A> {
pub(crate) trait OdeSolverTrait<A>: Send {
/// Enables dense output
fn enable_dense_output(&mut self) -> Result<(), StrError>;

Expand Down
19 changes: 10 additions & 9 deletions russell_ode/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::collections::HashMap;
use std::fs::{self, File};
use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;

/// Holds the data generated at an accepted step or during the dense output
#[derive(Clone, Debug, Deserialize)]
Expand Down Expand Up @@ -51,7 +52,7 @@ pub struct Output<'a, A> {

// --- step --------------------------------------------------------------------------------------------
/// Holds a callback function called on an accepted step
step_callback: Option<Box<dyn Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + 'a>>,
step_callback: Option<Arc<dyn Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + Send + Sync + 'a>>,

/// Save the results to a file (step)
step_file_key: Option<String>,
Expand Down Expand Up @@ -79,7 +80,7 @@ pub struct Output<'a, A> {

// --- dense -------------------------------------------------------------------------------------------
/// Holds a callback function for the dense output
dense_callback: Option<Box<dyn Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + 'a>>,
dense_callback: Option<Arc<dyn Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + Send + Sync + 'a>>,

/// Save the results to a file (dense)
dense_file_key: Option<String>,
Expand Down Expand Up @@ -123,7 +124,7 @@ pub struct Output<'a, A> {
y_aux: Vector,

/// Holds the y(x) function (e.g., to compute the correct/analytical solution)
yx_function: Option<Box<dyn Fn(&mut Vector, f64, &mut A) + 'a>>,
yx_function: Option<Arc<dyn Fn(&mut Vector, f64, &mut A) + Send + Sync + 'a>>,
}

impl OutData {
Expand Down Expand Up @@ -220,9 +221,9 @@ impl<'a, A> Output<'a, A> {
/// * `callback` -- function to be executed on an accepted step
pub fn set_step_callback(
&mut self,
callback: impl Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + 'a,
callback: impl Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + Send + Sync + 'a,
) -> &mut Self {
self.step_callback = Some(Box::new(callback));
self.step_callback = Some(Arc::new(callback));
self
}

Expand Down Expand Up @@ -309,9 +310,9 @@ impl<'a, A> Output<'a, A> {
/// * `callback` -- function to be executed on the selected output stations
pub fn set_dense_callback(
&mut self,
callback: impl Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + 'a,
callback: impl Fn(&Stats, f64, f64, &Vector, &mut A) -> Result<bool, StrError> + Send + Sync + 'a,
) -> &mut Self {
self.dense_callback = Some(Box::new(callback));
self.dense_callback = Some(Arc::new(callback));
self
}

Expand Down Expand Up @@ -357,8 +358,8 @@ impl<'a, A> Output<'a, A> {
/// Sets the function to compute the correct/reference results y(x)
///
/// Use `|y, x, args|` or `|y: &mut Vector, x: f64, args, &mut A|`
pub fn set_yx_correct(&mut self, y_fn_x: impl Fn(&mut Vector, f64, &mut A) + 'a) -> &mut Self {
self.yx_function = Some(Box::new(y_fn_x));
pub fn set_yx_correct(&mut self, y_fn_x: impl Fn(&mut Vector, f64, &mut A) + Send + Sync + 'a) -> &mut Self {
self.yx_function = Some(Arc::new(y_fn_x));
self
}

Expand Down
65 changes: 36 additions & 29 deletions russell_ode/tests/test_multithreaded.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use russell_lab::{approx_eq, Vector};
use russell_ode::{Method, NoArgs, OdeSolver, Params, System};
use std::thread;

struct Simulator<'a> {
struct SimData<'a> {
solver: OdeSolver<'a, NoArgs>,
x0: f64,
x1: f64,
y: Vector,
a: u8,
}

impl<'a> Simulator<'a> {
impl<'a> SimData<'a> {
fn new(method: Method) -> Self {
let system = System::new(1, |f: &mut Vector, _x: f64, _y: &Vector, _args: &mut u8| {
f[0] = 1.0;
Ok(())
});
let params = Params::new(method);
Simulator {
SimData {
solver: OdeSolver::new(params, system).unwrap(),
x0: 0.0,
x1: 1.5,
Expand All @@ -27,31 +27,38 @@ impl<'a> Simulator<'a> {
}
}

struct Simulator<'a> {
data: SimData<'a>,
}

impl<'a> Simulator<'a> {
fn new(method: Method) -> Self {
Simulator {
data: SimData::new(method),
}
}
}

trait Runner: Send {
fn run_and_check(&mut self);
}

impl<'a> Runner for Simulator<'a> {
fn run_and_check(&mut self) {
self.data
.solver
.solve(&mut self.data.y, self.data.x0, self.data.x1, None, &mut self.data.a)
.unwrap();
approx_eq(self.data.y[0], self.data.x1, 1e-15);
}
}

#[test]
fn test_multithreaded() {
// run two simulations concurrently
thread::scope(|scope| {
let first = scope.spawn(move || {
let mut sim = Simulator::new(Method::FwEuler);
sim.solver.solve(&mut sim.y, sim.x0, sim.x1, None, &mut sim.a).unwrap();
approx_eq(sim.y[0], sim.x1, 1e-15);
});
let second = scope.spawn(move || {
let mut sim = Simulator::new(Method::MdEuler);
sim.solver.solve(&mut sim.y, sim.x0, sim.x1, None, &mut sim.a).unwrap();
approx_eq(sim.y[0], sim.x1, 1e-15);
});
let err1 = first.join();
let err2 = second.join();
if err1.is_err() && err2.is_err() {
Err("first and second failed")
} else if err1.is_err() {
Err("first failed")
} else if err2.is_err() {
Err("second failed")
} else {
Ok(())
}
})
.unwrap();
// run simulations concurrently
let mut runners: Vec<Box<dyn Runner>> = vec![
Box::new(Simulator::new(Method::FwEuler)),
Box::new(Simulator::new(Method::MdEuler)),
];
runners.par_iter_mut().for_each(|r| r.run_and_check());
}

0 comments on commit b6d3fbe

Please sign in to comment.