Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the code for n-d #4

Merged
merged 9 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,11 @@ members = ["ndinterp", "ndinterp_capi"] # "ndinterp_py"]

[profile.release]
lto = true

[workspace.package]
authors = ["Alessandro Candido <alessandro.candido@sns.it>", "Juan Cruz-Martinez <juacrumar@lairen.eu>", "Christopher Schwan <handgranaten-herbert@posteo.de>"]
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp"
categories = ["science"]
description = "N-dimensional interpolation library"
12 changes: 7 additions & 5 deletions ndinterp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
[package]
name = "ndinterp"
version = "0.0.1"
authors = ["Alessandro Candido <alessandro.candido@sns.it>"]
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp"
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
categories.workspace = true
readme = "README.md"
categories = ["science"]
description = "N-dimensional interpolation library"
keywords = ["math", "science"]

[dependencies]
ndarray = "0.15.4"
ndarray-linalg = "0.14.1"
petgraph = "0.6.2"
thiserror = "1.0.40"
itertools = "0.11.0"
2 changes: 1 addition & 1 deletion ndinterp/examples/alphas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn main() {
];

let grid = Grid {
input: logq2,
xgrid: vec![logq2.to_vec()],
values: alpha_s_vals,
};

Expand Down
90 changes: 59 additions & 31 deletions ndinterp/src/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,34 @@
//! y = f(x1, x2, x3...)
//!
use crate::interpolate::InterpolationError;
use ndarray::Array1;
use itertools::izip;
use ndarray::{Array, ArrayView1, Dimension};

// Make public the families of interpolation algorithms implemented for grids
pub mod cubic;

