Skip to content

Commit

Permalink
separate DenseMatrix trait
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 5, 2024
1 parent ad35b21 commit 5933203
Show file tree
Hide file tree
Showing 17 changed files with 270 additions and 140 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"rust-analyzer.linkedProjects": [
"./Cargo.toml",
"./Cargo.toml"
]
}
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ anyhow = ">=1.0.77"
num-traits = "0.2.17"
ouroboros = "0.18.2"
serde = { version = "1.0.196", features = ["derive"] }
iree-rs = { verson = "0.1.1", optional = true }

[dev-dependencies]
insta = { version = "1.34.0", features = ["yaml"] }

[features]
iree = ["dep:iree-rs"]
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use num_traits::{Signed, Pow};
use vector::{Vector, VectorView, VectorViewMut, VectorIndex, VectorRef};
use nonlinear_solver::{NonLinearSolver, newton::NewtonNonlinearSolver};
use op::{NonLinearOp, LinearOp, ConstantOp};
use matrix::{Matrix, MatrixViewMut, MatrixCommon};
use matrix::{DenseMatrix, MatrixViewMut, Matrix};
use solver::SolverProblem;
use linear_solver::{lu::LU, LinearSolver};
pub use ode_solver::{OdeSolverProblem, OdeSolverState, bdf::Bdf, OdeSolverMethod};
Expand Down
4 changes: 2 additions & 2 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ impl <V> LinearSolveSolution<V> {
pub mod tests {
use std::rc::Rc;

use crate::{op::{linear_closure::LinearClosure, LinearOp}, LinearSolver, vector::VectorRef, Matrix, SolverProblem, Vector, LU};
use crate::{op::{linear_closure::LinearClosure, LinearOp}, LinearSolver, vector::VectorRef, DenseMatrix, SolverProblem, Vector, LU};
use num_traits::{One, Zero};

use super::LinearSolveSolution;

fn linear_problem<M: Matrix + 'static>() -> (SolverProblem<impl LinearOp<M = M, V = M::V, T = M::T>>, Vec<LinearSolveSolution<M::V>>) {
fn linear_problem<M: DenseMatrix + 'static>() -> (SolverProblem<impl LinearOp<M = M, V = M::V, T = M::T>>, Vec<LinearSolveSolution<M::V>>) {
let diagonal = M::V::from_vec(vec![2.0.into(), 2.0.into()]);
let jac = M::from_diagonal(&diagonal);
let op = Rc::new(LinearClosure::new(
Expand Down
67 changes: 31 additions & 36 deletions src/matrix/dense_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ use anyhow::Result;

use crate::{Scalar, IndexType};

use super::{Matrix, MatrixView, MatrixCommon, MatrixViewMut};
use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};

impl<'a, T: Scalar> MatrixCommon for DMatrixViewMut<'a, T> {
type V = DVector<T>;
type T = T;
fn diagonal(&self) -> DVector<T> {
self.diagonal()
}

fn ncols(&self) -> IndexType {
self.ncols()
}
Expand All @@ -33,11 +31,7 @@ impl<'a, T: Scalar> MatrixViewMut<'a> for DMatrixViewMut<'a, T> {
impl<'a, T: Scalar> MatrixCommon for DMatrixView<'a, T> {
type V = DVector<T>;
type T = T;
fn diagonal(&self) -> DVector<T> {
self.diagonal()
}

fn ncols(&self) -> IndexType {
fn ncols(&self) -> IndexType {
self.ncols()
}
fn nrows(&self) -> IndexType {
Expand All @@ -52,26 +46,48 @@ impl<'a, T: Scalar> MatrixView<'a> for DMatrixView<'a, T> {
impl<T: Scalar> MatrixCommon for DMatrix<T> {
type V = DVector<T>;
type T = T;
fn diagonal(&self) -> DVector<T> {
self.diagonal()
}

fn ncols(&self) -> IndexType {
self.ncols()
}
fn nrows(&self) -> IndexType {
self.nrows()
}

}

impl<T: Scalar> Matrix for DMatrix<T> {
fn try_from_triplets(nrows: IndexType, ncols: IndexType, triplets: Vec<(IndexType, IndexType, T)>) -> Result<Self> {
let mut m = Self::zeros(nrows, ncols);
for (i, j, v) in triplets {
m[(i, j)] = v;
}
Ok(m)
}
fn zeros(nrows: IndexType, ncols: IndexType) -> Self {
Self::zeros(nrows, ncols)
}
fn from_diagonal(v: &DVector<T>) -> Self {
Self::from_diagonal(v)
}
fn diagonal(&self) -> Self::V {
self.diagonal()
}
}


impl<T: Scalar> DenseMatrix for DMatrix<T> {
type View<'a> = DMatrixView<'a, T>;
type ViewMut<'a> = DMatrixViewMut<'a, T>;


fn gemv(&self, alpha: T, x: &DVector<T>, beta: T, y: &mut DVector<T>) {

fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T) {
self.gemm(alpha, a, b, beta);
}
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
y.gemv(alpha, self, x, beta);
}


fn column_mut(&mut self, i: IndexType) -> DVectorViewMut<'_, T> {
self.column_mut(i)
}
Expand All @@ -86,25 +102,4 @@ impl<T: Scalar> Matrix for DMatrix<T> {
fn columns(&self, start: IndexType, nrows: IndexType) -> Self::View<'_> {
self.columns(start, nrows)
}

fn try_from_triplets(nrows: IndexType, ncols: IndexType, triplets: Vec<(IndexType, IndexType, T)>) -> Result<Self> {
let mut m = Self::zeros(nrows, ncols);
for (i, j, v) in triplets {
m[(i, j)] = v;
}
Ok(m)
}
fn zeros(nrows: IndexType, ncols: IndexType) -> Self {
Self::zeros(nrows, ncols)
}
fn from_diagonal(v: &DVector<T>) -> Self {
Self::from_diagonal(v)
}
fn gemm(&mut self, alpha: T, a: &Self, b: &Self, beta: T) {
self.gemm(alpha, a, b, beta);
}
fn diagonal(&self) -> Self::V {
self.diagonal()
}

}
121 changes: 65 additions & 56 deletions src/matrix/mod.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
use std::fmt::{Debug, Display};
use num_traits::{One, Zero};
use std::fmt::Debug;

use crate::{IndexType, Scalar, Vector};
use anyhow::Result;

mod dense_serial;
mod sparse_serial;

pub trait MatrixCommon: Sized + Debug + Display
pub trait MatrixCommon: Sized + Debug
{
type V: Vector<T = Self::T>;
type T: Scalar;
fn diagonal(&self) -> Self::V;



/// Get the number of columns of the matrix
fn nrows(&self) -> IndexType;

/// Get the number of rows of the matrix
fn ncols(&self) -> IndexType;


}

impl <'a, M> MatrixCommon for &'a M where M: MatrixCommon {
type T = M::T;
type V = M::V;
fn diagonal(&self) -> Self::V {
M::diagonal(self)
}

fn ncols(&self) -> IndexType {
M::ncols(self)
}
Expand All @@ -33,9 +38,6 @@ impl <'a, M> MatrixCommon for &'a M where M: MatrixCommon {
impl <'a, M> MatrixCommon for &'a mut M where M: MatrixCommon {
type T = M::T;
type V = M::V;
fn diagonal(&self) -> Self::V {
M::diagonal(self)
}
fn ncols(&self) -> IndexType {
M::ncols(self)
}
Expand All @@ -47,11 +49,13 @@ impl <'a, M> MatrixCommon for &'a mut M where M: MatrixCommon {
pub trait MatrixOpsByValue<Rhs = Self, Output = Self>: MatrixCommon
+ Add<Rhs, Output = Output>
+ Sub<Rhs, Output = Output>
+ Mul<Rhs, Output = Output>
{}

impl <M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where M: MatrixCommon
+ Add<Rhs, Output = Output>
+ Sub<Rhs, Output = Output>
+ Mul<Rhs, Output = Output>
{}

pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon
Expand All @@ -64,11 +68,9 @@ impl <M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon
+ SubAssign<Rhs>
{}

pub trait MatrixMutOps<View>:
MatrixMutOpsByValue<Self>
+ for<'a> MatrixMutOpsByValue<&'a Self>
+ MatrixMutOpsByValue<View>
+ for<'a> MatrixMutOpsByValue<&'a View>
pub trait MatrixMutOps<Other>:
MatrixMutOpsByValue<Other>
+ for<'a> MatrixMutOpsByValue<&'a Other>
+ MulAssign<Self::T>
+ DivAssign<Self::T>
{}
Expand All @@ -83,23 +85,19 @@ where
+ DivAssign<Self::T>
{}

pub trait MatrixOps<View>:
MatrixOpsByValue<Self>
+ for<'a> MatrixOpsByValue<&'a Self>
+ MatrixOpsByValue<View>
+ for<'a> MatrixOpsByValue<&'a View>

pub trait MatrixOps<Rhs>:
MatrixOpsByValue<Rhs>
+ for<'a> MatrixOpsByValue<&'a Rhs>
+ Mul<Self::T, Output = Self>
+ Div<Self::T, Output = Self>
{}

impl <M, View> MatrixOps<View> for M
where
M: MatrixOpsByValue<Self>
+ for<'a> MatrixOpsByValue<&'a Self>
+ MatrixOpsByValue<View>
+ for<'a> MatrixOpsByValue<&'a View>
+ Mul<Self::T, Output = Self>
+ Div<Self::T, Output = Self>
impl <M, Rhs> MatrixOps<Rhs> for M where
M: MatrixOpsByValue<Rhs>
+ for<'a> MatrixOpsByValue<&'a Rhs>
+ Mul<Self::T, Output = M>
+ Div<Self::T, Output = M>
{}

/// A trait allowing for references to implement matrix operations
Expand All @@ -120,10 +118,12 @@ impl <RefT, M: MatrixCommon> MatrixRef<M> for RefT where

/// A mutable view of a dense matrix [Matrix]
pub trait MatrixViewMut<'a>:
MatrixMutOps<Self::View>
MatrixMutOps<Self>
+ MatrixMutOps<Self::View>
where Self: 'a
{
type Owned: Matrix<V = Self::V>;
type View: MatrixView<'a, V = Self::V, Owned = Self::Owned, T = Self::T>;
type Owned: DenseMatrix<T = Self::T, V = Self::V, ViewMut<'a> = Self>;
type View: MatrixView<'a, Owned = Self::Owned, T = Self::T>;
fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
}
Expand All @@ -132,23 +132,19 @@ pub trait MatrixViewMut<'a>:
pub trait MatrixView<'a>:
MatrixRef<Self::Owned>
+ Clone
where Self: 'a
{
type Owned: Matrix<V = Self::V>;
type Owned: DenseMatrix<T = Self::T, V = Self::V>;
}

/// A dense matrix. The assumption is that the underlying matrix is stored in column-major order, so functions for taking columns views are efficient
/// A base matrix trait (including sparse and dense matrices)
pub trait Matrix:
for <'a> MatrixOps<Self::View<'a>>
+ for <'a> MatrixMutOps<Self::View<'a>>
+ Index<(IndexType, IndexType), Output = Self::T>
+ IndexMut<(IndexType, IndexType), Output = Self::T>
MatrixOps<Self>
+ Clone
{
/// A view of this matrix type
type View<'a>: MatrixView<'a, Owned = Self, T = Self::T> where Self: 'a;

/// A mutable view of this matrix type
type ViewMut<'a>: MatrixViewMut<'a, Owned = Self, T = Self::T, View = Self::View<'a>> where Self: 'a;
/// Extract the diagonal of the matrix as an owned vector
fn diagonal(&self) -> Self::V;


/// Create a new matrix of shape `nrows` x `ncols` filled with zeros
fn zeros(nrows: IndexType, ncols: IndexType) -> Self;
Expand All @@ -158,6 +154,31 @@ pub trait Matrix:

/// Create a new matrix from a vector of triplets (i, j, value) where i and j are the row and column indices of the value
fn try_from_triplets(nrows: IndexType, ncols: IndexType, triplets: Vec<(IndexType, IndexType, Self::T)>) -> Result<Self>;
}

/// A dense column-major matrix. The assumption is that the underlying matrix is stored in column-major order, so functions for taking columns views are efficient
pub trait DenseMatrix:
Matrix
+ for <'a> MatrixOps<Self::View<'a>>
+ for <'a> MatrixMutOps<Self::View<'a>>
+ Index<(IndexType, IndexType), Output = Self::T>
+ IndexMut<(IndexType, IndexType), Output = Self::T>
{

/// A view of the dense matrix type
type View<'a>: MatrixView<'a, Owned = Self, T = Self::T> where Self: 'a;

/// A mutable view of the dense matrix type
type ViewMut<'a>: MatrixViewMut<'a, Owned = Self, T = Self::T, View = Self::View<'a>> where Self: 'a;


/// Perform a matrix-matrix multiplication `self = alpha * a * b + beta * self`, where `alpha` and `beta` are scalars, and `a` and `b` are matrices
fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);


/// Perform a matrix-vector multiplication `y = self * x + beta * y`.
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);


/// Get a matrix view of the columns starting at `start` and ending at `start + ncols`
fn columns(&self, start: IndexType, ncols: IndexType) -> Self::View<'_>;
Expand All @@ -171,19 +192,7 @@ pub trait Matrix:
/// Get a mutable vector view of the column `i`
fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;

/// Perform a matrix-matrix multiplication `self = alpha * a * b + beta * self`, where `alpha` and `beta` are scalars, and `a` and `b` are matrices
fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);

/// Extract the diagonal of the matrix as an owned vector
fn diagonal(&self) -> Self::V;

/// Perform a matrix-matrix multiplication `result = self * x`.
fn mat_mul(&self, x: &Self) -> Self {
let mut y = Self::zeros(self.nrows(), x.ncols());
y.gemm(Self::T::one(), self, x, Self::T::zero());
y
}

/// Perform a matrix-vector multiplication `y = self * x + beta * y`.
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
}


Loading

0 comments on commit 5933203

Please sign in to comment.