Skip to content

Commit

Permalink
Merge pull request #4 from AleCandido/prepare_for_ndinterpolation
Browse files Browse the repository at this point in the history
Update the code for n-d
  • Loading branch information
scarlehoff authored Jul 20, 2023
2 parents 4b2bb40 + 2c0dae3 commit 381fa8f
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 107 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,11 @@ members = ["ndinterp", "ndinterp_capi"] # "ndinterp_py"]

[profile.release]
lto = true

[workspace.package]
authors = ["Alessandro Candido <alessandro.candido@sns.it>", "Juan Cruz-Martinez <juacrumar@lairen.eu>", "Christopher Schwan <handgranaten-herbert@posteo.de>"]
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp"
categories = ["science"]
description = "N-dimensional interpolation library"
12 changes: 7 additions & 5 deletions ndinterp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
[package]
name = "ndinterp"
version = "0.0.1"
authors = ["Alessandro Candido <alessandro.candido@sns.it>"]
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp"
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
categories.workspace = true
readme = "README.md"
categories = ["science"]
description = "N-dimensional interpolation library"
keywords = ["math", "science"]

[dependencies]
ndarray = "0.15.4"
ndarray-linalg = "0.14.1"
petgraph = "0.6.2"
thiserror = "1.0.40"
itertools = "0.11.0"
2 changes: 1 addition & 1 deletion ndinterp/examples/alphas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn main() {
];

let grid = Grid {
input: logq2,
xgrid: vec![logq2.to_vec()],
values: alpha_s_vals,
};

Expand Down
90 changes: 59 additions & 31 deletions ndinterp/src/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,34 @@
//! y = f(x1, x2, x3...)
//!
use crate::interpolate::InterpolationError;
use ndarray::Array1;
use itertools::izip;
use ndarray::{Array, ArrayView1, Dimension};

// Make public the families of interpolation algorithms implemented for grids
pub mod cubic;

/// A grid is made of two components:
/// A d-dimensional vector of 1-dimensional sorted vectors for the input points
/// A d-dimensional array for the grid values of
#[derive(Debug)]
pub struct Grid {
/// A grid is made of two (1-dimensional) sorted arrays.
pub input: Array1<f64>,
pub values: Array1<f64>,
pub struct Grid<D: Dimension> {
/// Arrays with the input vectors (x_i)
pub xgrid: Vec<Vec<f64>>,

/// Output points
pub values: Array<f64, D>,
}

/// A grid slice is always 1-Dimensional
#[derive(Debug)]
pub struct GridSlice<'a> {
/// A reference to one of the input vectors of the grid
pub x: &'a Vec<f64>,
/// A view of the slice of values corresponding to x
pub y: ArrayView1<'a, f64>,
}

