Skip to content

Commit

Permalink
Workaround for axis_windows bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasBoss committed Jul 14, 2023
1 parent aa4f977 commit 893cbd6
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions src/interp1d/strategies/cubic_spline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use std::{
};

use ndarray::{
s, Array, ArrayBase, ArrayViewMut, Axis, Data, Dimension, Ix1, RemoveAxis, ScalarOperand,
Slice, Zip,
s, Array, ArrayBase, ArrayViewMut, Axis, Data, Dimension, Ix1, RemoveAxis, ScalarOperand, Zip,
};
use num_traits::{cast, Num, NumCast, Pow};

Expand Down Expand Up @@ -132,27 +131,23 @@ impl CubicSpline {
// RHS vector
let mut rhs: Array<Sd::Elem, D> = Array::zeros(dim.clone());

let mut inner_rhs = rhs.slice_axis_mut(AX0, Slice::from(1..-1));
Zip::from(inner_rhs.axis_iter_mut(AX0))
.and(x.windows(3))
.and(data.axis_windows(AX0, 3))
.for_each(|rhs, x, data| {
let y_left = data.index_axis(AX0, 0);
let y_mid = data.index_axis(AX0, 1);
let y_right = data.index_axis(AX0, 2);
let x_left = x[0];
let x_mid = x[1];
let x_right = x[2];

Zip::from(y_left).and(y_mid).and(y_right).map_assign_into(
rhs,
|&y_left, &y_mid, &y_right| {
three
* ((y_mid - y_left) / (x_mid - x_left).pow(two)
+ (y_right - y_mid) / (x_right - x_mid).pow(two))
},
);
});
for i in 1..len-1{
let rhs = rhs.index_axis_mut(AX0, i);
let y_left = data.index_axis(AX0, i-1);
let y_mid = data.index_axis(AX0, i);
let y_right = data.index_axis(AX0, i+1);
let x_left = x[i-1];
let x_mid = x[i];
let x_right = x[i+1];
Zip::from(y_left).and(y_mid).and(y_right).map_assign_into(
rhs,
|&y_left, &y_mid, &y_right| {
three
* ((y_mid - y_left) / (x_mid - x_left).pow(two)
+ (y_right - y_mid) / (x_right - x_mid).pow(two))
},
);
}

let rhs_0 = rhs.index_axis_mut(AX0, 0);
let data_0 = data.index_axis(AX0, 0);
Expand Down

0 comments on commit 893cbd6

Please sign in to comment.