Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update argmin in linfa-linear to version 0.8.0 #289

Merged
merged 5 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/checking.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
run: cargo check --workspace --all-targets

- name: Run cargo check (with serde)
run: cargo check --workspace --all-targets --features "linfa-clustering/serde linfa-ica/serde linfa-kernel/serde linfa-reduction/serde linfa-svm/serde linfa-elasticnet/serde linfa-pls/serde linfa-trees/serde linfa-nn/serde"
run: cargo check --workspace --all-targets --features "linfa-clustering/serde linfa-ica/serde linfa-kernel/serde linfa-reduction/serde linfa-svm/serde linfa-elasticnet/serde linfa-pls/serde linfa-trees/serde linfa-nn/serde linfa-linear/serde"
16 changes: 12 additions & 4 deletions algorithms/linfa-linear/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@ keywords = ["machine-learning", "linfa", "ai", "ml", "linear"]
categories = ["algorithms", "mathematics", "science"]

[features]
blas = ["ndarray-linalg", "linfa/ndarray-linalg", "argmin/ndarray-linalg"]
blas = ["ndarray-linalg", "linfa/ndarray-linalg"]
serde = ["serde_crate", "linfa/serde", "ndarray/serde", "argmin/serde1"]

[dependencies.serde_crate]
package = "serde"
optional = true
version = "1.0"
default-features = false
features = ["std", "derive"]

[dependencies]
ndarray = { version = "0.15", features = ["approx"] }
linfa-linalg = { version = "0.1", default-features = false }
ndarray-linalg = { version = "0.15", optional = true }
num-traits = "0.2"
argmin = { version = "0.4.6", features = ["ndarray", "ndarray-rand"] }
serde = { version = "1.0", default-features = false, features = ["derive"] }
argmin = { version = "0.7", default-features = false }
argmin-math = { version = "0.2", features = ["ndarray_v0_15-nolinalg"] }
thiserror = "1.0"

linfa = { version = "0.6.1", path = "../..", features=["serde"] }
linfa = { version = "0.6.1", path = "../.." }

[dev-dependencies]
linfa-datasets = { version = "0.6.1", path = "../../datasets", features = ["diabetes"] }
Expand Down
74 changes: 3 additions & 71 deletions algorithms/linfa-linear/src/float.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
use argmin::prelude::{ArgminAdd, ArgminDot, ArgminFloat, ArgminMul, ArgminNorm, ArgminSub};
use ndarray::{Array1, NdFloat};
use argmin::core::ArgminFloat;
use ndarray::NdFloat;
use num_traits::float::FloatConst;
use num_traits::FromPrimitive;
use serde::{Deserialize, Serialize};

// A Float trait that captures the requirements we need for the various places
// we need floats. There requirements are imposed y ndarray and argmin
pub trait Float:
ArgminFloat
+ FloatConst
+ NdFloat
+ Default
+ Clone
+ FromPrimitive
+ ArgminMul<ArgminParam<Self>, ArgminParam<Self>>
+ linfa::Float
ArgminFloat + FloatConst + NdFloat + Default + Clone + FromPrimitive + linfa::Float
{
const POSITIVE_LABEL: Self;
const NEGATIVE_LABEL: Self;
Expand All @@ -29,63 +21,3 @@ impl Float for f64 {
const POSITIVE_LABEL: Self = 1.0;
const NEGATIVE_LABEL: Self = -1.0;
}

impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f64 {
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
ArgminParam(&other.0 * *self)
}
}

impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f32 {
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
ArgminParam(&other.0 * *self)
}
}

// Here we create a new type over ndarray's Array1. This is required
// to implement traits required by argmin
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ArgminParam<A>(pub Array1<A>);

impl<A> ArgminParam<A> {
#[inline]
pub fn as_array(&self) -> &Array1<A> {
&self.0
}
}

impl<A: Float> ArgminSub<ArgminParam<A>, ArgminParam<A>> for ArgminParam<A> {
fn sub(&self, other: &ArgminParam<A>) -> ArgminParam<A> {
ArgminParam(&self.0 - &other.0)
}
}

impl<A: Float> ArgminAdd<ArgminParam<A>, ArgminParam<A>> for ArgminParam<A> {
fn add(&self, other: &ArgminParam<A>) -> ArgminParam<A> {
ArgminParam(&self.0 + &other.0)
}
}

