diff --git a/lax/src/cholesky.rs b/lax/src/cholesky.rs index 9b213246..785f6e5e 100644 --- a/lax/src/cholesky.rs +++ b/lax/src/cholesky.rs @@ -1,59 +1,25 @@ +//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm + use super::*; use crate::{error::*, layout::*}; use cauchy::*; -#[cfg_attr(doc, katexit::katexit)] -/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition -/// -/// For a given positive definite matrix $A$, -/// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where +/// Compute Cholesky decomposition according to [UPLO] /// -/// - $L$ is lower matrix -/// - $U$ is upper matrix +/// LAPACK correspondance +/// ---------------------- /// -/// This is designed as two step computation according to LAPACK API +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrf | dpotrf | cpotrf | zpotrf | /// -/// 1. Factorize input matrix $A$ into $L$ or $U$ -/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ -/// using $U$ or $L$. -pub trait Cholesky_: Sized { - /// Compute Cholesky decomposition $A = U^T U$ or $A = L L^T$ according to [UPLO] - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotrf | dpotrf | cpotrf | zpotrf | - /// +pub trait CholeskyImpl: Scalar { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotri | dpotri | cpotri | zpotri | - /// - fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; - - /// Solve linear equation $Ax = b$ using $U$ or $L$ - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | spotrs | dpotrs | cpotrs | zpotrs | - /// - fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } -macro_rules! impl_cholesky { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Cholesky_ for $scalar { +macro_rules! impl_cholesky_ { + ($s:ty, $trf:path) => { + impl CholeskyImpl for $s { fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); if matches!(l, MatrixLayout::C { .. }) { @@ -69,7 +35,30 @@ macro_rules! impl_cholesky { } Ok(()) } + } + }; +} +impl_cholesky_!(c64, lapack_sys::zpotrf_); +impl_cholesky_!(c32, lapack_sys::cpotrf_); +impl_cholesky_!(f64, lapack_sys::dpotrf_); +impl_cholesky_!(f32, lapack_sys::spotrf_); + +/// Compute inverse matrix using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotri | dpotri | cpotri | zpotri | +/// +pub trait InvCholeskyImpl: Scalar { + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; +} +macro_rules! impl_inv_cholesky { + ($s:ty, $tri:path) => { + impl InvCholeskyImpl for $s { fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); if matches!(l, MatrixLayout::C { .. }) { @@ -85,7 +74,30 @@ macro_rules! impl_cholesky { } Ok(()) } + } + }; +} +impl_inv_cholesky!(c64, lapack_sys::zpotri_); +impl_inv_cholesky!(c32, lapack_sys::cpotri_); +impl_inv_cholesky!(f64, lapack_sys::dpotri_); +impl_inv_cholesky!(f32, lapack_sys::spotri_); +/// Solve linear equation using Cholesky factroization result +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | spotrs | dpotrs | cpotrs | zpotrs | +/// +pub trait SolveCholeskyImpl: Scalar { + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solve_cholesky { + ($s:ty, $trs:path) => { + impl SolveCholeskyImpl for $s { fn solve_cholesky( l: MatrixLayout, mut uplo: UPLO, @@ -123,29 +135,8 @@ macro_rules! impl_cholesky { } } }; -} // end macro_rules - -impl_cholesky!( - f64, - lapack_sys::dpotrf_, - lapack_sys::dpotri_, - lapack_sys::dpotrs_ -); -impl_cholesky!( - f32, - lapack_sys::spotrf_, - lapack_sys::spotri_, - lapack_sys::spotrs_ -); -impl_cholesky!( - c64, - lapack_sys::zpotrf_, - lapack_sys::zpotri_, - lapack_sys::zpotrs_ -); -impl_cholesky!( - c32, - lapack_sys::cpotrf_, - lapack_sys::cpotri_, - lapack_sys::cpotrs_ -); +} +impl_solve_cholesky!(c64, lapack_sys::zpotrs_); +impl_solve_cholesky!(c32, lapack_sys::cpotrs_); +impl_solve_cholesky!(f64, lapack_sys::dpotrs_); +impl_solve_cholesky!(f32, lapack_sys::spotrs_); diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 199f2dc2..e673d261 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -1,21 +1,24 @@ -//! ndarray-free safe Rust wrapper for LAPACK FFI +//! Safe Rust wrapper for LAPACK without external dependency. //! -//! `Lapack` trait and sub-traits -//! ------------------------------- +//! [Lapack] trait +//! ---------------- //! -//! This crates provides LAPACK wrapper as `impl` of traits to base scalar types. -//! For example, LU decomposition to double-precision matrix is provided like: +//! This crates provides LAPACK wrapper as a traits. +//! For example, LU decomposition of general matrices is provided like: //! //! ```ignore -//! impl Solve_ for f64 { -//! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { ... } +//! pub trait Lapack { +//! fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; //! } //! ``` //! -//! see [Solve_] for detail. You can use it like `f64::lu`: +//! see [Lapack] for detail. +//! This trait is implemented for [f32], [f64], [c32] which is an alias to `num::Complex`, +//! and [c64] which is an alias to `num::Complex`. +//! You can use it like `f64::lu`: //! //! ``` -//! use lax::{Solve_, layout::MatrixLayout, Transpose}; +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! //! let mut a = vec![ //! 1.0, 2.0, @@ -31,9 +34,9 @@ //! this trait can be used as a trait bound: //! //! ``` -//! use lax::{Solve_, layout::MatrixLayout, Transpose}; +//! use lax::{Lapack, layout::MatrixLayout, Transpose}; //! -//! fn solve_at_once(layout: MatrixLayout, a: &mut [T], b: &mut [T]) -> Result<(), lax::error::Error> { +//! fn solve_at_once(layout: MatrixLayout, a: &mut [T], b: &mut [T]) -> Result<(), lax::error::Error> { //! let pivot = T::lu(layout, a)?; //! T::solve(layout, Transpose::No, a, &pivot, b)?; //! Ok(()) @@ -48,9 +51,9 @@ //! //! According to the property input metrix, several types of triangular decomposition are used: //! -//! - [Solve_] trait provides methods for LU-decomposition for general matrix. -//! - [Solveh_] triat provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/hermite indefinite matrix. -//! - [Cholesky_] triat provides methods for Cholesky decomposition for symmetric/hermite positive dinite matrix. +//! - [solve] module provides methods for LU-decomposition for general matrix. +//! - [solveh] module provides methods for Bunch-Kaufman diagonal pivoting method for symmetric/Hermitian indefinite matrix. +//! - [cholesky] module provides methods for Cholesky decomposition for symmetric/Hermitian positive dinite matrix. //! //! Eigenvalue Problem //! ------------------- @@ -59,8 +62,8 @@ //! there are several types of eigenvalue problem API //! //! - [eig] module for eigenvalue problem for general matrix. -//! - [eigh] module for eigenvalue problem for symmetric/hermite matrix. -//! - [eigh_generalized] module for generalized eigenvalue problem for symmetric/hermite matrix. +//! - [eigh] module for eigenvalue problem for symmetric/Hermitian matrix. +//! - [eigh_generalized] module for generalized eigenvalue problem for symmetric/Hermitian matrix. //! //! Singular Value Decomposition //! ----------------------------- @@ -85,20 +88,20 @@ pub mod error; pub mod flags; pub mod layout; +pub mod cholesky; pub mod eig; pub mod eigh; pub mod eigh_generalized; pub mod least_squares; pub mod qr; +pub mod solve; +pub mod solveh; pub mod svd; pub mod svddc; mod alloc; -mod cholesky; mod opnorm; mod rcond; -mod solve; -mod solveh; mod triangular; mod tridiagonal; @@ -107,8 +110,6 @@ pub use self::flags::*; pub use self::least_squares::LeastSquaresOwned; pub use self::opnorm::*; pub use self::rcond::*; -pub use self::solve::*; -pub use self::solveh::*; pub use self::svd::{SvdOwned, SvdRef}; pub use self::triangular::*; pub use self::tridiagonal::*; @@ -121,9 +122,7 @@ pub type Pivot = Vec; #[cfg_attr(doc, katexit::katexit)] /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: - OperatorNorm_ + Solve_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ -{ +pub trait Lapack: OperatorNorm_ + Triangular_ + Tridiagonal_ + Rcond_ { /// Compute right eigenvalue and eigenvectors for a general matrix fn eig( calc_v: bool, @@ -131,7 +130,7 @@ pub trait Lapack: a: &mut [Self], ) -> Result<(Vec, Vec)>; - /// Compute right eigenvalue and eigenvectors for a symmetric or hermite matrix + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix fn eigh( calc_eigenvec: bool, layout: MatrixLayout, @@ -139,7 +138,7 @@ pub trait Lapack: a: &mut [Self], ) -> Result>; - /// Compute right eigenvalue and eigenvectors for a symmetric or hermite matrix + /// Compute right eigenvalue and eigenvectors for a symmetric or Hermitian matrix fn eigh_generalized( calc_eigenvec: bool, layout: MatrixLayout, @@ -181,6 +180,83 @@ pub trait Lapack: b_layout: MatrixLayout, b: &mut [Self], ) -> Result>; + + /// Computes the LU decomposition of a general $m \times n$ matrix + /// with partial pivoting with row interchanges. + /// + /// For a given matrix $A$, LU decomposition is described as $A = PLU$ where: + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// - $P$ is permutation matrix represented by [Pivot] + /// + /// This is designed as two step computation according to LAPACK API: + /// + /// 1. Factorize input matrix $A$ into $L$, $U$, and $P$. + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv] using the output of LU decomposition. + /// + /// Output + /// ------- + /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. + /// - $P$ is returned as [Pivot] + /// + /// Error + /// ------ + /// - if the matrix is singular + /// - On this case, `return_code` in [Error::LapackComputationalFailure] means + /// `return_code`-th diagonal element of $U$ becomes zero. + /// + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + + /// Solve linear equations $Ax = b$ using the output of LU-decomposition + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method + /// + /// For a given symmetric matrix $A$, + /// this method factorizes $A = U^T D U$ or $A = L D L^T$ where + /// + /// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices + /// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. + /// + /// This takes two-step approach based in LAPACK: + /// + /// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ + /// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ + /// + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; + + /// Compute inverse matrix $A^{-1}$ using the result of [Lapack::bk] + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; + + /// Solve symmetric/Hermitian linear equation $Ax = b$ using the result of [Lapack::bk] + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; + + /// Solve symmetric/Hermitian positive-definite linear equations using Cholesky decomposition + /// + /// For a given positive definite matrix $A$, + /// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where + /// + /// - $L$ is lower matrix + /// - $U$ is upper matrix + /// + /// This is designed as two step computation according to LAPACK API + /// + /// 1. Factorize input matrix $A$ into $L$ or $U$ + /// 2. Solve linear equation $Ax = b$ by [Lapack::solve_cholesky] + /// or compute inverse matrix $A^{-1}$ by [Lapack::inv_cholesky] + /// + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Compute inverse matrix $A^{-1}$ using $U$ or $L$ calculated by [Lapack::cholesky] + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>; + + /// Solve linear equation $Ax = b$ using $U$ or $L$ calculated by [Lapack::cholesky] + fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>; } macro_rules! impl_lapack { @@ -276,6 +352,72 @@ macro_rules! impl_lapack { let work = LeastSquaresWork::<$s>::new(a_layout, b_layout)?; work.eval(a, b) } + + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + use solve::*; + LuImpl::lu(l, a) + } + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()> { + use solve::*; + let mut work = InvWork::<$s>::new(l)?; + work.calc(a, p)?; + Ok(()) + } + + fn solve( + l: MatrixLayout, + t: Transpose, + a: &[Self], + p: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solve::*; + SolveImpl::solve(l, t, a, p, b) + } + + fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + use solveh::*; + let work = BkWork::<$s>::new(l)?; + work.eval(uplo, a) + } + + fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + use solveh::*; + let mut work = InvhWork::<$s>::new(l)?; + work.calc(uplo, a, ipiv) + } + + fn solveh( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + ipiv: &Pivot, + b: &mut [Self], + ) -> Result<()> { + use solveh::*; + SolvehImpl::solveh(l, uplo, a, ipiv, b) + } + + fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + CholeskyImpl::cholesky(l, uplo, a) + } + + fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> { + use cholesky::*; + InvCholeskyImpl::inv_cholesky(l, uplo, a) + } + + fn solve_cholesky( + l: MatrixLayout, + uplo: UPLO, + a: &[Self], + b: &mut [Self], + ) -> Result<()> { + use cholesky::*; + SolveCholeskyImpl::solve_cholesky(l, uplo, a, b) + } } }; } diff --git a/lax/src/solve.rs b/lax/src/solve.rs index d0f764fd..63f69983 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -1,72 +1,25 @@ +//! Solve linear equations using LU-decomposition + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[cfg_attr(doc, katexit::katexit)] -/// Solve linear equations using LU-decomposition -/// -/// For a given matrix $A$, LU decomposition is described as $A = PLU$ where: -/// -/// - $L$ is lower matrix -/// - $U$ is upper matrix -/// - $P$ is permutation matrix represented by [Pivot] +/// Helper trait to abstract `*getrf` LAPACK routines for implementing [Lapack::lu] /// -/// This is designed as two step computation according to LAPACK API: +/// LAPACK correspondance +/// ---------------------- /// -/// 1. Factorize input matrix $A$ into $L$, $U$, and $P$. -/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$ -/// using the output of LU decomposition. +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetrf | dgetrf | cgetrf | zgetrf | /// -pub trait Solve_: Scalar + Sized { - /// Computes the LU decomposition of a general $m \times n$ matrix - /// with partial pivoting with row interchanges. - /// - /// Output - /// ------- - /// - $U$ and $L$ are stored in `a` after LU decomposition has succeeded. - /// - $P$ is returned as [Pivot] - /// - /// Error - /// ------ - /// - if the matrix is singular - /// - On this case, `return_code` in [Error::LapackComputationalFailure] means - /// `return_code`-th diagonal element of $U$ becomes zero. - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrf | dgetrf | cgetrf | zgetrf | - /// +pub trait LuImpl: Scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - - /// Compute inverse matrix $A^{-1}$ from the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetri | dgetri | cgetri | zgetri | - /// - fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - - /// Solve linear equations $Ax = b$ using the output of LU-decomposition - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | sgetrs | dgetrs | cgetrs | zgetrs | - /// - fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { - impl Solve_ for $scalar { +macro_rules! impl_lu { + ($scalar:ty, $getrf:path) => { + impl LuImpl for $scalar { fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); assert_eq!(a.len() as i32, row * col); @@ -91,49 +44,55 @@ macro_rules! impl_solve { let ipiv = unsafe { ipiv.assume_init() }; Ok(ipiv) } + } + }; +} - fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); - if n == 0 { - // Do nothing for empty matrices. - return Ok(()); - } - - // calc work size - let mut info = 0; - let mut work_size = [Self::zero()]; - unsafe { - $getri( - &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work_size), - &(-1), - &mut info, - ) - }; - info.as_lapack_result()?; - - // actual - let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec> = vec_uninit(lwork); - unsafe { - $getri( - &l.len(), - AsPtr::as_mut_ptr(a), - &l.lda(), - ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work), - &(lwork as i32), - &mut info, - ) - }; - info.as_lapack_result()?; +impl_lu!(c64, lapack_sys::zgetrf_); +impl_lu!(c32, lapack_sys::cgetrf_); +impl_lu!(f64, lapack_sys::dgetrf_); +impl_lu!(f32, lapack_sys::sgetrf_); - Ok(()) - } +#[cfg_attr(doc, katexit::katexit)] +/// Helper trait to abstract `*getrs` LAPACK routines for implementing [Lapack::solve] +/// +/// If the array has C layout, then it needs to be handled +/// specially, since LAPACK expects a Fortran-layout array. +/// Reinterpreting a C layout array as Fortran layout is +/// equivalent to transposing it. So, we can handle the "no +/// transpose" and "transpose" cases by swapping to "transpose" +/// or "no transpose", respectively. For the "Hermite" case, we +/// can take advantage of the following: +/// +/// $$ +/// \begin{align*} +/// A^H x &= b \\\\ +/// \Leftrightarrow \overline{A^T} x &= b \\\\ +/// \Leftrightarrow \overline{\overline{A^T} x} &= \overline{b} \\\\ +/// \Leftrightarrow \overline{\overline{A^T}} \overline{x} &= \overline{b} \\\\ +/// \Leftrightarrow A^T \overline{x} &= \overline{b} +/// \end{align*} +/// $$ +/// +/// So, we can handle this case by switching to "no transpose" +/// (which is equivalent to transposing the array since it will +/// be reinterpreted as Fortran layout) and applying the +/// elementwise conjugate to `x` and `b`. +/// +pub trait SolveImpl: Scalar { + /// LAPACK correspondance + /// ---------------------- + /// + /// | f32 | f64 | c32 | c64 | + /// |:-------|:-------|:-------|:-------| + /// | sgetrs | dgetrs | cgetrs | zgetrs | + /// + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; +} +macro_rules! impl_solve { + ($scalar:ty, $getrs:path) => { + impl SolveImpl for $scalar { fn solve( l: MatrixLayout, t: Transpose, @@ -141,26 +100,6 @@ macro_rules! impl_solve { ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { - // If the array has C layout, then it needs to be handled - // specially, since LAPACK expects a Fortran-layout array. - // Reinterpreting a C layout array as Fortran layout is - // equivalent to transposing it. So, we can handle the "no - // transpose" and "transpose" cases by swapping to "transpose" - // or "no transpose", respectively. For the "Hermite" case, we - // can take advantage of the following: - // - // ```text - // A^H x = b - // ⟺ conj(A^T) x = b - // ⟺ conj(conj(A^T) x) = conj(b) - // ⟺ conj(conj(A^T)) conj(x) = conj(b) - // ⟺ A^T conj(x) = conj(b) - // ``` - // - // So, we can handle this case by switching to "no transpose" - // (which is equivalent to transposing the array since it will - // be reinterpreted as Fortran layout) and applying the - // elementwise conjugate to `x` and `b`. let (t, conj) = match l { MatrixLayout::C { .. } => match t { Transpose::No => (Transpose::Transpose, false), @@ -203,27 +142,83 @@ macro_rules! impl_solve { }; } // impl_solve! -impl_solve!( - f64, - lapack_sys::dgetrf_, - lapack_sys::dgetri_, - lapack_sys::dgetrs_ -); -impl_solve!( - f32, - lapack_sys::sgetrf_, - lapack_sys::sgetri_, - lapack_sys::sgetrs_ -); -impl_solve!( - c64, - lapack_sys::zgetrf_, - lapack_sys::zgetri_, - lapack_sys::zgetrs_ -); -impl_solve!( - c32, - lapack_sys::cgetrf_, - lapack_sys::cgetri_, - lapack_sys::cgetrs_ -); +impl_solve!(f64, lapack_sys::dgetrs_); +impl_solve!(f32, lapack_sys::sgetrs_); +impl_solve!(c64, lapack_sys::zgetrs_); +impl_solve!(c32, lapack_sys::cgetrs_); + +/// Working memory for computing inverse matrix +pub struct InvWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +/// Helper trait to abstract `*getri` LAPACK rotuines for implementing [Lapack::inv] +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | sgetri | dgetri | cgetri | zgetri | +/// +pub trait InvWorkImpl: Sized { + type Elem: Scalar; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>; +} + +macro_rules! impl_inv_work { + ($s:ty, $tri:path) => { + impl InvWorkImpl for InvWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let mut info = 0; + let mut work_size = [Self::Elem::zero()]; + unsafe { + $tri( + &n, + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null(), + AsPtr::as_mut_ptr(&mut work_size), + &(-1), + &mut info, + ) + }; + info.as_lapack_result()?; + let lwork = work_size[0].to_usize().unwrap(); + let work = vec_uninit(lwork); + Ok(InvWork { layout, work }) + } + + fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + if self.layout.len() == 0 { + return Ok(()); + } + let lwork = self.work.len().to_i32().unwrap(); + let mut info = 0; + unsafe { + $tri( + &self.layout.len(), + AsPtr::as_mut_ptr(a), + &self.layout.lda(), + ipiv.as_ptr(), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, + &mut info, + ) + }; + info.as_lapack_result()?; + Ok(()) + } + } + }; +} + +impl_inv_work!(c64, lapack_sys::zgetri_); +impl_inv_work!(c32, lapack_sys::cgetri_); +impl_inv_work!(f64, lapack_sys::dgetri_); +impl_inv_work!(f32, lapack_sys::sgetri_); diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index bbc6f363..abb75cb8 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -1,124 +1,169 @@ +//! Factorize symmetric/Hermitian matrix using [Bunch-Kaufman diagonal pivoting method][BK] +//! +//! [BK]: https://doi.org/10.2307/2005787 +//! + use crate::{error::*, layout::MatrixLayout, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; -#[cfg_attr(doc, katexit::katexit)] -/// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK]. -/// -/// For a given symmetric matrix $A$, -/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where -/// -/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices -/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks. -/// -/// This takes two-step approach based in LAPACK: +pub struct BkWork { + pub layout: MatrixLayout, + pub work: Vec>, + pub ipiv: Vec>, +} + +/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method /// -/// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$ -/// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$ +/// LAPACK correspondance +/// ---------------------- /// -/// [BK]: https://doi.org/10.2307/2005787 +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrf | dsytrf | chetrf | zhetrf | /// -pub trait Solveh_: Sized { - /// Factorize input matrix using Bunch-Kaufman diagonal pivoting method - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytrf | dsytrf | chetrf | zhetrf | - /// - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result; - - /// Compute inverse matrix $A^{-1}$ from factroized result - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytri | dsytri | chetri | zhetri | - /// - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>; - - /// Solve linear equation $Ax = b$ using factroized result - /// - /// LAPACK correspondance - /// ---------------------- - /// - /// | f32 | f64 | c32 | c64 | - /// |:-------|:-------|:-------|:-------| - /// | ssytrs | dsytrs | chetrs | zhetrs | - /// - fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; +pub trait BkWorkImpl: Sized { + type Elem: Scalar; + fn new(l: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>; + fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result; } -macro_rules! impl_solveh { - ($scalar:ty, $trf:path, $tri:path, $trs:path) => { - impl Solveh_ for $scalar { - fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { - let (n, _) = l.size(); - let mut ipiv = vec_uninit(n as usize); - if n == 0 { - return Ok(Vec::new()); - } +macro_rules! impl_bk_work { + ($s:ty, $trf:path) => { + impl BkWorkImpl for BkWork<$s> { + type Elem = $s; - // calc work size + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let ipiv = vec_uninit(n as usize); let mut info = 0; - let mut work_size = [Self::zero()]; + let mut work_size = [Self::Elem::zero()]; unsafe { $trf( - uplo.as_ptr(), + UPLO::Upper.as_ptr(), &n, - AsPtr::as_mut_ptr(a), - &l.lda(), - AsPtr::as_mut_ptr(&mut ipiv), + std::ptr::null_mut(), + &layout.lda(), + std::ptr::null_mut(), AsPtr::as_mut_ptr(&mut work_size), &(-1), &mut info, ) }; info.as_lapack_result()?; - - // actual let lwork = work_size[0].to_usize().unwrap(); - let mut work: Vec> = vec_uninit(lwork); + let work = vec_uninit(lwork); + Ok(BkWork { layout, work, ipiv }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> { + let (n, _) = self.layout.size(); + let lwork = self.work.len().to_i32().unwrap(); + if lwork == 0 { + return Ok(&[]); + } + let mut info = 0; unsafe { $trf( uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), - &l.lda(), - AsPtr::as_mut_ptr(&mut ipiv), - AsPtr::as_mut_ptr(&mut work), - &(lwork as i32), + &self.layout.lda(), + AsPtr::as_mut_ptr(&mut self.ipiv), + AsPtr::as_mut_ptr(&mut self.work), + &lwork, &mut info, ) }; info.as_lapack_result()?; - let ipiv = unsafe { ipiv.assume_init() }; - Ok(ipiv) + Ok(unsafe { self.ipiv.slice_assume_init_ref() }) } - fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { - let (n, _) = l.size(); + fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result { + let _ref = self.calc(uplo, a)?; + Ok(unsafe { self.ipiv.assume_init() }) + } + } + }; +} +impl_bk_work!(c64, lapack_sys::zhetrf_); +impl_bk_work!(c32, lapack_sys::chetrf_); +impl_bk_work!(f64, lapack_sys::dsytrf_); +impl_bk_work!(f32, lapack_sys::ssytrf_); + +pub struct InvhWork { + pub layout: MatrixLayout, + pub work: Vec>, +} + +/// Compute inverse matrix of symmetric/Hermitian matrix +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytri | dsytri | chetri | zhetri | +/// +pub trait InvhWorkImpl: Sized { + type Elem; + fn new(layout: MatrixLayout) -> Result; + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>; +} + +macro_rules! impl_invh_work { + ($s:ty, $tri:path) => { + impl InvhWorkImpl for InvhWork<$s> { + type Elem = $s; + + fn new(layout: MatrixLayout) -> Result { + let (n, _) = layout.size(); + let work = vec_uninit(n as usize); + Ok(InvhWork { layout, work }) + } + + fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> { + let (n, _) = self.layout.size(); let mut info = 0; - let mut work: Vec> = vec_uninit(n as usize); unsafe { $tri( uplo.as_ptr(), &n, AsPtr::as_mut_ptr(a), - &l.lda(), + &self.layout.lda(), ipiv.as_ptr(), - AsPtr::as_mut_ptr(&mut work), + AsPtr::as_mut_ptr(&mut self.work), &mut info, ) }; info.as_lapack_result()?; Ok(()) } + } + }; +} +impl_invh_work!(c64, lapack_sys::zhetri_); +impl_invh_work!(c32, lapack_sys::chetri_); +impl_invh_work!(f64, lapack_sys::dsytri_); +impl_invh_work!(f32, lapack_sys::ssytri_); +/// Solve symmetric/Hermitian linear equation +/// +/// LAPACK correspondance +/// ---------------------- +/// +/// | f32 | f64 | c32 | c64 | +/// |:-------|:-------|:-------|:-------| +/// | ssytrs | dsytrs | chetrs | zhetrs | +/// +pub trait SolvehImpl: Scalar { + fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solveh_ { + ($s:ty, $trs:path) => { + impl SolvehImpl for $s { fn solveh( l: MatrixLayout, uplo: UPLO, @@ -146,29 +191,9 @@ macro_rules! impl_solveh { } } }; -} // impl_solveh! +} -impl_solveh!( - f64, - lapack_sys::dsytrf_, - lapack_sys::dsytri_, - lapack_sys::dsytrs_ -); -impl_solveh!( - f32, - lapack_sys::ssytrf_, - lapack_sys::ssytri_, - lapack_sys::ssytrs_ -); -impl_solveh!( - c64, - lapack_sys::zhetrf_, - lapack_sys::zhetri_, - lapack_sys::zhetrs_ -); -impl_solveh!( - c32, - lapack_sys::chetrf_, - lapack_sys::chetri_, - lapack_sys::chetrs_ -); +impl_solveh_!(c64, lapack_sys::zhetrs_); +impl_solveh_!(c32, lapack_sys::chetrs_); +impl_solveh_!(f64, lapack_sys::dsytrs_); +impl_solveh_!(f32, lapack_sys::ssytrs_);