diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 3870105f..46f4b6cc 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -289,5 +289,47 @@ impl Tensor { } } +#[macro_export] +macro_rules! assert_tensor_eq { + ($generic:ty, $a:expr, $b:expr, $eps:expr) => {{ + fn unravel_index(index: usize, shape: &[i64]) -> Vec { + let mut result = Vec::with_capacity(shape.len()); + let mut index = index as i64; + for dim in shape.iter().rev() { + result.push(index % dim); + index /= dim; + } + result.reverse(); + result + } + + let (a, b): (&Tensor, &Tensor) = (&$a, &$b); + let eps = $eps; + assert_eq!(a.size(), b.size(), "Tensor size mismatch"); + let shape = a.size(); + + for (i, (&a, &b)) in + Vec::<$generic>::from(a).iter().zip(Vec::<$generic>::from(b).iter()).enumerate() + { + assert!( + (a - b).abs() < eps, + "Tensor mismatch at index {:?}: {} != {}", + unravel_index(i, &shape), + a, + b + ); + } + }}; + (f64, $a:expr, $b:expr) => {{ + assert_tensor_eq!(f64, $a, $b, 1e-5); + }}; + (f32, $a:expr, $b:expr) => {{ + assert_tensor_eq!(f32, $a, $b, 1e-5); + }}; + ($generic:ty, $a:expr, $b:expr) => {{ + assert_tensor_eq!($generic, $a, $b, <$generic>::default()); + }}; +} + #[used] static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency]; diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index d5c30f04..bc8a1988 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -2,7 +2,7 @@ use anyhow::Result; use half::f16; use std::convert::{TryFrom, TryInto}; use std::f32; -use tch::{Device, Tensor}; +use tch::{assert_tensor_eq, Device, Tensor}; #[test] #[cfg(feature = "cuda-tests")] @@ -351,15 +351,18 @@ fn sparse() { fn einsum() { // Element-wise squaring of a vector. let t = Tensor::of_slice(&[1.0, 2.0, 3.0]); - let t = Tensor::einsum("i, i -> i", &[&t, &t]); - assert_eq!(Vec::::from(&t), [1.0, 4.0, 9.0]); + let e = Tensor::einsum("i, i -> i", &[&t, &t]); + assert_eq!(Vec::::from(&e), [1.0, 4.0, 9.0]); + assert_tensor_eq!(f64, &t.multiply(&t), e); // Matrix transpose let t = Tensor::of_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]); - let t = Tensor::einsum("ij -> ji", &[t]); - assert_eq!(Vec::::from(&t), [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); + let e = Tensor::einsum("ij -> ji", &[&t]); + assert_eq!(Vec::::from(&e), [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); + assert_tensor_eq!(f64, &t.transpose(0, 1), e); // Sum all elements - let t = Tensor::einsum("ij -> ", &[t]); - assert_eq!(Vec::::from(&t), [21.0]); + let e = Tensor::einsum("ij -> ", &[&t]); + assert_eq!(Vec::::from(&e), [21.0]); + assert_tensor_eq!(f64, &t.sum(tch::Kind::Float), e); } #[test] @@ -458,3 +461,111 @@ fn set_data() { t.set_data(&t.to_kind(tch::Kind::BFloat16)); assert_eq!(t.kind(), tch::Kind::BFloat16); } + +#[test] +fn einsum_matrix_multiply() -> Result<()> { + let v = Tensor::rand(&[1, 4], (tch::Kind::Float, Device::Cpu)); + let a = Tensor::rand(&[3, 4], (tch::Kind::Float, Device::Cpu)); + let b = Tensor::rand(&[4, 3], (tch::Kind::Float, Device::Cpu)); + let ab = a.mm(&b); + let av = a.mm(&v.transpose(0, 1)); + + let einsum_ab = Tensor::einsum("ij,jk->ik", &[&a, &b]); + let einsum_av = Tensor::einsum("ij,kj->ik", &[&a, &v]); + let einsum_hadamard = Tensor::einsum("ij,ij->ij", &[&a, &a]); + + assert_tensor_eq!(f64, &ab, &einsum_ab); + assert_tensor_eq!(f64, &av, &einsum_av); + assert_tensor_eq!(f64, &a.multiply(&a), &einsum_hadamard); + + Ok(()) +} + +#[test] +fn einsum_batch_matrix_multiply() -> Result<()> { + let a = Tensor::rand(&[3, 2, 5], (tch::Kind::Float, Device::Cpu)); + let b = Tensor::rand(&[3, 5, 2], (tch::Kind::Float, Device::Cpu)); + let ab = a.bmm(&b); + + let einsum_ab = Tensor::einsum("ijk,ikl->ijl", &[&a, &b]); + + assert_tensor_eq!(f64, &ab, &einsum_ab); + + Ok(()) +} + +#[test] +fn einsum_outer_product() -> Result<()> { + let a = Tensor::rand(&[3], (tch::Kind::Float, Device::Cpu)); + let b = Tensor::rand(&[3], (tch::Kind::Float, Device::Cpu)); + let ab = a.outer(&b); + let einsum_ab = Tensor::einsum("i,j->ij", &[&a, &b]); + + assert_tensor_eq!(f64, &ab, &einsum_ab); + + Ok(()) +} + +#[test] +fn einsum_diagonal() -> Result<()> { + let a = Tensor::rand(&[3, 3], (tch::Kind::Float, Device::Cpu)); + let diag = a.diag(0); + let einsum_diag = Tensor::einsum("ii->i", &[&a]); + + assert_tensor_eq!(f64, &diag, &einsum_diag); + + Ok(()) +} + +#[test] +fn einsum_trace() -> Result<()> { + let a = Tensor::rand(&[3, 3], (tch::Kind::Float, Device::Cpu)); + let trace = a.trace(); + let einsum_trace = Tensor::einsum("ii->", &[&a]); + + assert_tensor_eq!(f64, &trace, &einsum_trace); + + Ok(()) +} + +#[test] +fn einsum_dot_product() -> Result<()> { + let a = Tensor::rand(&[3], (tch::Kind::Float, Device::Cpu)); + let b = Tensor::rand(&[3], (tch::Kind::Float, Device::Cpu)); + let m = Tensor::rand(&[3, 4], (tch::Kind::Float, Device::Cpu)); + + let einsum_m = Tensor::einsum("ij,ij->", &[&m, &m]); + let einsum_dot = Tensor::einsum("i,i->", &[&a, &b]); + + assert_tensor_eq!(f64, &a.dot(&b), &einsum_dot); + assert_tensor_eq!(f64, &m.multiply(&m).sum(tch::Kind::Float), &einsum_m); + + Ok(()) +} + +#[test] +fn einsum_permute() -> Result<()> { + let a = Tensor::rand(&[5, 4, 3], (tch::Kind::Float, Device::Cpu)); + let einsum_a = Tensor::einsum("ijk->kji", &[&a]); + assert_eq!(vec![3, 4, 5], einsum_a.size()); + + Ok(()) +} + +#[test] +fn einsum_sum() -> Result<()> { + let a = Tensor::rand(&[3], (tch::Kind::Float, Device::Cpu)); + let b = Tensor::rand(&[3, 3], (tch::Kind::Float, Device::Cpu)); + + let einsum_a = Tensor::einsum("i->", &[&a]); + let einsum_b = Tensor::einsum("ij->", &[&b]); + let column_einsum_b = Tensor::einsum("ij->j", &[&b]); + let row_einsum_b = Tensor::einsum("ij->i", &[&b]); + + assert_tensor_eq!(f64, &a.sum(tch::Kind::Float), &einsum_a); + assert_tensor_eq!(f64, &b.sum(tch::Kind::Float), &einsum_b); + assert_tensor_eq!(f64, &b.sum_dim_intlist(&[0], false, tch::Kind::Float), &column_einsum_b); + assert_tensor_eq!(f64, &b.sum_dim_intlist(&[1], false, tch::Kind::Float), &row_einsum_b); + + Ok(()) +}