impl<A: Float> ArgminDot<ArgminParam<A>, A> for ArgminParam<A> {
fn dot(&self, other: &ArgminParam<A>) -> A {
self.0.dot(&other.0)
}
}

impl<A: Float> ArgminNorm<A> for ArgminParam<A> {
fn norm(&self) -> A {
self.0.dot(&self.0)
}
}

impl<A: Float> ArgminMul<A, ArgminParam<A>> for ArgminParam<A> {
fn mul(&self, other: &A) -> ArgminParam<A> {
ArgminParam(&self.0 * *other)
}
}

impl<A: Float> ArgminMul<ArgminParam<A>, ArgminParam<A>> for ArgminParam<A> {
fn mul(&self, other: &ArgminParam<A>) -> ArgminParam<A> {
ArgminParam(&self.0 * &other.0)
}
}
10 changes: 8 additions & 2 deletions algorithms/linfa-linear/src/glm/hyperparams.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use crate::{glm::link::Link, LinearError, TweedieRegressor};
use linfa::{Float, ParamGuard};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

/// The set of hyperparameters that can be specified for the execution of the Tweedie Regressor.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct TweedieRegressorValidParams<F> {
alpha: F,
fit_intercept: bool,
Expand Down
10 changes: 8 additions & 2 deletions algorithms/linfa-linear/src/glm/link.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
//! Link functions used by GLM

use ndarray::Array1;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use crate::float::Float;

