diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs
index d851f2f08..fe87567f0 100644
--- a/src/linalg/impl_linalg.rs
+++ b/src/linalg/impl_linalg.rs
@@ -14,6 +14,7 @@ use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::{LinalgScalar, Zip};
use std::any::TypeId;
+use std::mem::MaybeUninit;
use alloc::vec::Vec;
#[cfg(feature = "blas")]
@@ -699,6 +700,39 @@ unsafe fn general_mat_vec_mul_impl(
}
}
+
+/// Kronecker product of 2D matrices.
+///
+/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
+/// matrix K formed by the block multiplication A_ij * B.
+pub fn kron(a: &ArrayBase, b: &ArrayBase) -> Array
+where
+ S1: Data,
+ S2: Data,
+ A: LinalgScalar,
+{
+ let dimar = a.shape()[0];
+ let dimac = a.shape()[1];
+ let dimbr = b.shape()[0];
+ let dimbc = b.shape()[1];
+ let mut out: Array2> = Array2::uninit((
+ dimar
+ .checked_mul(dimbr)
+ .expect("Dimensions of kronecker product output array overflows usize."),
+ dimac
+ .checked_mul(dimbc)
+ .expect("Dimensions of kronecker product output array overflows usize."),
+ ));
+ Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
+ .and(a)
+ .for_each(|out, &a| {
+ Zip::from(out).and(b).for_each(|out, &b| {
+ *out = MaybeUninit::new(a * b);
+ })
+ });
+ unsafe { out.assume_init() }
+}
+
#[inline(always)]
/// Return `true` if `A` and `B` are the same type
fn same_type() -> bool {
diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs
index 8575905cd..abd7b2b9d 100644
--- a/src/linalg/mod.rs
+++ b/src/linalg/mod.rs
@@ -11,5 +11,6 @@
pub use self::impl_linalg::general_mat_mul;
pub use self::impl_linalg::general_mat_vec_mul;
pub use self::impl_linalg::Dot;
+pub use self::impl_linalg::kron;
mod impl_linalg;
diff --git a/tests/oper.rs b/tests/oper.rs
index ed612bad2..051728680 100644
--- a/tests/oper.rs
+++ b/tests/oper.rs
@@ -6,6 +6,7 @@
)]
#![cfg(feature = "std")]
use ndarray::linalg::general_mat_mul;
+use ndarray::linalg::kron;
use ndarray::prelude::*;
use ndarray::{rcarr1, rcarr2};
use ndarray::{Data, LinalgScalar};
@@ -820,3 +821,65 @@ fn vec_mat_mul() {
}
}
}
+
+#[test]
+fn kron_square_f64() {
+ let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
+ let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]);
+
+ assert_eq!(
+ kron(&a, &b),
+ arr2(&[
+ [0.0, 1.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ [0.0, 0.0, 1.0, 0.0]
+ ]),
+ );
+
+ assert_eq!(
+ kron(&b, &a),
+ arr2(&[
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0, 0.0]
+ ]),
+ )
+}
+
+#[test]
+fn kron_square_i64() {
+ let a = arr2(&[[1, 0], [0, 1]]);
+ let b = arr2(&[[0, 1], [1, 0]]);
+
+ assert_eq!(
+ kron(&a, &b),
+ arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]),
+ );
+
+ assert_eq!(
+ kron(&b, &a),
+ arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),
+ )
+}
+
+#[test]
+fn kron_i64() {
+ let a = arr2(&[[1, 0]]);
+ let b = arr2(&[[0, 1], [1, 0]]);
+ let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]);
+ assert_eq!(kron(&a, &b), r);
+
+ let a = arr2(&[[1, 0], [0, 0], [0, 1]]);
+ let b = arr2(&[[0, 1], [1, 0]]);
+ let r = arr2(&[
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ [0, 0, 0, 1],
+ [0, 0, 1, 0],
+ ]);
+ assert_eq!(kron(&a, &b), r);
+}