diff --git a/src/dgemm_kernel.rs b/src/dgemm_kernel.rs index 8bdba34..4464af2 100644 --- a/src/dgemm_kernel.rs +++ b/src/dgemm_kernel.rs @@ -132,22 +132,82 @@ unsafe fn at(ptr: *const T, i: usize) -> T { *ptr.offset(i as isize) } -#[test] -fn test_gemm_kernel() { - let mut a = [1.; 32]; - let mut b = [0.; 16]; - for (i, x) in a.iter_mut().enumerate() { - *x = i as f64; +#[cfg(test)] +mod tests { + use super::*; + use aligned_alloc::Alloc; + + fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy + { + unsafe { + Alloc::new(n, Gemm::align_to()).init_with(elt) + } } - for i in 0..4 { - b[i + i * 4] = 1.; + + use super::T; + type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); + + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { + const K: usize = 4; + let mut a = aligned_alloc(1., MR * K); + let mut b = aligned_alloc(0., NR * K); + for (i, x) in a.iter_mut().enumerate() { + *x = i as _; + } + + for i in 0..K { + b[i + i * NR] = 1.; + } + let mut c = [0.; MR * NR]; + unsafe { + kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize); + // col major C + } + assert_eq!(&a[..], &c[..a.len()]); } - let mut c = [0.; 32]; - unsafe { - kernel(4, 1., &a[0], &b[0], - 0., &mut c[0], 1, 8); - // transposed C so that results line up + + #[test] + fn test_native_kernel() { + test_a_kernel("kernel", kernel); } - assert_eq!(&a, &c); -} + #[test] + fn test_kernel_fallback_impl() { + test_a_kernel("kernel", kernel_fallback_impl); + } + + #[test] + fn test_loop_m_n() { + let mut m = [[0; NR]; MR]; + loop_m!(i, loop_n!(j, m[i][j] += 1)); + for arr in &m[..] { + for elt in &arr[..] { + assert_eq!(*elt, 1); + } + } + } + + mod test_arch_kernels { + use super::test_a_kernel; + macro_rules! test_arch_kernels_x86 { + ($($feature_name:tt, $function_name:ident),*) => { + $( + #[test] + fn $function_name() { + if is_x86_feature_detected_!($feature_name) { + test_a_kernel(stringify!($function_name), super::super::$function_name); + } else { + println!("Skipping, host does not have feature: {:?}", $feature_name); + } + } + )* + } + } + + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + test_arch_kernels_x86! { + "avx", kernel_target_avx, + "sse2", kernel_target_sse2 + } + } +} diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 83e8fdf..ca85ae5 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -395,14 +395,15 @@ mod tests { } } + use super::T; + type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); - #[test] - fn test_gemm_kernel() { + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { const K: usize = 4; let mut a = aligned_alloc(1., MR * K); let mut b = aligned_alloc(0., NR * K); for (i, x) in a.iter_mut().enumerate() { - *x = i as f32; + *x = i as _; } for i in 0..K { @@ -410,12 +411,22 @@ mod tests { } let mut c = [0.; MR * NR]; unsafe { - kernel(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize); + kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize); // col major C } assert_eq!(&a[..], &c[..a.len()]); } + #[test] + fn test_native_kernel() { + test_a_kernel("kernel", kernel); + } + + #[test] + fn test_kernel_fallback_impl() { + test_a_kernel("kernel", kernel_fallback_impl); + } + #[test] fn test_loop_m_n() { let mut m = [[0; NR]; MR]; @@ -426,4 +437,28 @@ mod tests { } } } + + mod test_arch_kernels { + use super::test_a_kernel; + macro_rules! test_arch_kernels_x86 { + ($($feature_name:tt, $function_name:ident),*) => { + $( + #[test] + fn $function_name() { + if is_x86_feature_detected_!($feature_name) { + test_a_kernel(stringify!($function_name), super::super::$function_name); + } else { + println!("Skipping, host does not have feature: {:?}", $feature_name); + } + } + )* + } + } + + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + test_arch_kernels_x86! { + "avx", kernel_target_avx, + "sse2", kernel_target_sse2 + } + } }