impl Grid {
impl<'a> GridSlice<'a> {
// TODO: at the moment we are using here the derivatives that LHAPDF is using for the
// interpolation in alpha_s, these are probably enough for this use case but not in general
// - [ ] Implement a more robust form of the derivative
Expand All @@ -33,9 +48,9 @@ impl Grid {
/// input at position index as the ratio between the differences dy/dx computed as:
/// dy = y_{i} - y_{i-1}
/// dx = x_{i} - x_{x-1}
pub fn derivative_at(&self, index: usize) -> f64 {
let dx = self.input[index] - self.input[index - 1];
let dy = self.values[index] - self.values[index - 1];
pub fn derivative_at(&'a self, index: usize) -> f64 {
let dx = self.x[index] - self.x[index - 1];
let dy = self.y[index] - self.y[index - 1];
dy / dx
}

Expand All @@ -44,53 +59,66 @@ impl Grid {
///
/// Dx_{i} = \Delta x_{i} = x_{i} - x_{i-}
/// y'_{i} = 1/2 * ( (y_{i+1}-y_{i})/Dx_{i+1} + (y_{i}-y_{i-1})/Dx_{i} )
pub fn central_derivative_at(&self, index: usize) -> f64 {
pub fn central_derivative_at(&'a self, index: usize) -> f64 {
let dy_f = self.derivative_at(index + 1);
let dy_b = self.derivative_at(index);
0.5 * (dy_f + dy_b)
}
}

/// Find the index of the last value in the input such that input(idx) < query
impl<D: Dimension> Grid<D> {
/// Find the index of the last value in the input xgrid such that xgrid(idx) < query
/// If the query is outside the grid returns an extrapolation error
pub fn closest_below(&self, query: f64) -> Result<usize, InterpolationError> {
if query > self.input[self.input.len() - 1] {
Err(InterpolationError::ExtrapolationAbove(query))
} else if query < self.input[0] {
Err(InterpolationError::ExtrapolationBelow(query))
} else {
let u_idx = self.input.iter().position(|x| x > &query).unwrap();
let idx = u_idx - 1;
Ok(idx)
pub fn closest_below<const N: usize>(
&self,
input_query: &[f64],
) -> Result<[usize; N], InterpolationError> {
let mut ret = [0; N];

for (r, &query, igrid) in izip!(&mut ret, input_query, &self.xgrid) {
if query > *igrid.last().unwrap() {
return Err(InterpolationError::ExtrapolationAbove(query));
} else if query < igrid[0] {
return Err(InterpolationError::ExtrapolationBelow(query));
}

let u_idx = igrid.partition_point(|&x| x < query);
*r = u_idx - 1;
}
Ok(ret)
}
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use ndarray::{array, Ix1};

fn gen_grid() -> Grid {
let x = array![0., 1., 2., 3., 4.];
fn gen_grid() -> Grid<Ix1> {
let x = vec![vec![0., 1., 2., 3., 4.]];
let y = array![4., 3., 2., 1., 1.];
let grid = Grid {
input: x,

Grid {
xgrid: x,
values: y,
};
grid
}
}

#[test]
fn check_derivative() {
let grid = gen_grid();
assert_eq!(grid.central_derivative_at(1), -1.);
assert_eq!(grid.central_derivative_at(3), -0.5);
let grid_slice = GridSlice {
x: &grid.xgrid[0],
y: grid.values.view(),
};
assert_eq!(grid_slice.central_derivative_at(1), -1.);
assert_eq!(grid_slice.central_derivative_at(3), -0.5);
}

#[test]
fn check_index_search() {
let grid = gen_grid();
assert_eq!(grid.closest_below(0.5).unwrap(), 0);
assert_eq!(grid.closest_below(3.2).unwrap(), 3);
assert_eq!(grid.closest_below::<1>(&[0.5]).unwrap()[0], 0);
assert_eq!(grid.closest_below::<1>(&[3.2]).unwrap()[0], 3);
}
}
65 changes: 42 additions & 23 deletions ndinterp/src/grid/cubic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Implements cubic interpolation algorithms
use crate::grid::Grid;
use crate::grid::{Grid, GridSlice};
use crate::interpolate::InterpolationError;
pub use crate::interpolate::Interpolator;
use ndarray::Ix1;

/// Cubic interpolation in 1D
///
Expand All @@ -18,40 +19,35 @@ pub use crate::interpolate::Interpolator;
/// with hij the Hermite basis functions
#[derive(Debug)]
pub struct Cubic1d {
pub grid: Grid,
/// The grid object contains all necessary information to perform the interpolation
pub grid: Grid<Ix1>,
}

impl Interpolator<f64> for Cubic1d {
/// Use Cubic interpolation 1d to compute y(query)
/// The interpolation uses the two nearest neighbours and their derivatives computed as an
/// average of the differences above and below.
///
/// Special cases are considered when the interpolation occurs between the first (last) two
/// bins, where the derivative would involve points outside the grids.
///
/// Two special are considered, when the interpolation occurs between the first (last) two
/// bins, the derivative at the boundary is approximated by the forward (backward) difference
fn interpolate(&self, query: f64) -> Result<f64, InterpolationError> {
let idx = self.grid.closest_below(query)?;
let dx = self.grid.input[idx + 1] - self.grid.input[idx];
impl<'a> GridSlice<'a> {
/// Implements utilities for a GridSlice that can be used by cubic interpolation Nd
/// Takes as input the value being queried and its index within the given slice
fn cubic_interpolate_1d(&'a self, query: f64, idx: usize) -> Result<f64, InterpolationError> {
// grid slice utilities are expected to be called multipled times for the same
// query and so it is convient to pass idx from the outside to avoid expensive searches
let dx = self.x[idx + 1] - self.x[idx];

// Upper and lower bounds and derivatives
let yu = self.grid.values[idx + 1];
let yl = self.grid.values[idx];
let yu = self.y[idx + 1];
let yl = self.y[idx];

let dydxu = if idx == self.grid.input.len() - 2 {
dx * self.grid.derivative_at(idx + 1)
let dydxu = if idx == self.x.len() - 2 {
dx * self.derivative_at(idx + 1)
} else {
dx * self.grid.central_derivative_at(idx + 1)
dx * self.central_derivative_at(idx + 1)
};

let dydxl = if idx == 0 {
dx * self.grid.derivative_at(idx + 1)
dx * self.derivative_at(idx + 1)
} else {
dx * self.grid.central_derivative_at(idx)
dx * self.central_derivative_at(idx)
};

let t = (query - self.grid.input[idx]) / dx;
let t = (query - self.x[idx]) / dx;
let t2 = t * t;
let t3 = t2 * t;

Expand All @@ -63,3 +59,26 @@ impl Interpolator<f64> for Cubic1d {
Ok(p0 + p1 + m0 + m1)
}
}

impl Interpolator<f64> for Cubic1d {
/// Use Cubic interpolation 1d to compute y(query)
/// The interpolation uses the two nearest neighbours and their derivatives computed as an
/// average of the differences above and below.
///
/// Special cases are considered when the interpolation occurs between the first (last) two
/// bins, where the derivative would involve points outside the grids.
///
/// Two special are considered, when the interpolation occurs between the first (last) two
/// bins, the derivative at the boundary is approximated by the forward (backward) difference
fn interpolate(&self, query: f64) -> Result<f64, InterpolationError> {
let raw_idx = self.grid.closest_below::<1>(&[query])?;
let idx = raw_idx[0];

let grid_sl = GridSlice {
x: &self.grid.xgrid[0],
y: self.grid.values.view(),
};

grid_sl.cubic_interpolate_1d(query, idx)
}
}
65 changes: 34 additions & 31 deletions ndinterp/src/interpolate.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
use std::iter::zip;
//! This module implements interpolation rutines
use thiserror::Error;

use crate::metric::Metric;

/// Errors encountered during interpolation
#[derive(Debug, Error)]
pub enum InterpolationError {
/// Raised when the queried value is above the maximum
#[error("The value queried ({0}) is above the maximum")]
ExtrapolationAbove(f64),

/// Raised when the queried value is below the minimum
#[error("The value queried ({0}) is below the minimum")]
ExtrapolationBelow(f64),
}

/// Methods which all interpolator must implement
pub trait Interpolator<T> {
/// Produce the result of the inteprolation given a (nd) point 'query'
fn interpolate(&self, query: T) -> Result<f64, InterpolationError>;
}

/// ---- deal with the stuff below later ----
pub trait Interpolate {
type Point: Metric;

fn interpolate(&self, query: &Self::Point) -> f64;
}

pub struct Input<Point: Metric> {
pub point: Point,
pub value: f64,
}

impl<Point: Metric> From<(Point, f64)> for Input<Point> {
fn from(item: (Point, f64)) -> Self {
Self {
point: item.0,
value: item.1,
}
}
}

impl<Point: Metric> Input<Point> {
pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
zip(points.into_iter(), values.into_iter())
.map(|t| t.into())
.collect()
}
}
///// ---- deal with the stuff below later ----
//pub trait Interpolate {
// type Point: Metric;
//
// fn interpolate(&self, query: &Self::Point) -> f64;
//}
//
//pub struct Input<Point: Metric> {
// pub point: Point,
// pub value: f64,
//}
//
//impl<Point: Metric> From<(Point, f64)> for Input<Point> {
// fn from(item: (Point, f64)) -> Self {
// Self {
// point: item.0,
// value: item.1,
// }
// }
//}
//
//impl<Point: Metric> Input<Point> {
// pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
// zip(points.into_iter(), values.into_iter())
// .map(|t| t.into())
// .collect()
// }
//}
8 changes: 3 additions & 5 deletions ndinterp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#![warn(clippy::all, clippy::cargo, clippy::nursery, clippy::pedantic)]
#![warn(clippy::all, clippy::cargo)]
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic, clippy::restriction, clippy::cargo)]
#![warn(missing_docs)]

pub mod grid;
pub mod interpolate;
pub mod metric;
pub mod scatter;
//pub mod metric;
//pub mod scatter;
Loading

0 comments on commit 381fa8f

Please sign in to comment.