diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index d851f2f08..fe87567f0 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -14,6 +14,7 @@ use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::{LinalgScalar, Zip}; use std::any::TypeId; +use std::mem::MaybeUninit; use alloc::vec::Vec; #[cfg(feature = "blas")] @@ -699,6 +700,39 @@ unsafe fn general_mat_vec_mul_impl( } } + +/// Kronecker product of 2D matrices. +/// +/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) +/// matrix K formed by the block multiplication A_ij * B. +pub fn kron(a: &ArrayBase, b: &ArrayBase) -> Array +where + S1: Data, + S2: Data, + A: LinalgScalar, +{ + let dimar = a.shape()[0]; + let dimac = a.shape()[1]; + let dimbr = b.shape()[0]; + let dimbc = b.shape()[1]; + let mut out: Array2> = Array2::uninit(( + dimar + .checked_mul(dimbr) + .expect("Dimensions of kronecker product output array overflows usize."), + dimac + .checked_mul(dimbc) + .expect("Dimensions of kronecker product output array overflows usize."), + )); + Zip::from(out.exact_chunks_mut((dimbr, dimbc))) + .and(a) + .for_each(|out, &a| { + Zip::from(out).and(b).for_each(|out, &b| { + *out = MaybeUninit::new(a * b); + }) + }); + unsafe { out.assume_init() } +} + #[inline(always)] /// Return `true` if `A` and `B` are the same type fn same_type() -> bool { diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 8575905cd..abd7b2b9d 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -11,5 +11,6 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; pub use self::impl_linalg::Dot; +pub use self::impl_linalg::kron; mod impl_linalg; diff --git a/tests/oper.rs b/tests/oper.rs index ed612bad2..051728680 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -6,6 +6,7 @@ )] #![cfg(feature = "std")] use ndarray::linalg::general_mat_mul; +use ndarray::linalg::kron; use ndarray::prelude::*; use ndarray::{rcarr1, rcarr2}; use ndarray::{Data, LinalgScalar}; @@ -820,3 +821,65 @@ fn vec_mat_mul() { } } } + +#[test] +fn kron_square_f64() { + let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]); + let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[ + [0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 0.0] + ]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[ + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0] + ]), + ) +} + +#[test] +fn kron_square_i64() { + let a = arr2(&[[1, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), + ) +} + +#[test] +fn kron_i64() { + let a = arr2(&[[1, 0]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]); + assert_eq!(kron(&a, &b), r); + + let a = arr2(&[[1, 0], [0, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 0], + ]); + assert_eq!(kron(&a, &b), r); +}