Skip to content

Commit

Permalink
TEST: Make sure all kernels are tested
Browse files Browse the repository at this point in the history
These tests are rudimentary so far, but they cover all the possible
kernels (avx, sse2, fallback) we have so far.
  • Loading branch information
bluss committed Nov 14, 2018
1 parent fdc80e6 commit b5cc042
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 19 deletions.
90 changes: 75 additions & 15 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,82 @@ unsafe fn at(ptr: *const T, i: usize) -> T {
*ptr.offset(i as isize)
}

#[test]
fn test_gemm_kernel() {
let mut a = [1.; 32];
let mut b = [0.; 16];
for (i, x) in a.iter_mut().enumerate() {
*x = i as f64;
#[cfg(test)]
mod tests {
use super::*;
use aligned_alloc::Alloc;

fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy
{
unsafe {
Alloc::new(n, Gemm::align_to()).init_with(elt)
}
}
for i in 0..4 {
b[i + i * 4] = 1.;

use super::T;
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);

fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
const K: usize = 4;
let mut a = aligned_alloc(1., MR * K);
let mut b = aligned_alloc(0., NR * K);
for (i, x) in a.iter_mut().enumerate() {
*x = i as _;
}

for i in 0..K {
b[i + i * NR] = 1.;
}
let mut c = [0.; MR * NR];
unsafe {
kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
// col major C
}
assert_eq!(&a[..], &c[..a.len()]);
}
let mut c = [0.; 32];
unsafe {
kernel(4, 1., &a[0], &b[0],
0., &mut c[0], 1, 8);
// transposed C so that results line up

#[test]
fn test_native_kernel() {
test_a_kernel("kernel", kernel);
}
assert_eq!(&a, &c);
}

#[test]
fn test_kernel_fallback_impl() {
test_a_kernel("kernel", kernel_fallback_impl);
}

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
loop_m!(i, loop_n!(j, m[i][j] += 1));
for arr in &m[..] {
for elt in &arr[..] {
assert_eq!(*elt, 1);
}
}
}

mod test_arch_kernels {
use super::test_a_kernel;
macro_rules! test_arch_kernels_x86 {
($($feature_name:tt, $function_name:ident),*) => {
$(
#[test]
fn $function_name() {
if is_x86_feature_detected_!($feature_name) {
test_a_kernel(stringify!($function_name), super::super::$function_name);
} else {
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
}
)*
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
test_arch_kernels_x86! {
"avx", kernel_target_avx,
"sse2", kernel_target_sse2
}
}
}
43 changes: 39 additions & 4 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,27 +395,38 @@ mod tests {
}
}

use super::T;
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);

#[test]
fn test_gemm_kernel() {
fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
const K: usize = 4;
let mut a = aligned_alloc(1., MR * K);
let mut b = aligned_alloc(0., NR * K);
for (i, x) in a.iter_mut().enumerate() {
*x = i as f32;
*x = i as _;
}

for i in 0..K {
b[i + i * NR] = 1.;
}
let mut c = [0.; MR * NR];
unsafe {
kernel(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
// col major C
}
assert_eq!(&a[..], &c[..a.len()]);
}

#[test]
fn test_native_kernel() {
test_a_kernel("kernel", kernel);
}

#[test]
fn test_kernel_fallback_impl() {
test_a_kernel("kernel", kernel_fallback_impl);
}

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
Expand All @@ -426,4 +437,28 @@ mod tests {
}
}
}

mod test_arch_kernels {
use super::test_a_kernel;
macro_rules! test_arch_kernels_x86 {
($($feature_name:tt, $function_name:ident),*) => {
$(
#[test]
fn $function_name() {
if is_x86_feature_detected_!($feature_name) {
test_a_kernel(stringify!($function_name), super::super::$function_name);
} else {
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
}
)*
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
test_arch_kernels_x86! {
"avx", kernel_target_avx,
"sse2", kernel_target_sse2
}
}
}

0 comments on commit b5cc042

Please sign in to comment.