Skip to content

Commit

Permalink
Merge pull request #1105 from ethanhs/kron
Browse files Browse the repository at this point in the history
Implement Kronecker product
  • Loading branch information
bluss authored Nov 11, 2021
2 parents 209d171 + 7d6fd72 commit 1c685ef
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::{LinalgScalar, Zip};

use std::any::TypeId;
use std::mem::MaybeUninit;
use alloc::vec::Vec;

#[cfg(feature = "blas")]
Expand Down Expand Up @@ -699,6 +700,39 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
}
}


/// Kronecker product of 2D matrices.
///
/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R)
/// matrix K formed by the block multiplication A_ij * B.
pub fn kron<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix2>) -> Array<A, Ix2>
where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
let dimar = a.shape()[0];
let dimac = a.shape()[1];
let dimbr = b.shape()[0];
let dimbc = b.shape()[1];
let mut out: Array2<MaybeUninit<A>> = Array2::uninit((
dimar
.checked_mul(dimbr)
.expect("Dimensions of kronecker product output array overflows usize."),
dimac
.checked_mul(dimbc)
.expect("Dimensions of kronecker product output array overflows usize."),
));
Zip::from(out.exact_chunks_mut((dimbr, dimbc)))
.and(a)
.for_each(|out, &a| {
Zip::from(out).and(b).for_each(|out, &b| {
*out = MaybeUninit::new(a * b);
})
});
unsafe { out.assume_init() }
}

#[inline(always)]
/// Return `true` if `A` and `B` are the same type
fn same_type<A: 'static, B: 'static>() -> bool {
Expand Down
1 change: 1 addition & 0 deletions src/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
pub use self::impl_linalg::general_mat_mul;
pub use self::impl_linalg::general_mat_vec_mul;
pub use self::impl_linalg::Dot;
pub use self::impl_linalg::kron;

mod impl_linalg;
63 changes: 63 additions & 0 deletions tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)]
#![cfg(feature = "std")]
use ndarray::linalg::general_mat_mul;
use ndarray::linalg::kron;
use ndarray::prelude::*;
use ndarray::{rcarr1, rcarr2};
use ndarray::{Data, LinalgScalar};
Expand Down Expand Up @@ -820,3 +821,65 @@ fn vec_mat_mul() {
}
}
}

#[test]
fn kron_square_f64() {
let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]);

assert_eq!(
kron(&a, &b),
arr2(&[
[0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0]
]),
);

assert_eq!(
kron(&b, &a),
arr2(&[
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0]
]),
)
}

#[test]
fn kron_square_i64() {
let a = arr2(&[[1, 0], [0, 1]]);
let b = arr2(&[[0, 1], [1, 0]]);

assert_eq!(
kron(&a, &b),
arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]),
);

assert_eq!(
kron(&b, &a),
arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),
)
}

#[test]
fn kron_i64() {
let a = arr2(&[[1, 0]]);
let b = arr2(&[[0, 1], [1, 0]]);
let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]);
assert_eq!(kron(&a, &b), r);

let a = arr2(&[[1, 0], [0, 0], [0, 1]]);
let b = arr2(&[[0, 1], [1, 0]]);
let r = arr2(&[
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0],
]);
assert_eq!(kron(&a, &b), r);
}

0 comments on commit 1c685ef

Please sign in to comment.