diff --git a/src/complex/matrix.rs b/src/complex/matrix.rs index 860ebb4b..cd224d6c 100644 --- a/src/complex/matrix.rs +++ b/src/complex/matrix.rs @@ -7,14 +7,15 @@ use std::{ use anyhow::{bail, Result}; use matrixmultiply::CGemmOption; use num_complex::Complex; +use peroxide_num::{ExpLogOps, PowOps, TrigOps}; use rand_distr::num_traits::{One, Zero}; use crate::{ fuga::{ - nearly_eq, tab, ConcatenateError, InnerProduct, LinearOp, MatrixProduct, Norm, Normed, - Shape, Vector, + copy_vec_ptr, nearly_eq, swap_vec_ptr, tab, Algorithm, ConcatenateError, FPMatrix, + InnerProduct, LinearOp, MatrixProduct, Norm, Normed, Shape, Vector, }, - traits::{fp::FPMatrix, mutable::MutMatrix}, + traits::{fp::FPVector, mutable::MutMatrix}, }; /// R-like complex matrix structure @@ -1761,15 +1762,122 @@ impl FPMatrix for ComplexMatrix { } } +pub fn diag(n: usize) -> ComplexMatrix { + let mut v: Vec> = vec![Complex::zero(); n * n]; + for i in 0..n { + let idx = i * (n + 1); + v[idx] = Complex::one(); + } + complex_matrix(v, n, n, Shape::Row) +} + +/// Data structure for Complete Pivoting LU decomposition +/// +/// # Usage +/// ```rust +/// extern crate peroxide; +/// use peroxide::fuga::*; +/// use num_complex::Complex64; +/// use peroxide::complex::matrix::*; +/// use peroxide::complex::matrix::LinearAlgebra; +/// +/// let a = complex_matrix(vec![Complex64::new(1f64, 1f64), +/// Complex64::new(2f64, 2f64), +/// Complex64::new(3f64, 3f64), +/// Complex64::new(4f64, 4f64)], +/// 2, 2, Row +/// ); +/// let pqlu = a.lu(); +/// let (p, q, l, u) = pqlu.extract(); +/// // p, q are permutations +/// // l, u are matrices +/// println!("{}", l); // lower triangular +/// println!("{}", u); // upper triangular +/// ``` +#[derive(Debug, Clone)] +pub struct PQLU { + pub p: Vec, + pub q: Vec, + pub l: ComplexMatrix, + pub u: ComplexMatrix, +} + +impl PQLU { + pub fn extract(&self) -> (Vec, Vec, ComplexMatrix, ComplexMatrix) { + ( + self.p.clone(), + self.q.clone(), + self.l.clone(), + self.u.clone(), + ) + } + + pub fn det(&self) -> Complex { + // sgn of perms + let mut sgn_p = 1f64; + let mut sgn_q = 1f64; + for (i, &j) in self.p.iter().enumerate() { + if i != j { + sgn_p *= -1f64; + } + } + for (i, &j) in self.q.iter().enumerate() { + if i != j { + sgn_q *= -1f64; + } + } + + self.u.diag().reduce(Complex::one(), |x, y| x * y) * sgn_p * sgn_q + } + + pub fn inv(&self) -> ComplexMatrix { + let (p, q, l, u) = self.extract(); + let mut m = complex_inv_u(u) * complex_inv_l(l); + // Q = Q1 Q2 Q3 .. + for (idx1, idx2) in q.into_iter().enumerate().rev() { + unsafe { + m.swap(idx1, idx2, Shape::Row); + } + } + // P = Pn-1 .. P3 P2 P1 + for (idx1, idx2) in p.into_iter().enumerate().rev() { + unsafe { + m.swap(idx1, idx2, Shape::Col); + } + } + m + } +} + +/// MATLAB like eye - Identity matrix +pub fn eye(n: usize) -> ComplexMatrix { + let mut m = complex_matrix(vec![Complex::zero(); n * n], n, n, Shape::Row); + for i in 0..n { + m[(i, i)] = Complex::one(); + } + m +} + // ============================================================================= // Linear Algebra // ============================================================================= +#[derive(Debug, Copy, Clone)] +pub enum SolveKind { + LU, + WAZ, +} + /// Linear algebra trait pub trait LinearAlgebra { fn back_subs(&self, b: &Vec>) -> Vec>; fn forward_subs(&self, b: &Vec>) -> Vec>; + fn lu(&self) -> PQLU; + fn det(&self) -> Complex; fn block(&self) -> (ComplexMatrix, ComplexMatrix, ComplexMatrix, ComplexMatrix); + fn inv(&self) -> ComplexMatrix; + fn solve(&self, b: &Vec>, sk: SolveKind) -> Vec>; + fn solve_mat(&self, m: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix; fn is_symmetric(&self) -> bool; // ToDo: Add other fn of this trait from src/structure/matrix.rs } @@ -1805,6 +1913,112 @@ impl LinearAlgebra for ComplexMatrix { y } + /// LU Decomposition Implements (Complete Pivot) + /// + /// # Description + /// It use complete pivoting LU decomposition. + /// You can get two permutations, and LU matrices. + /// + /// # Caution + /// It returns `Option` - You should unwrap to obtain real value. + /// `PQLU` has four field - `p`, `q`, `l`, `u`. + /// `p`, `q` are permutations. + /// `l`, `u` are matrices. + /// + /// # Examples + /// ``` + /// #[macro_use] + /// use peroxide::fuga::*; + /// use num_complex::Complex64; + /// use peroxide::complex::matrix::*; + /// use peroxide::complex::matrix::LinearAlgebra; + /// + /// fn main() { + /// let a = complex_matrix(vec![Complex64::new(1f64, 1f64), + /// Complex64::new(2f64, 2f64), + /// Complex64::new(3f64, 3f64), + /// Complex64::new(4f64, 4f64)], + /// 2, 2, Row + /// ); + /// + /// let l_exp = complex_matrix(vec![Complex64::new(1f64, 0f64), + /// Complex64::new(0f64, 0f64), + /// Complex64::new(0.5f64, -0.0f64), + /// Complex64::new(1f64, 0f64)], + /// 2, 2, Row + /// ); + /// + /// let u_exp = complex_matrix(vec![Complex64::new(4f64, 4f64), + /// Complex64::new(3f64, 3f64), + /// Complex64::new(0f64, 0f64), + /// Complex64::new(-0.5f64, -0.5f64)], + /// 2, 2, Row + /// ); + /// let pqlu = a.lu(); + /// let (p,q,l,u) = (pqlu.p, pqlu.q, pqlu.l, pqlu.u); + /// assert_eq!(p, vec![1]); // swap 0 & 1 (Row) + /// assert_eq!(q, vec![1]); // swap 0 & 1 (Col) + /// assert_eq!(l, l_exp); + /// assert_eq!(u, u_exp); + /// } + /// ``` + fn lu(&self) -> PQLU { + assert_eq!(self.col, self.row); + let n = self.row; + let len: usize = n * n; + + let mut l = eye(n); + let mut u = complex_matrix(vec![Complex::zero(); len], n, n, self.shape); + + let mut temp = self.clone(); + let (p, q) = gecp(&mut temp); + for i in 0..n { + for j in 0..i { + // Inverse multiplier + l[(i, j)] = -temp[(i, j)]; + } + for j in i..n { + u[(i, j)] = temp[(i, j)]; + } + } + // Pivoting L + for i in 0..n - 1 { + unsafe { + let l_i = l.col_mut(i); + for j in i + 1..l.col - 1 { + let dst = p[j]; + std::ptr::swap(l_i[j], l_i[dst]); + } + } + } + PQLU { p, q, l, u } + } + + /// Determinant + /// + /// # Examples + /// ``` + /// #[macro_use] + /// use peroxide::fuga::*; + /// use num_complex::Complex64; + /// use peroxide::complex::matrix::*; + /// use peroxide::complex::matrix::LinearAlgebra; + /// + /// fn main() { + /// let a = complex_matrix(vec![Complex64::new(1f64, 1f64), + /// Complex64::new(2f64, 2f64), + /// Complex64::new(3f64, 3f64), + /// Complex64::new(4f64, 4f64)], + /// 2, 2, Row + /// ); + /// assert_eq!(a.det().norm(), 4f64); + /// } + /// ``` + fn det(&self) -> Complex { + assert_eq!(self.row, self.col); + self.lu().det() + } + /// Block Partition /// /// # Examples @@ -1865,6 +2079,102 @@ impl LinearAlgebra for ComplexMatrix { (m1, m2, m3, m4) } + /// Inverse of Matrix + /// + /// # Caution + /// + /// `inv` function returns `Option` + /// Thus, you should use pattern matching or `unwrap` to obtain inverse. + /// + /// # Examples + /// ``` + /// #[macro_use] + /// extern crate peroxide; + /// use peroxide::fuga::*; + /// use num_complex::Complex64; + /// use peroxide::complex::matrix::*; + /// use peroxide::complex::matrix::LinearAlgebra; + /// + /// fn main() { + /// // Non-singular + /// let a = complex_matrix(vec![Complex64::new(1f64, 1f64), + /// Complex64::new(2f64, 2f64), + /// Complex64::new(3f64, 3f64), + /// Complex64::new(4f64, 4f64)], + /// 2, 2, Row + /// ); + /// + /// let a_inv_exp = complex_matrix(vec![Complex64::new(-1.0f64, 1f64), + /// Complex64::new(0.5f64, -0.5f64), + /// Complex64::new(0.75f64, -0.75f64), + /// Complex64::new(-0.25f64, 0.25f64)], + /// 2, 2, Row + /// ); + /// assert_eq!(a.inv(), a_inv_exp); + /// } + /// ``` + fn inv(&self) -> Self { + self.lu().inv() + } + + /// Solve with Vector + /// + /// # Solve options + /// + /// * LU: Gaussian elimination with Complete pivoting LU (GECP) + /// * WAZ: Solve with WAZ decomposition + fn solve(&self, b: &Vec>, sk: SolveKind) -> Vec> { + match sk { + SolveKind::LU => { + let lu = self.lu(); + let (p, q, l, u) = lu.extract(); + let mut v = b.clone(); + v.swap_with_perm(&p.into_iter().enumerate().collect()); + let z = l.forward_subs(&v); + let mut y = u.back_subs(&z); + y.swap_with_perm(&q.into_iter().enumerate().rev().collect()); + y + } + SolveKind::WAZ => { + unimplemented!() + } + } + } + + fn solve_mat(&self, m: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix { + match sk { + SolveKind::LU => { + let lu = self.lu(); + let (p, q, l, u) = lu.extract(); + let mut x = complex_matrix( + vec![Complex::zero(); self.col * m.col], + self.col, + m.col, + Shape::Col, + ); + for i in 0..m.col { + let mut v = m.col(i).clone(); + for (r, &s) in p.iter().enumerate() { + v.swap(r, s); + } + let z = l.forward_subs(&v); + let mut y = u.back_subs(&z); + for (r, &s) in q.iter().enumerate() { + y.swap(r, s); + } + unsafe { + let mut c = x.col_mut(i); + copy_vec_ptr(&mut c, &y); + } + } + x + } + SolveKind::WAZ => { + unimplemented!() + } + } + } + fn is_symmetric(&self) -> bool { if self.row != self.col { return false; @@ -1883,6 +2193,11 @@ impl LinearAlgebra for ComplexMatrix { } } +#[allow(non_snake_case)] +pub fn solve(A: &ComplexMatrix, b: &ComplexMatrix, sk: SolveKind) -> ComplexMatrix { + A.solve_mat(b, sk) +} + impl MutMatrix for ComplexMatrix { type Scalar = Complex; @@ -1934,8 +2249,8 @@ impl MutMatrix for ComplexMatrix { unsafe fn swap(&mut self, idx1: usize, idx2: usize, shape: Shape) { match shape { - Shape::Col => swap_complex_vec_ptr(&mut self.col_mut(idx1), &mut self.col_mut(idx2)), - Shape::Row => swap_complex_vec_ptr(&mut self.row_mut(idx1), &mut self.row_mut(idx2)), + Shape::Col => swap_vec_ptr(&mut self.col_mut(idx1), &mut self.col_mut(idx2)), + Shape::Row => swap_vec_ptr(&mut self.row_mut(idx1), &mut self.row_mut(idx2)), } } @@ -1946,14 +2261,101 @@ impl MutMatrix for ComplexMatrix { } } -// ToDo: Move swap_complex_vec_ptr to low_level.rs -pub unsafe fn swap_complex_vec_ptr( - lhs: &mut Vec<*mut Complex>, - rhs: &mut Vec<*mut Complex>, -) { - assert_eq!(lhs.len(), rhs.len(), "Should use same length vectors"); - for (&mut l, &mut r) in lhs.iter_mut().zip(rhs.iter_mut()) { - std::ptr::swap(l, r); +impl ExpLogOps for ComplexMatrix { + type Float = Complex; + + fn exp(&self) -> Self { + self.fmap(|x| x.exp()) + } + fn ln(&self) -> Self { + self.fmap(|x| x.ln()) + } + fn log(&self, base: Self::Float) -> Self { + self.fmap(|x| x.ln() / base.ln()) // Using `Log: change of base` formula + } + fn log2(&self) -> Self { + self.fmap(|x| x.ln() / 2.0.ln()) // Using `Log: change of base` formula + } + fn log10(&self) -> Self { + self.fmap(|x| x.ln() / 10.0.ln()) // Using `Log: change of base` formula + } +} + +impl PowOps for ComplexMatrix { + type Float = Complex; + + fn powi(&self, n: i32) -> Self { + self.fmap(|x| x.powi(n)) + } + + fn powf(&self, f: Self::Float) -> Self { + self.fmap(|x| x.powc(f)) + } + + fn pow(&self, _f: Self) -> Self { + unimplemented!() + } + + fn sqrt(&self) -> Self { + self.fmap(|x| x.sqrt()) + } +} + +impl TrigOps for ComplexMatrix { + fn sin_cos(&self) -> (Self, Self) { + let (sin, cos) = self.data.iter().map(|x| (x.sin(), x.cos())).unzip(); + ( + complex_matrix(sin, self.row, self.col, self.shape), + complex_matrix(cos, self.row, self.col, self.shape), + ) + } + + fn sin(&self) -> Self { + self.fmap(|x| x.sin()) + } + + fn cos(&self) -> Self { + self.fmap(|x| x.cos()) + } + + fn tan(&self) -> Self { + self.fmap(|x| x.tan()) + } + + fn sinh(&self) -> Self { + self.fmap(|x| x.sinh()) + } + + fn cosh(&self) -> Self { + self.fmap(|x| x.cosh()) + } + + fn tanh(&self) -> Self { + self.fmap(|x| x.tanh()) + } + + fn asin(&self) -> Self { + self.fmap(|x| x.asin()) + } + + fn acos(&self) -> Self { + self.fmap(|x| x.acos()) + } + + fn atan(&self) -> Self { + self.fmap(|x| x.atan()) + } + + fn asinh(&self) -> Self { + self.fmap(|x| x.asinh()) + } + + fn acosh(&self) -> Self { + self.fmap(|x| x.acosh()) + } + + fn atanh(&self) -> Self { + self.fmap(|x| x.atanh()) } } @@ -2281,3 +2683,124 @@ pub fn complex_gevm( ) } } + +/// LU via Gaussian Elimination with Partial Pivoting +#[allow(dead_code)] +fn gepp(m: &mut ComplexMatrix) -> Vec { + let mut r = vec![0usize; m.col - 1]; + for k in 0..(m.col - 1) { + // Find the pivot row + let r_k = m + .col(k) + .into_iter() + .skip(k) + .enumerate() + .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap()) + .unwrap() + .0 + + k; + r[k] = r_k; + + // Interchange the rows r_k and k + for j in k..m.col { + unsafe { + std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]); + println!("Swap! k:{}, r_k:{}", k, r_k); + } + } + // Form the multipliers + for i in k + 1..m.col { + m[(i, k)] = -m[(i, k)] / m[(k, k)]; + } + // Update the entries + for i in k + 1..m.col { + for j in k + 1..m.col { + let local_m = m[(i, k)] * m[(k, j)]; + m[(i, j)] += local_m; + } + } + } + r +} + +/// LU via Gauss Elimination with Complete Pivoting +fn gecp(m: &mut ComplexMatrix) -> (Vec, Vec) { + let n = m.col; + let mut r = vec![0usize; n - 1]; + let mut s = vec![0usize; n - 1]; + for k in 0..n - 1 { + // Find pivot + let (r_k, s_k) = match m.shape { + Shape::Col => { + let mut row_ics = 0usize; + let mut col_ics = 0usize; + let mut max_val = 0f64; + for i in k..n { + let c = m + .col(i) + .into_iter() + .skip(k) + .enumerate() + .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap()) + .unwrap(); + let c_ics = c.0 + k; + let c_val = c.1.norm(); + if c_val > max_val { + row_ics = c_ics; + col_ics = i; + max_val = c_val; + } + } + (row_ics, col_ics) + } + Shape::Row => { + let mut row_ics = 0usize; + let mut col_ics = 0usize; + let mut max_val = 0f64; + for i in k..n { + let c = m + .row(i) + .into_iter() + .skip(k) + .enumerate() + .max_by(|x1, x2| x1.1.norm().partial_cmp(&x2.1.norm()).unwrap()) + .unwrap(); + let c_ics = c.0 + k; + let c_val = c.1.norm(); + if c_val > max_val { + col_ics = c_ics; + row_ics = i; + max_val = c_val; + } + } + (row_ics, col_ics) + } + }; + r[k] = r_k; + s[k] = s_k; + + // Interchange rows + for j in k..n { + unsafe { + std::ptr::swap(&mut m[(k, j)], &mut m[(r_k, j)]); + } + } + + // Interchange cols + for i in 0..n { + unsafe { + std::ptr::swap(&mut m[(i, k)], &mut m[(i, s_k)]); + } + } + + // Form the multipliers + for i in k + 1..n { + m[(i, k)] = -m[(i, k)] / m[(k, k)]; + for j in k + 1..n { + let local_m = m[(i, k)] * m[(k, j)]; + m[(i, j)] += local_m; + } + } + } + (r, s) +} diff --git a/src/complex/vector.rs b/src/complex/vector.rs index 58110730..e5836643 100644 --- a/src/complex/vector.rs +++ b/src/complex/vector.rs @@ -1,3 +1,4 @@ +use crate::fuga::Algorithm; use crate::traits::fp::FPVector; use crate::traits::math::{InnerProduct, Norm, Normed, Vector}; use crate::traits::sugar::VecOps; @@ -137,3 +138,37 @@ impl InnerProduct for Vec> { } impl VecOps for Vec> {} + +impl Algorithm for Vec> { + type Scalar = Complex; + + fn rank(&self) -> Vec { + unimplemented!() + } + + fn sign(&self) -> Complex { + unimplemented!() + } + + fn arg_max(&self) -> usize { + unimplemented!() + } + + fn arg_min(&self) -> usize { + unimplemented!() + } + + fn max(&self) -> Complex { + unimplemented!() + } + + fn min(&self) -> Complex { + unimplemented!() + } + + fn swap_with_perm(&mut self, p: &Vec<(usize, usize)>) { + for (i, j) in p.iter() { + self.swap(*i, *j); + } + } +} diff --git a/src/structure/vector.rs b/src/structure/vector.rs index 0e977361..9a0dcd71 100644 --- a/src/structure/vector.rs +++ b/src/structure/vector.rs @@ -473,6 +473,8 @@ impl MutFP for Vec { } impl Algorithm for Vec { + type Scalar = f64; + /// Assign rank /// /// # Examples diff --git a/src/traits/general.rs b/src/traits/general.rs index 2737f97b..4712d304 100644 --- a/src/traits/general.rs +++ b/src/traits/general.rs @@ -1,10 +1,12 @@ /// Some algorithms for Vector pub trait Algorithm { + type Scalar; + fn rank(&self) -> Vec; - fn sign(&self) -> f64; + fn sign(&self) -> Self::Scalar; fn arg_max(&self) -> usize; fn arg_min(&self) -> usize; - fn max(&self) -> f64; - fn min(&self) -> f64; + fn max(&self) -> Self::Scalar; + fn min(&self) -> Self::Scalar; fn swap_with_perm(&mut self, p: &Vec<(usize, usize)>); } diff --git a/src/util/low_level.rs b/src/util/low_level.rs index 8b2a1f52..53fe5016 100644 --- a/src/util/low_level.rs +++ b/src/util/low_level.rs @@ -1,17 +1,26 @@ -pub unsafe fn copy_vec_ptr(dst: &mut Vec<*mut f64>, src: &Vec) { +pub unsafe fn copy_vec_ptr(dst: &mut Vec<*mut T>, src: &Vec) +where + T: Copy, +{ assert_eq!(dst.len(), src.len(), "Should use same length vectors"); for (&mut p, &s) in dst.iter_mut().zip(src) { *p = s; } } -pub unsafe fn swap_vec_ptr(lhs: &mut Vec<*mut f64>, rhs: &mut Vec<*mut f64>) { +pub unsafe fn swap_vec_ptr(lhs: &mut Vec<*mut T>, rhs: &mut Vec<*mut T>) +where + T: Copy, +{ assert_eq!(lhs.len(), rhs.len(), "Should use same length vectors"); for (&mut l, &mut r) in lhs.iter_mut().zip(rhs.iter_mut()) { std::ptr::swap(l, r); } } -pub unsafe fn ptr_to_vec(pv: &Vec<*const f64>) -> Vec { +pub unsafe fn ptr_to_vec(pv: &Vec<*const T>) -> Vec +where + T: Copy, +{ pv.iter().map(|&x| *x).collect() }