#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Link functions used by GLM
pub enum Link {
/// The identity link function `g(x)=x`
Expand Down
59 changes: 40 additions & 19 deletions algorithms/linfa-linear/src/glm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,43 @@ mod hyperparams;
mod link;

use crate::error::{LinearError, Result};
use crate::float::{ArgminParam, Float};
use crate::float::Float;
use argmin_math::{
ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
ArgminSub, ArgminZero,
};
use distribution::TweedieDistribution;
pub use hyperparams::TweedieRegressorParams;
pub use hyperparams::TweedieRegressorValidParams;
use linfa::dataset::AsSingleTargets;
pub use link::Link;

use argmin::core::{ArgminOp, Executor};
use argmin::core::{CostFunction, Executor, Gradient};
use argmin::solver::linesearch::MoreThuenteLineSearch;
use argmin::solver::quasinewton::LBFGS;
use ndarray::{array, concatenate, s};
use ndarray::{Array, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use linfa::traits::*;
use linfa::DatasetBase;

impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for TweedieRegressorValidParams<F>
where
Array1<F>: ArgminAdd<Array1<F>, Array1<F>>
+ ArgminSub<Array1<F>, Array1<F>>
+ ArgminSub<F, Array1<F>>
+ ArgminAdd<F, Array1<F>>
+ ArgminMul<F, Array1<F>>
+ ArgminMul<Array1<F>, Array1<F>>
+ ArgminDot<Array1<F>, F>
+ ArgminL2Norm<F>
+ ArgminL1Norm<F>
+ ArgminSignum
+ ArgminMinMax,
F: ArgminMul<Array1<F>, Array1<F>> + ArgminZero,
{
type Object = TweedieRegressor<F>;

Expand Down Expand Up @@ -65,12 +83,12 @@ impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
// position x and gradient ∇f(x), where generally the history
// size m can be small (often m < 10)
// For our problem we set m as 7
let solver = LBFGS::new(linesearch, 7).with_tol_grad(F::cast(self.tol()));
let solver = LBFGS::new(linesearch, 7).with_tolerance_grad(F::cast(self.tol()))?;

let result = Executor::new(problem, solver, ArgminParam(coef))
.max_iters(self.max_iter() as u64)
let mut result = Executor::new(problem, solver)
.configure(|state| state.param(coef).max_iters(self.max_iter() as u64))
.run()?;
coef = result.state.get_best_param().as_array().to_owned();
coef = result.state.take_best_param().unwrap();

if self.fit_intercept() {
Ok(TweedieRegressor {
Expand Down Expand Up @@ -116,12 +134,9 @@ impl<'a, A: Float> TweedieProblem<'a, A> {
}
}

impl<'a, A: Float> ArgminOp for TweedieProblem<'a, A> {
type Param = ArgminParam<A>;
impl<'a, A: Float> CostFunction for TweedieProblem<'a, A> {
type Param = Array1<A>;
type Output = A;
type Hessian = ();
type Jacobian = Array1<A>;
type Float = A;

// This function calculates the value of the objective function we are trying
// to minimize,
Expand All @@ -130,9 +145,7 @@ impl<'a, A: Float> ArgminOp for TweedieProblem<'a, A> {
//
// - `p` is the parameter we are optimizing (coefficients and intercept)
// - `alpha` is the regularization hyperparameter
fn apply(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
let p = p.as_array();

fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
let (ypred, _, offset) = self.ypred(p);

let dev = self.dist.deviance(self.y, ypred.view())?;
Expand All @@ -145,10 +158,13 @@ impl<'a, A: Float> ArgminOp for TweedieProblem<'a, A> {

Ok(obj)
}
}

fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
let p = p.as_array();
impl<'a, A: Float> Gradient for TweedieProblem<'a, A> {
type Param = Array1<A>;
type Gradient = Array1<A>;

fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
let (ypred, lin_pred, offset) = self.ypred(p);

let devp;
Expand All @@ -168,7 +184,7 @@ impl<'a, A: Float> ArgminOp for TweedieProblem<'a, A> {
objp.slice_mut(s![offset..])
.zip_mut_with(&pscaled, |x, y| *x += *y);

Ok(ArgminParam(objp))
Ok(objp)
}
}

Expand Down Expand Up @@ -204,7 +220,12 @@ impl<'a, A: Float> ArgminOp for TweedieProblem<'a, A> {
/// let r2 = pred.r2(&dataset).unwrap();
/// println!("r2 from prediction: {}", r2);
/// ```
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct TweedieRegressor<A> {
/// Estimated coefficients for the linear predictor
pub coef: Array1<A>,
Expand Down
18 changes: 15 additions & 3 deletions algorithms/linfa-linear/src/isotonic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#![allow(non_snake_case)]
use crate::error::{LinearError, Result};
use ndarray::{s, stack, Array1, ArrayBase, Axis, Data, Ix1, Ix2};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use std::cmp::Ordering;

use linfa::dataset::{AsSingleTargets, DatasetBase};
Expand Down Expand Up @@ -74,7 +76,12 @@ where
(V, J_index)
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// An isotonic regression model.
///
/// IsotonicRegression solves an isotonic regression problem using the pool
Expand Down Expand Up @@ -103,7 +110,12 @@ where
/// A unifying framework. Mathematical Programming 47, 425–439 (1990).
pub struct IsotonicRegression {}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A fitted isotonic regression model which can be used for making predictions.
pub struct FittedIsotonicRegression<F> {
regressor: Array1<F>,
Expand Down
17 changes: 14 additions & 3 deletions algorithms/linfa-linear/src/ols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ use linfa_linalg::qr::LeastSquaresQrInto;
use ndarray::{concatenate, s, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
#[cfg(feature = "blas")]
use ndarray_linalg::LeastSquaresSvdInto;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use linfa::dataset::{AsSingleTargets, DatasetBase};
use linfa::traits::{Fit, PredictInplace};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// An ordinary least squares linear regression model.
///
/// LinearRegression fits a linear model to minimize the residual sum of
Expand Down Expand Up @@ -48,7 +54,12 @@ pub struct LinearRegression {
fit_intercept: bool,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A fitted linear regression model which can be used for making predictions.
pub struct FittedLinearRegression<F> {
intercept: F,
Expand Down
2 changes: 1 addition & 1 deletion src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<R: Records, S> DatasetBase<R, S> {
self.feature_names.clone()
} else {
(0..self.records.nfeatures())
.map(|idx| format!("feature-{}", idx))
.map(|idx| format!("feature-{idx}"))
.collect()
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/metrics_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ mod tests {
assert_abs_diff_eq!(x.accuracy(), 5.0 / 6.0_f32);
assert_abs_diff_eq!(
x.mcc(),
(2. * 3. - 1. * 0.) / (2.0f32 * 3. * 3. * 4.).sqrt() as f32
(2. * 3. - 1. * 0.) / (2.0f32 * 3. * 3. * 4.).sqrt()
);

assert_split_eq(
Expand Down