/// A grid is made of two components:
/// A d-dimensional vector of 1-dimensional sorted vectors for the input points
/// A d-dimensional array for the grid values of
#[derive(Debug)]
pub struct Grid {
/// A grid is made of two (1-dimensional) sorted arrays.
pub input: Array1<f64>,
pub values: Array1<f64>,
pub struct Grid<D: Dimension> {
/// Arrays with the input vectors (x_i)
pub xgrid: Vec<Vec<f64>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here an even nicer option (for later on) would be something like Vec<[f64; N]>, with N: usize determined from D.

But it is again the same problem we already discussed before...


/// Output points
pub values: Array<f64, D>,
}

/// A grid slice is always 1-Dimensional
#[derive(Debug)]
pub struct GridSlice<'a> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, you started playing with lifetimes...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the compiler is not able to deduce all of them though.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are rules to infer references lifetimes for methods, but even those are limited (to simplest but also most common cases, so they are very useful). You definitely want to infer as much as possible, but consistency is hard to impose when you have many rules (I recently read about attempts to introduce linear types, and they are a mess...).

I'm not sure, but I believe the reason you need to make them explicit should be that there is an alternative way of doing, e.g. having more lifetimes and some of them explicitly specified (like generics), or something like that. Maybe?

/// A reference to one of the input vectors of the grid
pub x: &'a Vec<f64>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you need a reference to a vector, and not a slice?

(with slice I mean &'a [f64], as per standard library)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on what you mean with “need”, that might also work (and be more correct?), but x is a reference to one of the quantities in the xgrid which are vectors hence the choice.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need to use vector-specific operations, then yes, you need the reference to a vector.
But if you need to reference just a collection of points, the slice is more general, and it relies less on the underlying structure (while still be plenty of possible methods, https://doc.rust-lang.org/std/primitive.slice.html).

However, if it has to be one entire element (the full x coordinates) maybe it makes sense to keep as it is. I'd suggest instead to give it a name:

type XPoint = Vec<f64>;

(if you're able to find a better name, you're welcome). It's just an alias, for greater clarity, but here a newtype would be too much (I believe).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no. I do need a vector because of the search function.

But it is not a point, it is actually the x grid for \vec{y} = f(\vec{x}). I could call them xgrid, I decided to change it to x since it is not exposed to the user and only used internally (as opposed to the xgrid which needs to be filled by the user).

But I can use xgrid as well here if you prefer (but I'll do the change directly in #5)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"better name" was referring to the name of the type alias I was proposing. x for the attribute is just fine.

The reason behind the alias is that an XPoint is isomorphic to a Vec<f64>, but not all Vec<f64> are XPoints. It doesn't help with automatic check, because an alias is interchangeable with the aliased type, but it helps the developer reading it (to have help also from the compiler, we could use a newtype, but as I said it seems to me that it's not worth, at the moment)

/// A view of the slice of values corresponding to x
pub y: ArrayView1<'a, f64>,
}

impl Grid {
impl<'a> GridSlice<'a> {
// TODO: at the moment we are using here the derivatives that LHAPDF is using for the
// interpolation in alpha_s, these are probably enough for this use case but not in general
// - [ ] Implement a more robust form of the derivative
Expand All @@ -33,9 +48,9 @@ impl Grid {
/// input at position index as the ratio between the differences dy/dx computed as:
/// dy = y_{i} - y_{i-1}
/// dx = x_{i} - x_{x-1}
pub fn derivative_at(&self, index: usize) -> f64 {
let dx = self.input[index] - self.input[index - 1];
let dy = self.values[index] - self.values[index - 1];
pub fn derivative_at(&'a self, index: usize) -> f64 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The annoying part of lifetimes is that you then need to specify them everywhere. Luckily not in calls, but in all types yes...

let dx = self.x[index] - self.x[index - 1];
let dy = self.y[index] - self.y[index - 1];
dy / dx
}

Expand All @@ -44,53 +59,66 @@ impl Grid {
///
/// Dx_{i} = \Delta x_{i} = x_{i} - x_{i-}
/// y'_{i} = 1/2 * ( (y_{i+1}-y_{i})/Dx_{i+1} + (y_{i}-y_{i-1})/Dx_{i} )
pub fn central_derivative_at(&self, index: usize) -> f64 {
pub fn central_derivative_at(&'a self, index: usize) -> f64 {
let dy_f = self.derivative_at(index + 1);
let dy_b = self.derivative_at(index);
0.5 * (dy_f + dy_b)
}
}

/// Find the index of the last value in the input such that input(idx) < query
impl<D: Dimension> Grid<D> {
/// Find the index of the last value in the input xgrid such that xgrid(idx) < query
/// If the query is outside the grid returns an extrapolation error
pub fn closest_below(&self, query: f64) -> Result<usize, InterpolationError> {
if query > self.input[self.input.len() - 1] {
Err(InterpolationError::ExtrapolationAbove(query))
} else if query < self.input[0] {
Err(InterpolationError::ExtrapolationBelow(query))
} else {
let u_idx = self.input.iter().position(|x| x > &query).unwrap();
let idx = u_idx - 1;
Ok(idx)
pub fn closest_below<const N: usize>(
&self,
input_query: &[f64],
) -> Result<[usize; N], InterpolationError> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I believe you'd like to use something like D::NDIM.

Unfortunately, D::NDIM: Option<usize>, so you'd need something like .unwrap(), but executed at compile time.

https://docs.rs/ndarray/0.15.6/ndarray/trait.Dimension.html#associatedconstant.NDIM

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simpler alternative might be to split the bulk of the function, and then implement thin wrappers around (one per Dimension).
Even the implementors of Dimension itself are just 7 (+ 1).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, you should be able to use a const function to unwrap it, defining an associated constant, and try to use the associated constant in the signature.

https://users.rust-lang.org/t/compile-time-const-unwrapping/51619

EDIT: Unfortunately, const panic is not stabilized, because of an issue with Rust 2021 rust-lang/rust#85194, maybe you could check if const pattern matching is allowed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A const function to unwrap IxN at compile time you mean?

Copy link
Owner

@alecandido alecandido Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, in order to turn it into a usize (if you can not panic, you can always turn IxDyn into a 0, it's useless anyhow, and it should not happen, so I wouldn't care that much).

Copy link
Collaborator Author

@scarlehoff scarlehoff Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this cannot be done. Or at least I cannot find a way of passing the D to the const function:

generic parameters may not be used in const operations
type parameter may not be used in const expressions

and in any case, if I use a function that could depend on one of the parameters:

const parameters may only be used as standalone arguments

Not sure whether that means there's no way of doing what I want to do (without creating a path per dimension, at that point I rather leave the N)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like to solve this issue. But if it's not simple, we could postpone.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to as well, but I haven't been able to...

let mut ret = [0; N];

for (r, &query, igrid) in izip!(&mut ret, input_query, &self.xgrid) {
if query > *igrid.last().unwrap() {
return Err(InterpolationError::ExtrapolationAbove(query));
} else if query < igrid[0] {
return Err(InterpolationError::ExtrapolationBelow(query));
}

let u_idx = igrid.partition_point(|&x| x < query);
*r = u_idx - 1;
}
Ok(ret)
}
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use ndarray::{array, Ix1};

fn gen_grid() -> Grid {
let x = array![0., 1., 2., 3., 4.];
fn gen_grid() -> Grid<Ix1> {
let x = vec![vec![0., 1., 2., 3., 4.]];
let y = array![4., 3., 2., 1., 1.];
let grid = Grid {
input: x,

Grid {
xgrid: x,
values: y,
};
grid
}
}

#[test]
fn check_derivative() {
let grid = gen_grid();
assert_eq!(grid.central_derivative_at(1), -1.);
assert_eq!(grid.central_derivative_at(3), -0.5);
let grid_slice = GridSlice {
x: &grid.xgrid[0],
y: grid.values.view(),
};
assert_eq!(grid_slice.central_derivative_at(1), -1.);
assert_eq!(grid_slice.central_derivative_at(3), -0.5);
}

#[test]
fn check_index_search() {
let grid = gen_grid();
assert_eq!(grid.closest_below(0.5).unwrap(), 0);
assert_eq!(grid.closest_below(3.2).unwrap(), 3);
assert_eq!(grid.closest_below::<1>(&[0.5]).unwrap()[0], 0);
assert_eq!(grid.closest_below::<1>(&[3.2]).unwrap()[0], 3);
}
}
65 changes: 42 additions & 23 deletions ndinterp/src/grid/cubic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Implements cubic interpolation algorithms

use crate::grid::Grid;
use crate::grid::{Grid, GridSlice};
use crate::interpolate::InterpolationError;
pub use crate::interpolate::Interpolator;
use ndarray::Ix1;

/// Cubic interpolation in 1D
///
Expand All @@ -18,40 +19,35 @@ pub use crate::interpolate::Interpolator;
/// with hij the Hermite basis functions
#[derive(Debug)]
pub struct Cubic1d {
pub grid: Grid,
/// The grid object contains all necessary information to perform the interpolation
pub grid: Grid<Ix1>,
}

impl Interpolator<f64> for Cubic1d {
/// Use Cubic interpolation 1d to compute y(query)
/// The interpolation uses the two nearest neighbours and their derivatives computed as an
/// average of the differences above and below.
///
/// Special cases are considered when the interpolation occurs between the first (last) two
/// bins, where the derivative would involve points outside the grids.
///
/// Two special are considered, when the interpolation occurs between the first (last) two
/// bins, the derivative at the boundary is approximated by the forward (backward) difference
fn interpolate(&self, query: f64) -> Result<f64, InterpolationError> {
let idx = self.grid.closest_below(query)?;
let dx = self.grid.input[idx + 1] - self.grid.input[idx];
impl<'a> GridSlice<'a> {
/// Implements utilities for a GridSlice that can be used by cubic interpolation Nd
/// Takes as input the value being queried and its index within the given slice
fn cubic_interpolate_1d(&'a self, query: f64, idx: usize) -> Result<f64, InterpolationError> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to call it 1d? A slice is always 1D... (or are you already planning to extend beyond?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who knows… might be necessary in the future.

// grid slice utilities are expected to be called multipled times for the same
// query and so it is convient to pass idx from the outside to avoid expensive searches
let dx = self.x[idx + 1] - self.x[idx];

// Upper and lower bounds and derivatives
let yu = self.grid.values[idx + 1];
let yl = self.grid.values[idx];
let yu = self.y[idx + 1];
let yl = self.y[idx];

let dydxu = if idx == self.grid.input.len() - 2 {
dx * self.grid.derivative_at(idx + 1)
let dydxu = if idx == self.x.len() - 2 {
dx * self.derivative_at(idx + 1)
} else {
dx * self.grid.central_derivative_at(idx + 1)
dx * self.central_derivative_at(idx + 1)
};

let dydxl = if idx == 0 {
dx * self.grid.derivative_at(idx + 1)
dx * self.derivative_at(idx + 1)
} else {
dx * self.grid.central_derivative_at(idx)
dx * self.central_derivative_at(idx)
};

let t = (query - self.grid.input[idx]) / dx;
let t = (query - self.x[idx]) / dx;
let t2 = t * t;
let t3 = t2 * t;

Expand All @@ -63,3 +59,26 @@ impl Interpolator<f64> for Cubic1d {
Ok(p0 + p1 + m0 + m1)
}
}

impl Interpolator<f64> for Cubic1d {
/// Use Cubic interpolation 1d to compute y(query)
/// The interpolation uses the two nearest neighbours and their derivatives computed as an
/// average of the differences above and below.
///
/// Special cases are considered when the interpolation occurs between the first (last) two
/// bins, where the derivative would involve points outside the grids.
///
/// Two special are considered, when the interpolation occurs between the first (last) two
/// bins, the derivative at the boundary is approximated by the forward (backward) difference
fn interpolate(&self, query: f64) -> Result<f64, InterpolationError> {
let raw_idx = self.grid.closest_below::<1>(&[query])?;
let idx = raw_idx[0];

let grid_sl = GridSlice {
x: &self.grid.xgrid[0],
y: self.grid.values.view(),
};

grid_sl.cubic_interpolate_1d(query, idx)
}
}
65 changes: 34 additions & 31 deletions ndinterp/src/interpolate.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
use std::iter::zip;
//! This module implements interpolation rutines
use thiserror::Error;

use crate::metric::Metric;

/// Errors encountered during interpolation
#[derive(Debug, Error)]
pub enum InterpolationError {
/// Raised when the queried value is above the maximum
#[error("The value queried ({0}) is above the maximum")]
ExtrapolationAbove(f64),

/// Raised when the queried value is below the minimum
#[error("The value queried ({0}) is below the minimum")]
ExtrapolationBelow(f64),
}

/// Methods which all interpolator must implement
pub trait Interpolator<T> {
/// Produce the result of the inteprolation given a (nd) point 'query'
fn interpolate(&self, query: T) -> Result<f64, InterpolationError>;
}

/// ---- deal with the stuff below later ----
pub trait Interpolate {
type Point: Metric;

fn interpolate(&self, query: &Self::Point) -> f64;
}

pub struct Input<Point: Metric> {
pub point: Point,
pub value: f64,
}

impl<Point: Metric> From<(Point, f64)> for Input<Point> {
fn from(item: (Point, f64)) -> Self {
Self {
point: item.0,
value: item.1,
}
}
}

impl<Point: Metric> Input<Point> {
pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
zip(points.into_iter(), values.into_iter())
.map(|t| t.into())
.collect()
}
}
///// ---- deal with the stuff below later ----
//pub trait Interpolate {
// type Point: Metric;
//
// fn interpolate(&self, query: &Self::Point) -> f64;
//}
//
//pub struct Input<Point: Metric> {
// pub point: Point,
// pub value: f64,
//}
//
//impl<Point: Metric> From<(Point, f64)> for Input<Point> {
// fn from(item: (Point, f64)) -> Self {
// Self {
// point: item.0,
// value: item.1,
// }
// }
//}
//
//impl<Point: Metric> Input<Point> {
// pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
// zip(points.into_iter(), values.into_iter())
// .map(|t| t.into())
// .collect()
// }
//}
8 changes: 3 additions & 5 deletions ndinterp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#![warn(clippy::all, clippy::cargo, clippy::nursery, clippy::pedantic)]
#![warn(clippy::all, clippy::cargo)]
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic, clippy::restriction, clippy::cargo)]
#![warn(missing_docs)]

pub mod grid;
pub mod interpolate;
pub mod metric;
pub mod scatter;
//pub mod metric;
//pub mod scatter;
Loading