Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the code for n-d #4

Merged
merged 9 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ndinterp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
[package]
name = "ndinterp"
version = "0.0.1"
authors = ["Alessandro Candido <alessandro.candido@sns.it>"]
authors = ["Alessandro Candido <alessandro.candido@sns.it>", "Juan Cruz-Martinez <juacrumar@lairen.eu>"]
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp"
readme = "README.md"
categories = ["science"]
description = "N-dimensional interpolation library"
keywords = ["math", "science"]

[dependencies]
ndarray = "0.15.4"
ndarray-linalg = "0.14.1"
petgraph = "0.6.2"
thiserror = "1.0.40"
itertools = "0.11.0"
2 changes: 1 addition & 1 deletion ndinterp/examples/alphas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn main() {
];

let grid = Grid {
input: logq2,
input: vec![logq2.to_vec()],
values: alpha_s_vals,
};

Expand Down
61 changes: 38 additions & 23 deletions ndinterp/src/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
//! y = f(x1, x2, x3...)
//!
use crate::interpolate::InterpolationError;
use ndarray::Array1;
use itertools::izip;
use ndarray::{Array, Dimension, Ix1};

// Make public the families of interpolation algorithms implemented for grids
pub mod cubic;

/// A grid is made of two components:
/// A d-dimensional vector of 1-dimensional sorted vectors for the input points
/// A d-dimensional array for the grid values of
#[derive(Debug)]
pub struct Grid {
/// A grid is made of two (1-dimensional) sorted arrays.
pub input: Array1<f64>,
pub values: Array1<f64>,
pub struct Grid<D: Dimension> {
/// Arrays with the input vectors
pub input: Vec<Vec<f64>>,
scarlehoff marked this conversation as resolved.
Show resolved Hide resolved

/// Output points
pub values: Array<f64, D>,
}

impl Grid {
impl Grid<Ix1> {
// TODO: at the moment we are using here the derivatives that LHAPDF is using for the
// interpolation in alpha_s, these are probably enough for this use case but not in general
// - [ ] Implement a more robust form of the derivative
Expand All @@ -34,7 +40,7 @@ impl Grid {
/// dy = y_{i} - y_{i-1}
/// dx = x_{i} - x_{x-1}
pub fn derivative_at(&self, index: usize) -> f64 {
let dx = self.input[index] - self.input[index - 1];
let dx = self.input[0][index] - self.input[0][index - 1];
let dy = self.values[index] - self.values[index - 1];
dy / dx
}
Expand All @@ -49,19 +55,28 @@ impl Grid {
let dy_b = self.derivative_at(index);
0.5 * (dy_f + dy_b)
}
}

impl<D: Dimension> Grid<D> {
/// Find the index of the last value in the input such that input(idx) < query
/// If the query is outside the grid returns an extrapolation error
pub fn closest_below(&self, query: f64) -> Result<usize, InterpolationError> {
if query > self.input[self.input.len() - 1] {
Err(InterpolationError::ExtrapolationAbove(query))
} else if query < self.input[0] {
Err(InterpolationError::ExtrapolationBelow(query))
} else {
let u_idx = self.input.iter().position(|x| x > &query).unwrap();
let idx = u_idx - 1;
Ok(idx)
pub fn closest_below<const N: usize>(
&self,
input_query: &[f64],
) -> Result<[usize; N], InterpolationError> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I believe you'd like to use something like D::NDIM.

Unfortunately, D::NDIM: Option<usize>, so you'd need something like .unwrap(), but executed at compile time.

https://docs.rs/ndarray/0.15.6/ndarray/trait.Dimension.html#associatedconstant.NDIM

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simpler alternative might be to split the bulk of the function, and then implement thin wrappers around (one per Dimension).
Even the implementors of Dimension itself are just 7 (+ 1).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, you should be able to use a const function to unwrap it, defining an associated constant, and try to use the associated constant in the signature.

https://users.rust-lang.org/t/compile-time-const-unwrapping/51619

EDIT: Unfortunately, const panic is not stabilized, because of an issue with Rust 2021 rust-lang/rust#85194, maybe you could check if const pattern matching is allowed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A const function to unwrap IxN at compile time you mean?

Copy link
Owner

@alecandido alecandido Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, in order to turn it into a usize (if you can not panic, you can always turn IxDyn into a 0, it's useless anyhow, and it should not happen, so I wouldn't care that much).

Copy link
Collaborator Author

@scarlehoff scarlehoff Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this cannot be done. Or at least I cannot find a way of passing the D to the const function:

generic parameters may not be used in const operations
type parameter may not be used in const expressions

and in any case, if I use a function that could depend on one of the parameters:

const parameters may only be used as standalone arguments

Not sure whether that means there's no way of doing what I want to do (without creating a path per dimension, at that point I rather leave the N)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like to solve this issue. But if it's not simple, we could postpone.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to as well, but I haven't been able to...

let mut ret = [0; N];

for (r, &query, igrid) in izip!(&mut ret, input_query, &self.input) {
if query > *igrid.last().unwrap() {
return Err(InterpolationError::ExtrapolationAbove(query));
} else if query < igrid[0] {
return Err(InterpolationError::ExtrapolationBelow(query));
}

let u_idx = igrid.partition_point(|&x| x < query);
*r = u_idx - 1;
}
Ok(ret)
}
}

Expand All @@ -70,14 +85,14 @@ mod tests {
use super::*;
use ndarray::array;

fn gen_grid() -> Grid {
let x = array![0., 1., 2., 3., 4.];
fn gen_grid() -> Grid<Ix1> {
let x = vec![vec![0., 1., 2., 3., 4.]];
let y = array![4., 3., 2., 1., 1.];
let grid = Grid {

Grid {
input: x,
values: y,
};
grid
}
}

#[test]
Expand All @@ -90,7 +105,7 @@ mod tests {
#[test]
fn check_index_search() {
let grid = gen_grid();
assert_eq!(grid.closest_below(0.5).unwrap(), 0);
assert_eq!(grid.closest_below(3.2).unwrap(), 3);
assert_eq!(grid.closest_below::<1>(&[0.5]).unwrap()[0], 0);
assert_eq!(grid.closest_below::<1>(&[3.2]).unwrap()[0], 3);
}
}
13 changes: 8 additions & 5 deletions ndinterp/src/grid/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::grid::Grid;
use crate::interpolate::InterpolationError;
pub use crate::interpolate::Interpolator;
use ndarray::Ix1;

/// Cubic interpolation in 1D
///
Expand All @@ -18,7 +19,8 @@ pub use crate::interpolate::Interpolator;
/// with hij the Hermite basis functions
#[derive(Debug)]
pub struct Cubic1d {
pub grid: Grid,
/// The grid object contains all necessary information to perform the interpolation
pub grid: Grid<Ix1>,
}

impl Interpolator<f64> for Cubic1d {
Expand All @@ -32,14 +34,15 @@ impl Interpolator<f64> for Cubic1d {
/// Two special are considered, when the interpolation occurs between the first (last) two
/// bins, the derivative at the boundary is approximated by the forward (backward) difference
fn interpolate(&self, query: f64) -> Result<f64, InterpolationError> {
let idx = self.grid.closest_below(query)?;
let dx = self.grid.input[idx + 1] - self.grid.input[idx];
let raw_idx = self.grid.closest_below::<1>(&[query])?;
let idx = raw_idx[0];
let dx = self.grid.input[0][idx + 1] - self.grid.input[0][idx];
scarlehoff marked this conversation as resolved.
Show resolved Hide resolved

// Upper and lower bounds and derivatives
let yu = self.grid.values[idx + 1];
let yl = self.grid.values[idx];

let dydxu = if idx == self.grid.input.len() - 2 {
let dydxu = if idx == self.grid.input[0].len() - 2 {
dx * self.grid.derivative_at(idx + 1)
} else {
dx * self.grid.central_derivative_at(idx + 1)
Expand All @@ -51,7 +54,7 @@ impl Interpolator<f64> for Cubic1d {
dx * self.grid.central_derivative_at(idx)
};

let t = (query - self.grid.input[idx]) / dx;
let t = (query - self.grid.input[0][idx]) / dx;
let t2 = t * t;
let t3 = t2 * t;

Expand Down
65 changes: 34 additions & 31 deletions ndinterp/src/interpolate.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
use std::iter::zip;
//! This module implements interpolation rutines
use thiserror::Error;

use crate::metric::Metric;

/// Errors encountered during interpolation
#[derive(Debug, Error)]
pub enum InterpolationError {
/// Raised when the queried value is above the maximum
#[error("The value queried ({0}) is above the maximum")]
ExtrapolationAbove(f64),

/// Raised when the queried value is below the minimum
#[error("The value queried ({0}) is below the minimum")]
ExtrapolationBelow(f64),
}

/// Methods which all interpolator must implement
pub trait Interpolator<T> {
/// Produce the result of the inteprolation given a (nd) point 'query'
fn interpolate(&self, query: T) -> Result<f64, InterpolationError>;
}

/// ---- deal with the stuff below later ----
pub trait Interpolate {
type Point: Metric;

fn interpolate(&self, query: &Self::Point) -> f64;
}

pub struct Input<Point: Metric> {
pub point: Point,
pub value: f64,
}

impl<Point: Metric> From<(Point, f64)> for Input<Point> {
fn from(item: (Point, f64)) -> Self {
Self {
point: item.0,
value: item.1,
}
}
}

impl<Point: Metric> Input<Point> {
pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
zip(points.into_iter(), values.into_iter())
.map(|t| t.into())
.collect()
}
}
///// ---- deal with the stuff below later ----
//pub trait Interpolate {
// type Point: Metric;
//
// fn interpolate(&self, query: &Self::Point) -> f64;
//}
//
//pub struct Input<Point: Metric> {
// pub point: Point,
// pub value: f64,
//}
//
//impl<Point: Metric> From<(Point, f64)> for Input<Point> {
// fn from(item: (Point, f64)) -> Self {
// Self {
// point: item.0,
// value: item.1,
// }
// }
//}
//
//impl<Point: Metric> Input<Point> {
// pub fn stack(points: Vec<Point>, values: Vec<f64>) -> Vec<Self> {
// zip(points.into_iter(), values.into_iter())
// .map(|t| t.into())
// .collect()
// }
//}
8 changes: 3 additions & 5 deletions ndinterp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#![warn(clippy::all, clippy::cargo, clippy::nursery, clippy::pedantic)]
#![warn(clippy::all, clippy::cargo)]
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic, clippy::restriction, clippy::cargo)]
#![warn(missing_docs)]

pub mod grid;
pub mod interpolate;
pub mod metric;
pub mod scatter;
//pub mod metric;
//pub mod scatter;
9 changes: 8 additions & 1 deletion ndinterp_capi/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
[package]
name = "ndinterp_capi"
version = "0.0.1"
authors = ["Alessandro Candido <alessandro.candido@sns.it>", "Juan Cruz-Martinez <juacrumar@lairen.eu>"]
scarlehoff marked this conversation as resolved.
Show resolved Hide resolved
edition = "2021"
license = "GPL-3.0-or-later"
repository = "https://github.com/AleCandido/ndinterp/ndinterp_capi"
readme = "README.md"
categories = ["science"]
description = "C bindings for ndinterp"
keywords = ["math", "science"]

[dependencies]
ndarray = "0.15.4"
ndinterp = { path = "../ndinterp/" }

[lib]
name = "ndinterp"
name = "ndinterp_capi"
crate-type = ["cdylib"]

[features]
Expand Down
4 changes: 2 additions & 2 deletions ndinterp_capi/benches/Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
CXX = c++
CXXFLAGS = -std=c++17
# Assuming the lib has been installed with cinstall
NDFLAGS = $(shell pkg-config ndinterp --libs) $(shell pkg-config ndinterp --cflags)
NDFLAGS = $(shell pkg-config ndinterp_capi --libs) $(shell pkg-config ndinterp_capi --cflags)
LHAFLAGS = $(shell lhapdf-config --libs --cflags)

lharun: lhacheck
./lhacheck

lhacheck: lhacheck.cpp
$(CXX) $(CXXFLAGS) $< $(NDFLAGS) $(LHAFLAGS) -o $@
$(CXX) $(CXXFLAGS) -g $< $(NDFLAGS) $(LHAFLAGS) -o $@

clean:
rm -f lhacheck
2 changes: 1 addition & 1 deletion ndinterp_capi/benches/lhacheck.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "LHAPDF/LHAPDF.h"
#include "ndinterp.h"
#include "ndinterp_capi.h"
#include <algorithm>
#include <chrono>
#include <cmath>
Expand Down
14 changes: 10 additions & 4 deletions ndinterp_capi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#![warn(clippy::all, clippy::cargo, clippy::nursery, clippy::pedantic)]
//! C interface for ndinterp
#![warn(clippy::all, clippy::cargo)]
#![warn(missing_docs)]

use core::slice;

use ndarray::ArrayView1;
/// C interface for ndinterp
use ndinterp::grid;
use ndinterp::interpolate::Interpolator;

/// cubic1d inteprolator
pub struct Cubic1d;

/// Creates a cubic1d interpolator given the nodes
Expand All @@ -21,11 +24,14 @@ pub unsafe extern "C" fn create_cubic_interpolator1d(
values_c: *const f64,
size: usize,
) -> Box<grid::cubic::Cubic1d> {
let input = ArrayView1::from_shape_ptr(size, input_c);
// Use slice instead of vec so that rust doesn't take ownership of the data and releases
// the burden is on the function calling them
let slice_input = unsafe { slice::from_raw_parts(input_c, size) };
scarlehoff marked this conversation as resolved.
Show resolved Hide resolved
let input = vec![slice_input.to_vec()];
cschwan marked this conversation as resolved.
Show resolved Hide resolved
let values = ArrayView1::from_shape_ptr(size, values_c);

let grid = grid::Grid {
input: input.into_owned(),
input,
values: values.into_owned(),
};
let cubic_interpolator = grid::cubic::Cubic1d { grid };
Expand Down