From 5e225394c3ff28d886f727cc882419707de0fb71 Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Thu, 6 Dec 2018 21:09:39 +0100 Subject: [PATCH 1/6] WIP --- src/igemm_kernel.rs | 276 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 273 insertions(+), 3 deletions(-) diff --git a/src/igemm_kernel.rs b/src/igemm_kernel.rs index c91bfec..28e7505 100644 --- a/src/igemm_kernel.rs +++ b/src/igemm_kernel.rs @@ -7,7 +7,6 @@ // except according to those terms. use kernel::GemmKernel; -use kernel::Element; use archparam; @@ -21,7 +20,7 @@ pub enum Gemm { } pub type T = i32; const MR: usize = 8; -const NR: usize = 4; +const NR: usize = 8; macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } @@ -78,7 +77,9 @@ pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, // dispatch to specific compiled versions #[cfg(any(target_arch="x86", target_arch="x86_64"))] { - if is_x86_feature_detected_!("avx") { + if is_x86_feature_detected_!("avx2") { + return kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("avx") { return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); } else if is_x86_feature_detected_!("sse2") { return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); @@ -87,6 +88,15 @@ pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); } +#[inline] +#[target_feature(enable="avx2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx2(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + kernel_x86_avx2(k, alpha, a, b, beta, c, rsc, csc) +} + #[inline] #[target_feature(enable="avx")] #[cfg(any(target_arch="x86", target_arch="x86_64"))] @@ -133,6 +143,265 @@ unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); } +#[inline(always)] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_x86_avx2(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + debug_assert_ne!(k, 0); + + let mut ab = [_mm256_setzero_si256(); MR]; + + let (mut a, mut b) = (a, b); + + println!("a_slice: {:?}", std::slice::from_raw_parts(a, 8)); + let mut a_0123_4567 = _mm256_loadu_si256(a as *const __m256i); + println!("a_simd: {:?}", a_0123_4567); + + println!("b_slice: {:?}", std::slice::from_raw_parts(b, 8)); + let mut b_0123_4567 = _mm256_loadu_si256(b as *const __m256i); + println!("b_simd: {:?}", b_0123_4567); + + // The task in the loop below is to multiply every number packed in b_0123_4567 with + // every number in the two a_* vectors. With a_* loading a column, and b_* loading a row, + // this is exactly equivalent to matrix multiplication defined by C_ij = ∑_k A_ik B_kj + C_ij, but + // where you fix k, perform all multiplications in i and j for a fixed k, add the result to + // C, and then increment k to and repeat. + // + // Reapeted indices are shortened, e.g. b_02_46 = b_0022_4466 + + unroll_by_with_last!(4 => k, is_last, { + // two bits select the value of one i32, so 4 bits select two adjacent i32. 4 bits can be + // represented in hexadecimal notation. The first (from back to font) 4 pairs of bits + // select the first 128bit lane, the last 4 pairs the second lane. + let b_02_46 = _mm256_shuffle_epi32( + b_0123_4567, + 0b_10_10_00_00__10_10_00_00 + ); + + let b_20_64 = _mm256_shuffle_epi32( + b_0123_4567, + 0b_00_00_10_10__00_00_10_10 + ); + + let b_46_02 = _mm256_permute2x128_si256( + b_02_46, + b_02_46, + 0x03 + ); + + let b_64_20 = _mm256_permute2x128_si256( + b_20_64, + b_20_64, + 0x03 + ); + + let b_13_57 = _mm256_shuffle_epi32( + b_0123_4567, + 0b_11_11_01_01__11_11_01_01 + ); + + let b_31_75 = _mm256_shuffle_epi32( + b_0123_4567, + 0b_01_01_11_11__01_01_11_11 + ); + + let b_57_13 = _mm256_permute2x128_si256( + b_13_57, + b_13_57, + 0x03 + ); + + let b_75_31 = _mm256_permute2x128_si256( + b_31_75, + b_31_75, + 0x03 + ); + + // Add and multiply in one go + ab[0] = _mm256_add_epi32(ab[0], _mm256_mul_epi32(a_0123_4567, b_02_46)); + ab[1] = _mm256_add_epi32(ab[1], _mm256_mul_epi32(a_0123_4567, b_20_64)); + ab[2] = _mm256_add_epi32(ab[2], _mm256_mul_epi32(a_0123_4567, b_46_02)); + ab[3] = _mm256_add_epi32(ab[3], _mm256_mul_epi32(a_0123_4567, b_64_20)); + + ab[4] = _mm256_add_epi32(ab[4], _mm256_mul_epi32(a_0123_4567, b_13_57)); + ab[5] = _mm256_add_epi32(ab[5], _mm256_mul_epi32(a_0123_4567, b_31_75)); + ab[6] = _mm256_add_epi32(ab[6], _mm256_mul_epi32(a_0123_4567, b_57_13)); + ab[7] = _mm256_add_epi32(ab[7], _mm256_mul_epi32(a_0123_4567, b_75_31)); + + if !is_last { + a = a.add(MR); + b = b.add(NR); + + a_0123_4567 = _mm256_loadu_si256(a as _); + b_0123_4567 = _mm256_loadu_si256(b as _); + } + }); + + let a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4 = _mm256_blend_epi32( + ab[0], + ab[1], + 0b_1100_1100 + ); + + let a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6 = _mm256_blend_epi32( + ab[0], + ab[1], + 0b_0011_0011 + ); + + let a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0 = _mm256_blend_epi32( + ab[2], + ab[3], + 0b_1100_1100 + ); + + let a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2 = _mm256_blend_epi32( + ab[2], + ab[3], + 0b_0011_0011 + ); + + // a0b0_a1b0_a2b0_a3b0_b4b0_b5b0_b6b0_b7b0 + ab[0] = _mm256_permute2f128_si256( + a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, + a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, + 0x30 + ); + // a0b4_a1b4_a2b4_a3b4_b4b4_b5b4_b6b4_b7b4 + ab[4] = _mm256_permute2f128_si256( + a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, + a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, + 0x12 + ); + // a0b2_a1b2_a2b2_a3b2_b4b2_b5b2_b6b2_b7b2 + ab[2] = _mm256_permute2f128_si256( + a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, + a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, + 0x30 + ); + // a0b6_a1b6_a2b6_a3b6_b4b6_b5b6_b6b6_b7b6 + ab[6] = _mm256_permute2f128_si256( + a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, + a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, + 0x12 + ); + + let a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5 = _mm256_blend_epi32( + ab[4], + ab[5], + 0b_1100_1100 + ); + + let a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7 = _mm256_blend_epi32( + ab[4], + ab[5], + 0b_0011_0011 + ); + + let a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1 = _mm256_blend_epi32( + ab[6], + ab[7], + 0b_1100_1100 + ); + + let a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3 = _mm256_blend_epi32( + ab[6], + ab[7], + 0b_0011_0011 + ); + + // a0b1_a1b1_a2b1_a3b1_b4b1_b5b1_b6b1_b7b1 + ab[1] = _mm256_permute2f128_si256( + a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, + a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, + 0x30 + ); + // a0b5_a1b5_a2b5_a3b5_b4b5_b5b5_b6b5_b7b5 + ab[5] = _mm256_permute2f128_si256( + a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, + a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, + 0x12 + ); + // a0b3_a1b3_a2b3_a3b3_b4b3_b5b3_b6b3_b7b3 + ab[3] = _mm256_permute2f128_si256( + a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, + a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, + 0x30 + ); + // a0b7_a1b7_a2b7_a3b7_b4b7_b5b7_b6b7_b7b7 + ab[7] = _mm256_permute2f128_si256( + a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, + a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, + 0x12 + ); + + // Compute α (A B) + let alpha_v = _mm256_set1_epi32(alpha); + loop_m!(i, ab[i] = _mm256_mul_epi32(alpha_v, ab[i])); + + macro_rules! c { + ($i:expr, $j:expr) => + (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // C ← α A B + β C + let mut cv = [_mm256_setzero_si256(); MR]; + + if beta != 0 { + let beta_v = _mm256_set1_epi32(beta); + + // Read C + if rsc == 1 { + loop_m!(i, cv[i] = _mm256_loadu_si256(c![0, i] as _)); + // } else if csc == 1 { + // loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0])); + // loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0])); + } else { + loop_m!(i, cv[i] = _mm256_setr_epi32( + *c![0, i], + *c![1, i], + *c![2, i], + *c![3, i], + *c![4, i], + *c![5, i], + *c![6, i], + *c![7, i], + )); + } + // Compute β C + loop_m!(i, cv[i] = _mm256_mul_epi32(cv[i], beta_v)); + } + + // Compute (α A B) + (β C) + loop_m!(i, cv[i] = _mm256_add_epi32(cv[i], ab[i])); + + if rsc == 1 { + loop_m!(i, _mm256_storeu_si256(c![0, i] as _, cv[i])); + // } else if csc == 1 { + // loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i])); + // loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4])); + } else { + // TODO: This inner unrolled loop should be replaced by + // `loop_n!(j, *c![i, j] = _mm256_extract_epi32(cv[i], j);` + // However, rustc currently errors with: + // > error: argument 2 is required to be a constant + // Some reading: + // + https://internals.rust-lang.org/t/pre-rfc-const-function-arguments/6709/12 + // + https://www.reddit.com/r/rust/comments/9pxuoj/simd_instructions_requiring_a_constant_parameter/ + loop_m!(i, { + *c![i, 0] = _mm256_extract_epi32(cv[i], 0); + *c![i, 1] = _mm256_extract_epi32(cv[i], 1); + *c![i, 2] = _mm256_extract_epi32(cv[i], 2); + *c![i, 3] = _mm256_extract_epi32(cv[i], 3); + *c![i, 4] = _mm256_extract_epi32(cv[i], 4); + *c![i, 5] = _mm256_extract_epi32(cv[i], 5); + *c![i, 6] = _mm256_extract_epi32(cv[i], 6); + *c![i, 7] = _mm256_extract_epi32(cv[i], 7); + }) + } +} + #[inline(always)] unsafe fn at(ptr: *const T, i: usize) -> T { *ptr.offset(i as isize) @@ -212,6 +481,7 @@ mod tests { #[cfg(any(target_arch="x86", target_arch="x86_64"))] test_arch_kernels_x86! { + "avx2", kernel_target_avx2, "avx", kernel_target_avx, "sse2", kernel_target_sse2 } From 0c19f4203999bef8c3d45843d1b9778906ab361c Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Sat, 8 Dec 2018 20:26:18 +0100 Subject: [PATCH 2/6] Generalize gemm to allow different input and output arrays --- src/dgemm_kernel.rs | 3 +- src/gemm.rs | 84 ++++++++++++++-------------- src/igemm_kernel.rs | 24 ++++---- src/kernel.rs | 132 +++++++++++++++++++++++++++++++------------- src/lib.rs | 4 +- src/sgemm_kernel.rs | 3 +- tests/sgemm.rs | 44 +++++++-------- 7 files changed, 174 insertions(+), 120 deletions(-) diff --git a/src/dgemm_kernel.rs b/src/dgemm_kernel.rs index 4f1e6a7..8c679d8 100644 --- a/src/dgemm_kernel.rs +++ b/src/dgemm_kernel.rs @@ -24,7 +24,8 @@ macro_rules! loop_n { } impl GemmKernel for Gemm { - type Elem = T; + type ElemIn = T; + type ElemOut = T; #[inline(always)] fn align_to() -> usize { 0 } diff --git a/src/gemm.rs b/src/gemm.rs index 4dc11bd..167ceb5 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -19,7 +19,7 @@ use kernel::GemmKernel; use kernel::Element; use sgemm_kernel; use dgemm_kernel; -use igemm_kernel; +// use igemm_kernel; use rawpointer::PointerExt; /// General matrix multiplication (f32) @@ -88,22 +88,22 @@ pub unsafe fn dgemm( c, rsc, csc) } -pub unsafe fn igemm( - m: usize, k: usize, n: usize, - alpha: i32, - a: *const i32, rsa: isize, csa: isize, - b: *const i32, rsb: isize, csb: isize, - beta: i32, - c: *mut i32, rsc: isize, csc: isize) -{ - gemm_loop::( - m, k, n, - alpha, - a, rsa, csa, - b, rsb, csb, - beta, - c, rsc, csc) -} +// pub unsafe fn igemm( +// m: usize, k: usize, n: usize, +// alpha: i32, +// a: *const i32, rsa: isize, csa: isize, +// b: *const i32, rsb: isize, csb: isize, +// beta: i32, +// c: *mut i32, rsc: isize, csc: isize) +// { +// gemm_loop::( +// m, k, n, +// alpha, +// a, rsa, csa, +// b, rsb, csb, +// beta, +// c, rsc, csc) +// } /// Ensure that GemmKernel parameters are supported /// (alignment, microkernel size). @@ -117,10 +117,10 @@ fn ensure_kernel_params() let nr = K::nr(); assert!(mr > 0 && mr <= 8); assert!(nr > 0 && nr <= 8); - assert!(mr * nr * size_of::() <= 8 * 4 * 8); + assert!(mr * nr * size_of::() <= 8 * 4 * 8); assert!(K::align_to() <= 32); // one row/col of the kernel is limiting the max align we can provide - let max_align = size_of::() * min(mr, nr); + let max_align = size_of::() * min(mr, nr); assert!(K::align_to() <= max_align); } @@ -128,11 +128,11 @@ fn ensure_kernel_params() /// strategy, the type parameter `K` is the gemm microkernel. unsafe fn gemm_loop( m: usize, k: usize, n: usize, - alpha: K::Elem, - a: *const K::Elem, rsa: isize, csa: isize, - b: *const K::Elem, rsb: isize, csb: isize, - beta: K::Elem, - c: *mut K::Elem, rsc: isize, csc: isize) + alpha: K::ElemOut, + a: *const K::ElemIn, rsa: isize, csa: isize, + b: *const K::ElemIn, rsb: isize, csb: isize, + beta: K::ElemOut, + c: *mut K::ElemOut, rsc: isize, csc: isize) where K: GemmKernel { debug_assert!(m <= 1 || n == 0 || rsc != 0); @@ -198,18 +198,18 @@ unsafe fn gemm_loop( /// + kc: columns of packed A / rows of packed B /// + mc: rows of packed A unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, - alpha: K::Elem, - app: *const K::Elem, bpp: *const K::Elem, - beta: K::Elem, - c: *mut K::Elem, rsc: isize, csc: isize) + alpha: K::ElemOut, + app: *const K::ElemIn, bpp: *const K::ElemIn, + beta: K::ElemOut, + c: *mut K::ElemOut, rsc: isize, csc: isize) where K: GemmKernel, { let mr = K::mr(); let nr = K::nr(); // make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment - assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); + assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); let mut mask_buf = [0u8; 256 + 31]; - let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::Elem; + let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::ElemOut; // LOOP 2: through micropanels in packed `b` for (l2, nr_) in range_chunk(nc, nr) { @@ -225,7 +225,7 @@ unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, // NOTE: For the rust kernels, it performs better to simply // always use the masked kernel function! if K::always_masked() || nr_ < nr || mr_ < mr { - masked_kernel::<_, K>(kc, alpha, &*app, &*bpp, + masked_kernel::<_, _, K>(kc, alpha, &*app, &*bpp, beta, &mut *c, rsc, csc, mr_, nr_, mask_ptr); continue; @@ -244,7 +244,7 @@ unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, /// we have rounded up to a multiple of the kernel size). /// /// Return packing buffer and offset to start of b -unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc, usize) +unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc, usize) where K: GemmKernel, { // max alignment requirement is a multiple of min(MR, NR) * sizeof @@ -349,19 +349,21 @@ unsafe fn pack(kc: usize, mc: usize, mr: usize, pack: *mut T, /// + rows: rows of kernel unmasked /// + cols: cols of kernel unmasked #[inline(never)] -unsafe fn masked_kernel(k: usize, alpha: T, - a: *const T, - b: *const T, - beta: T, - c: *mut T, rsc: isize, csc: isize, +unsafe fn masked_kernel(k: usize, alpha: Tout, + a: *const Tin, + b: *const Tin, + beta: Tout, + c: *mut Tout, rsc: isize, csc: isize, rows: usize, cols: usize, - mask_buf: *mut T) - where K: GemmKernel, T: Element, + mask_buf: *mut Tout) + where K: GemmKernel, + Tin: Element, + Tout: Element, { let mr = K::mr(); let nr = K::nr(); // use column major order for `mask_buf` - K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize); + K::kernel(k, Tout::one(), a, b, Tout::zero(), mask_buf, 1, mr as isize); let mut ab = mask_buf; for j in 0..nr { for i in 0..mr { @@ -369,7 +371,7 @@ unsafe fn masked_kernel(k: usize, alpha: T, let cptr = c.stride_offset(rsc, i) .stride_offset(csc, j); if beta.is_zero() { - *cptr = T::zero(); // initialize C + *cptr = Tout::zero(); // initialize C } else { (*cptr).scale_by(beta); } diff --git a/src/igemm_kernel.rs b/src/igemm_kernel.rs index 28e7505..26a7bf6 100644 --- a/src/igemm_kernel.rs +++ b/src/igemm_kernel.rs @@ -17,7 +17,8 @@ use std::arch::x86_64::*; pub enum Gemm { } -pub type T = i32; +pub type Tin = i8; +pub type Tout = i16 const MR: usize = 8; const NR: usize = 8; @@ -154,13 +155,8 @@ unsafe fn kernel_x86_avx2(k: usize, alpha: T, a: *const T, b: *const T, let (mut a, mut b) = (a, b); - println!("a_slice: {:?}", std::slice::from_raw_parts(a, 8)); let mut a_0123_4567 = _mm256_loadu_si256(a as *const __m256i); - println!("a_simd: {:?}", a_0123_4567); - - println!("b_slice: {:?}", std::slice::from_raw_parts(b, 8)); let mut b_0123_4567 = _mm256_loadu_si256(b as *const __m256i); - println!("b_simd: {:?}", b_0123_4567); // The task in the loop below is to multiply every number packed in b_0123_4567 with // every number in the two a_* vectors. With a_* loading a column, and b_* loading a row, @@ -263,25 +259,25 @@ unsafe fn kernel_x86_avx2(k: usize, alpha: T, a: *const T, b: *const T, ); // a0b0_a1b0_a2b0_a3b0_b4b0_b5b0_b6b0_b7b0 - ab[0] = _mm256_permute2f128_si256( + ab[0] = _mm256_permute2x128_si256( a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, 0x30 ); // a0b4_a1b4_a2b4_a3b4_b4b4_b5b4_b6b4_b7b4 - ab[4] = _mm256_permute2f128_si256( + ab[4] = _mm256_permute2x128_si256( a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, 0x12 ); // a0b2_a1b2_a2b2_a3b2_b4b2_b5b2_b6b2_b7b2 - ab[2] = _mm256_permute2f128_si256( + ab[2] = _mm256_permute2x128_si256( a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, 0x30 ); // a0b6_a1b6_a2b6_a3b6_b4b6_b5b6_b6b6_b7b6 - ab[6] = _mm256_permute2f128_si256( + ab[6] = _mm256_permute2x128_si256( a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, 0x12 @@ -312,25 +308,25 @@ unsafe fn kernel_x86_avx2(k: usize, alpha: T, a: *const T, b: *const T, ); // a0b1_a1b1_a2b1_a3b1_b4b1_b5b1_b6b1_b7b1 - ab[1] = _mm256_permute2f128_si256( + ab[1] = _mm256_permute2x128_si256( a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, 0x30 ); // a0b5_a1b5_a2b5_a3b5_b4b5_b5b5_b6b5_b7b5 - ab[5] = _mm256_permute2f128_si256( + ab[5] = _mm256_permute2x128_si256( a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, 0x12 ); // a0b3_a1b3_a2b3_a3b3_b4b3_b5b3_b6b3_b7b3 - ab[3] = _mm256_permute2f128_si256( + ab[3] = _mm256_permute2x128_si256( a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, 0x30 ); // a0b7_a1b7_a2b7_a3b7_b4b7_b5b7_b6b7_b7b7 - ab[7] = _mm256_permute2f128_si256( + ab[7] = _mm256_permute2x128_si256( a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, 0x12 diff --git a/src/kernel.rs b/src/kernel.rs index 801b3d9..32345db 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -8,7 +8,8 @@ /// General matrix multiply kernel pub trait GemmKernel { - type Elem: Element; + type ElemIn: Element; + type ElemOut: Element; /// align inputs to this fn align_to() -> usize; @@ -41,11 +42,11 @@ pub trait GemmKernel { /// + if `beta` is `0.`, then c does not need to be initialized unsafe fn kernel( k: usize, - alpha: Self::Elem, - a: *const Self::Elem, - b: *const Self::Elem, - beta: Self::Elem, - c: *mut Self::Elem, rsc: isize, csc: isize); + alpha: Self::ElemOut, + a: *const Self::ElemIn, + b: *const Self::ElemIn, + beta: Self::ElemOut, + c: *mut Self::ElemOut, rsc: isize, csc: isize); } pub trait Element : Copy { @@ -56,38 +57,91 @@ pub trait Element : Copy { fn scaled_add(&mut self, alpha: Self, a: Self); } -impl Element for f32 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn is_zero(&self) -> bool { *self == 0. } - fn scale_by(&mut self, x: Self) { - *self *= x; - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self += alpha * a; - } -} +// impl Element for f32 { +// fn zero() -> Self { 0. } +// fn one() -> Self { 1. } +// fn is_zero(&self) -> bool { *self == 0. } +// fn scale_by(&mut self, x: Self) { +// *self *= x; +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self += alpha * a; +// } +// } -impl Element for f64 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn is_zero(&self) -> bool { *self == 0. } - fn scale_by(&mut self, x: Self) { - *self *= x; - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self += alpha * a; - } -} +// impl Element for f64 { +// fn zero() -> Self { 0. } +// fn one() -> Self { 1. } +// fn is_zero(&self) -> bool { *self == 0. } +// fn scale_by(&mut self, x: Self) { +// *self *= x; +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self += alpha * a; +// } +// } -impl Element for i32 { - fn zero() -> Self { 0 } - fn one() -> Self { 1 } - fn is_zero(&self) -> bool { *self == 0 } - fn scale_by(&mut self, x: Self) { - *self = self.wrapping_mul(x); - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self = self.wrapping_add(alpha.wrapping_mul(a)); - } -} +// impl Element for i32 { +// fn zero() -> Self { 0 } +// fn one() -> Self { 1 } +// fn is_zero(&self) -> bool { *self == 0 } +// fn scale_by(&mut self, x: Self) { +// *self = self.wrapping_mul(x); +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self = self.wrapping_add(alpha.wrapping_mul(a)); +// } +// } + +// impl Element for i32 { +// fn zero() -> Self { 0 } +// fn one() -> Self { 1 } +// fn is_zero(&self) -> bool { *self == 0 } +// fn scale_by(&mut self, x: Self) { +// *self = self.wrapping_mul(x); +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self = self.wrapping_add(alpha.wrapping_mul(a)); +// } +// } + +macro_rules! impl_element_f { + ($($t:ty),+) => { + $( + impl Element for $t { + fn zero() -> Self { 0.0 } + fn one() -> Self { 1.0 } + fn is_zero(&self) -> bool { *self == 0.0 } + fn scale_by(&mut self, x: Self) { + // TODO: Change the semantics + *self *= x; + } + // TODO: Change the semantics + fn scaled_add(&mut self, alpha: Self, a: Self) { + *self += alpha * a; + } + } + )+ +};} + +macro_rules! impl_element_i { + ($($t:ty),+) => { + $( + impl Element for $t { + fn zero() -> Self { 0 } + fn one() -> Self { 1 } + fn is_zero(&self) -> bool { *self == 0 } + fn scale_by(&mut self, x: Self) { + // TODO: Change the semantics + *self = self.wrapping_mul(x); + } + // TODO: Change the semantics + fn scaled_add(&mut self, alpha: Self, a: Self) { + *self = self.wrapping_add(alpha.wrapping_mul(a)); + } + } + )+ +};} + +impl_element_f!(f32, f64); +impl_element_i!(i8, i16, i32); diff --git a/src/lib.rs b/src/lib.rs index f263bd8..331e3ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,10 +62,10 @@ mod kernel; mod gemm; mod sgemm_kernel; mod dgemm_kernel; -mod igemm_kernel; +// mod igemm_kernel; mod util; mod aligned_alloc; pub use gemm::sgemm; pub use gemm::dgemm; -pub use gemm::igemm; +// pub use gemm::igemm; diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 6064869..be05bfb 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -26,7 +26,8 @@ macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; } impl GemmKernel for Gemm { - type Elem = T; + type ElemIn = T; + type ElemOut = T; #[inline(always)] fn align_to() -> usize { 32 } diff --git a/tests/sgemm.rs b/tests/sgemm.rs index b4314a0..5dc996b 100644 --- a/tests/sgemm.rs +++ b/tests/sgemm.rs @@ -1,7 +1,7 @@ extern crate itertools; extern crate matrixmultiply; -use matrixmultiply::{sgemm, dgemm, igemm}; +use matrixmultiply::{sgemm, dgemm}; //, igemm}; use itertools::Itertools; use itertools::{ @@ -71,23 +71,23 @@ impl Gemm for f32 { } } -impl Gemm for i32 { - unsafe fn gemm( - m: usize, k: usize, n: usize, - alpha: Self, - a: *const Self, rsa: isize, csa: isize, - b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - igemm( - m, k, n, - alpha, - a, rsa, csa, - b, rsb, csb, - beta, - c, rsc, csc) - } -} +// impl Gemm for i32 { +// unsafe fn gemm( +// m: usize, k: usize, n: usize, +// alpha: Self, +// a: *const Self, rsa: isize, csa: isize, +// b: *const Self, rsb: isize, csb: isize, +// beta: Self, +// c: *mut Self, rsc: isize, csc: isize) { +// igemm( +// m, k, n, +// alpha, +// a, rsa, csa, +// b, rsb, csb, +// beta, +// c, rsc, csc) +// } +// } impl Gemm for f64 { unsafe fn gemm( @@ -124,10 +124,10 @@ fn test_dgemm_strides() { test_gemm_strides::(); } -#[test] -fn test_i32gemm_strides() { - test_gemm_strides::(); -} +// #[test] +// fn test_i32gemm_strides() { +// test_gemm_strides::(); +// } fn test_gemm_strides() where F: Gemm + Float { for n in 0..10 { From 86d1f5cd3e43d119d2631cb6c6602d300b2febba Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Sun, 9 Dec 2018 12:10:49 +0100 Subject: [PATCH 3/6] Remove kernel for i32, implement i8 --- src/gemm.rs | 34 +-- src/i8gemm_kernel.rs | 568 +++++++++++++++++++++++++++++++++++++++++++ src/igemm_kernel.rs | 485 ------------------------------------ src/kernel.rs | 4 +- src/lib.rs | 4 +- src/loopmacros.rs | 74 ++++++ 6 files changed, 663 insertions(+), 506 deletions(-) create mode 100644 src/i8gemm_kernel.rs delete mode 100644 src/igemm_kernel.rs diff --git a/src/gemm.rs b/src/gemm.rs index 167ceb5..2ab91ec 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -19,7 +19,7 @@ use kernel::GemmKernel; use kernel::Element; use sgemm_kernel; use dgemm_kernel; -// use igemm_kernel; +use i8gemm_kernel; use rawpointer::PointerExt; /// General matrix multiplication (f32) @@ -88,22 +88,22 @@ pub unsafe fn dgemm( c, rsc, csc) } -// pub unsafe fn igemm( -// m: usize, k: usize, n: usize, -// alpha: i32, -// a: *const i32, rsa: isize, csa: isize, -// b: *const i32, rsb: isize, csb: isize, -// beta: i32, -// c: *mut i32, rsc: isize, csc: isize) -// { -// gemm_loop::( -// m, k, n, -// alpha, -// a, rsa, csa, -// b, rsb, csb, -// beta, -// c, rsc, csc) -// } +pub unsafe fn i8gemm( + m: usize, k: usize, n: usize, + alpha: i16, + a: *const i8, rsa: isize, csa: isize, + b: *const i8, rsb: isize, csb: isize, + beta: i16, + c: *mut i16, rsc: isize, csc: isize) +{ + gemm_loop::( + m, k, n, + alpha, + a, rsa, csa, + b, rsb, csb, + beta, + c, rsc, csc) +} /// Ensure that GemmKernel parameters are supported /// (alignment, microkernel size). diff --git a/src/i8gemm_kernel.rs b/src/i8gemm_kernel.rs new file mode 100644 index 0000000..3ccc584 --- /dev/null +++ b/src/i8gemm_kernel.rs @@ -0,0 +1,568 @@ +// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use kernel::GemmKernel; +use archparam; + + +#[cfg(target_arch="x86")] +use std::arch::x86::*; +#[cfg(target_arch="x86_64")] +use std::arch::x86_64::*; + +pub enum Gemm { } + +pub type Tin = i8; +pub type Tout = i16; + +const MR: usize = 16; +const NR: usize = 32; + +macro_rules! loop_m { ($i:ident, $e:expr) => { loop16!($i, $e) }; } +macro_rules! loop_n { ($j:ident, $e:expr) => { loop32!($j, $e) }; } + +impl GemmKernel for Gemm { + type ElemIn = Tin; + type ElemOut = Tout; + + #[inline(always)] + fn align_to() -> usize { 16 } + + #[inline(always)] + fn mr() -> usize { MR } + #[inline(always)] + fn nr() -> usize { NR } + + #[inline(always)] + fn always_masked() -> bool { true } + + #[inline(always)] + fn nc() -> usize { archparam::S_NC } + #[inline(always)] + fn kc() -> usize { archparam::S_KC } + #[inline(always)] + fn mc() -> usize { archparam::S_MC } + + #[inline(always)] + unsafe fn kernel( + k: usize, + alpha: Tout, + a: *const Tin, + b: *const Tin, + beta: Tout, + c: *mut Tout, rsc: isize, csc: isize) { + kernel(k, alpha, a, b, beta, c, rsc, csc) + } +} + +/// Multiply two 128-bit vectors of 16 8-bit integers each,by sign-extending them to 256-bit +/// vectors of 16-bit integers, and then multiplying these temporaries. +#[inline(always)] +unsafe fn _mm256_mulepi8_epi16(a: __m128i, b: __m128i) -> __m256i +{ + let tmp0 = _mm256_cvtepi8_epi16(a); + let tmp1 = _mm256_cvtepi8_epi16(b); + + _mm256_mullo_epi16(tmp0, tmp1) +} + +/// matrix multiplication kernel +/// +/// This does the matrix multiplication: +/// +/// C ← α A B + β C +/// +/// + k: length of data in a, b +/// + a, b are packed +/// + c has general strides +/// + rsc: row stride of c +/// + csc: col stride of c +/// + if beta is 0, then c does not need to be initialized +#[inline(never)] +pub unsafe fn kernel(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + // dispatch to specific compiled versions + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + { + if is_x86_feature_detected_!("avx2") { + return kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("avx") { + return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("sse2") { + return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); + } + } + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); +} + +#[inline] +#[target_feature(enable="avx2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_x86_avx2(k, alpha, a, b, beta, c, rsc, csc) +} + +#[inline] +#[target_feature(enable="avx")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + +#[inline] +#[target_feature(enable="sse2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_sse2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + + +#[inline(always)] +unsafe fn kernel_fallback_impl(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + let mut ab: [[Tout; NR]; MR] = [[0; NR]; MR]; + let mut a = a; + let mut b = b; + debug_assert_eq!(beta, 0); + + // Compute A B into ab[i][j] + unroll_by!(4 => k, { + loop_m!(i, loop_n!(j, { + ab[i][j] = ab[i][j].saturating_add( + (at(a, i) as i16) + .saturating_mul( + at(b, j) as i16 + ));})); + + a = a.offset(MR as isize); + b = b.offset(NR as isize); + }); + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // set C = α A B + β C + loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); +} + +#[inline(always)] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_x86_avx2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + debug_assert_ne!(k, 0); + + let mut ab = [_mm256_setzero_si256(); NR]; + + let (mut a, mut b) = (a, b); + + let mut a_col = _mm_loadu_si128(a as *const __m128i); + + // Load two rows from b at a time. + let mut b_row = _mm256_loadu_si256(b as *const __m256i); + + // FIXME: Is this k a meaningful number in this context? + unroll_by_with_last!(4 => k, is_last, { + let b0_b16 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ) + ); + + let b0 = _mm256_extracti128_si256(b0_b16, 0); + let b16 = _mm256_extracti128_si256(b0_b16, 1); + + let b1_b17 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + ) + ); + + let b1 = _mm256_extracti128_si256(b1_b17, 0); + let b17 = _mm256_extracti128_si256(b1_b17, 1); + + let b2_b18 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, + ) + ); + + let b2 = _mm256_extracti128_si256(b2_b18, 0); + let b18 = _mm256_extracti128_si256(b2_b18, 1); + + let b3_b19 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + ) + ); + + let b3 = _mm256_extracti128_si256(b3_b19, 0); + let b19 = _mm256_extracti128_si256(b3_b19, 1); + + let b4_b20 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, + 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, + ) + ); + + let b4 = _mm256_extracti128_si256(b4_b20, 0); + let b20 = _mm256_extracti128_si256(b4_b20, 1); + + let b5_b21 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, + 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, + ) + ); + + let b5 = _mm256_extracti128_si256(b5_b21, 0); + let b21 = _mm256_extracti128_si256(b5_b21, 1); + + let b6_b22 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, + 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, + ) + ); + + let b6 = _mm256_extracti128_si256(b6_b22, 0); + let b22 = _mm256_extracti128_si256(b6_b22, 1); + + let b7_b23 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, + 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, + ) + ); + + let b7 = _mm256_extracti128_si256(b7_b23, 0); + let b23 = _mm256_extracti128_si256(b7_b23, 1); + + let b8_b24 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, + 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, + ) + ); + + let b8 = _mm256_extracti128_si256(b8_b24, 0); + let b24 = _mm256_extracti128_si256(b8_b24, 1); + + let b9_b25 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, + 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, + ) + ); + + let b9 = _mm256_extracti128_si256(b9_b25, 0); + let b25 = _mm256_extracti128_si256(b9_b25, 1); + + let b10_b26 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, + 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, + ) + ); + + let b10 = _mm256_extracti128_si256(b10_b26, 0); + let b26 = _mm256_extracti128_si256(b10_b26, 1); + + let b11_b27 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, + 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, + ) + ); + + let b11 = _mm256_extracti128_si256(b11_b27, 0); + let b27 = _mm256_extracti128_si256(b11_b27, 1); + + let b12_b28 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, + 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, + ) + ); + + let b12 = _mm256_extracti128_si256(b12_b28, 0); + let b28 = _mm256_extracti128_si256(b12_b28, 1); + + let b13_b29 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, + 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, + ) + ); + + let b13 = _mm256_extracti128_si256(b13_b29, 0); + let b29 = _mm256_extracti128_si256(b13_b29, 1); + + let b14_b30 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, + 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, + ) + ); + + let b14 = _mm256_extracti128_si256(b14_b30, 0); + let b30 = _mm256_extracti128_si256(b14_b30, 1); + + let b15_b31 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, + 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, + ) + ); + + let b15 = _mm256_extracti128_si256(b15_b31, 0); + let b31 = _mm256_extracti128_si256(b15_b31, 1); + + // Multiplication and addition with the first row. + ab[0] = _mm256_adds_epi16(ab[0], _mm256_mulepi8_epi16(a_col, b0)); + ab[1] = _mm256_adds_epi16(ab[1], _mm256_mulepi8_epi16(a_col, b1)); + ab[2] = _mm256_adds_epi16(ab[2], _mm256_mulepi8_epi16(a_col, b2)); + ab[3] = _mm256_adds_epi16(ab[3], _mm256_mulepi8_epi16(a_col, b3)); + + ab[4] = _mm256_adds_epi16(ab[4], _mm256_mulepi8_epi16(a_col, b4)); + ab[5] = _mm256_adds_epi16(ab[5], _mm256_mulepi8_epi16(a_col, b5)); + ab[6] = _mm256_adds_epi16(ab[6], _mm256_mulepi8_epi16(a_col, b6)); + ab[7] = _mm256_adds_epi16(ab[7], _mm256_mulepi8_epi16(a_col, b7)); + + ab[8] = _mm256_adds_epi16(ab[8], _mm256_mulepi8_epi16(a_col, b8)); + ab[9] = _mm256_adds_epi16(ab[9], _mm256_mulepi8_epi16(a_col, b9)); + ab[10] = _mm256_adds_epi16(ab[10], _mm256_mulepi8_epi16(a_col, b10)); + ab[11] = _mm256_adds_epi16(ab[11], _mm256_mulepi8_epi16(a_col, b11)); + + ab[12] = _mm256_adds_epi16(ab[12], _mm256_mulepi8_epi16(a_col, b12)); + ab[13] = _mm256_adds_epi16(ab[13], _mm256_mulepi8_epi16(a_col, b13)); + ab[14] = _mm256_adds_epi16(ab[14], _mm256_mulepi8_epi16(a_col, b14)); + ab[15] = _mm256_adds_epi16(ab[15], _mm256_mulepi8_epi16(a_col, b15)); + + // Multiplication and addition with the second row.); + ab[16] = _mm256_adds_epi16(ab[0], _mm256_mulepi8_epi16(a_col, b16)); + ab[17] = _mm256_adds_epi16(ab[1], _mm256_mulepi8_epi16(a_col, b17)); + ab[18] = _mm256_adds_epi16(ab[2], _mm256_mulepi8_epi16(a_col, b18)); + ab[19] = _mm256_adds_epi16(ab[3], _mm256_mulepi8_epi16(a_col, b19)); + + ab[20] = _mm256_adds_epi16(ab[4], _mm256_mulepi8_epi16(a_col, b20)); + ab[21] = _mm256_adds_epi16(ab[5], _mm256_mulepi8_epi16(a_col, b21)); + ab[22] = _mm256_adds_epi16(ab[6], _mm256_mulepi8_epi16(a_col, b22)); + ab[23] = _mm256_adds_epi16(ab[7], _mm256_mulepi8_epi16(a_col, b23)); + + ab[24] = _mm256_adds_epi16(ab[8], _mm256_mulepi8_epi16(a_col, b24)); + ab[25] = _mm256_adds_epi16(ab[9], _mm256_mulepi8_epi16(a_col, b25)); + ab[26] = _mm256_adds_epi16(ab[10], _mm256_mulepi8_epi16(a_col, b26)); + ab[27] = _mm256_adds_epi16(ab[11], _mm256_mulepi8_epi16(a_col, b27)); + + ab[28] = _mm256_adds_epi16(ab[12], _mm256_mulepi8_epi16(a_col, b28)); + ab[29] = _mm256_adds_epi16(ab[13], _mm256_mulepi8_epi16(a_col, b29)); + ab[30] = _mm256_adds_epi16(ab[14], _mm256_mulepi8_epi16(a_col, b30)); + ab[31] = _mm256_adds_epi16(ab[15], _mm256_mulepi8_epi16(a_col, b31)); + + if !is_last { + a = a.add(MR); + b = b.add(NR); + + a_col = _mm_loadu_si128(a as _); + b_row = _mm256_loadu_si256(b as _); + } + }); + + // Compute α (A B) + let alpha_v = _mm256_set1_epi16(alpha); + loop_m!(i, ab[i] = _mm256_mullo_epi16(alpha_v, ab[i])); + + macro_rules! c { + ($i:expr, $j:expr) => + (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // C ← α A B + β C + let mut cv = [_mm256_setzero_si256(); MR]; + + if beta != 0 { + let beta_v = _mm256_set1_epi16(beta); + + // Read C + if rsc == 1 { + loop_m!(i, cv[i] = _mm256_loadu_si256(c![0, i] as _)); + // } else if csc == 1 { + // loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0])); + // loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0])); + } else { + loop_m!(i, cv[i] = + _mm256_setr_epi16( + *c![0, i], + *c![1, i], + *c![2, i], + *c![3, i], + *c![4, i], + *c![5, i], + *c![6, i], + *c![7, i], + *c![8, i], + *c![9, i], + *c![10, i], + *c![11, i], + *c![12, i], + *c![13, i], + *c![14, i], + *c![15, i], + )); + } + // Compute β C + loop_m!(i, cv[i] = _mm256_mullo_epi16(cv[i], beta_v)); + } + + // Compute (α A B) + (β C) + loop_m!(i, cv[i] = _mm256_add_epi32(cv[i], ab[i])); + + if rsc == 1 { + loop_m!(i, _mm256_storeu_si256(c![0, i] as _, cv[i])); + // } else if csc == 1 { + // loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i])); + // loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4])); + } else { + // TODO: This inner unrolled loop should be replaced by + // `loop_n!(j, *c![i, j] = _mm256_extract_epi32(cv[i], j);` + // However, rustc currently errors with: + // > error: argument 2 is required to be a constant + // Some reading: + // + https://internals.rust-lang.org/t/pre-rfc-const-function-arguments/6709/12 + // + https://www.reddit.com/r/rust/comments/9pxuoj/simd_instructions_requiring_a_constant_parameter/ + loop_m!(i, { + *c![i, 0] = _mm256_extract_epi16(cv[i], 0); + *c![i, 1] = _mm256_extract_epi16(cv[i], 1); + *c![i, 2] = _mm256_extract_epi16(cv[i], 2); + *c![i, 3] = _mm256_extract_epi16(cv[i], 3); + *c![i, 4] = _mm256_extract_epi16(cv[i], 4); + *c![i, 5] = _mm256_extract_epi16(cv[i], 5); + *c![i, 6] = _mm256_extract_epi16(cv[i], 6); + *c![i, 7] = _mm256_extract_epi16(cv[i], 7); + }) + } +} + +#[inline(always)] +unsafe fn at(ptr: *const Tin, i: usize) -> Tin { + *ptr.offset(i as isize) +} + +#[cfg(test)] +mod tests { + use super::*; + use aligned_alloc::Alloc; + + fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy + { + unsafe { + Alloc::new(n, Gemm::align_to()).init_with(elt) + } + } + + use super::Tin; + use super::Tout; + type KernelFn = unsafe fn(usize, Tout, *const Tin, *const Tin, Tout, *mut Tout, isize, isize); + + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { + const K: usize = 4; + let mut a = aligned_alloc(1, MR * K); + let mut b = aligned_alloc(0, NR * K); + for (i, x) in a.iter_mut().enumerate() { + *x = i as _; + } + + for i in 0..K { + b[i + i * NR] = 1; + } + let mut c = [0; MR * NR]; + unsafe { + kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); + // col major C + } + let a: Vec<_> = a.iter().map(|x| *x as i16).collect(); + assert_eq!(&a[..], &c[..a.len()]); + } + + #[test] + fn test_native_kernel() { + test_a_kernel("kernel", kernel); + } + + #[test] + fn test_kernel_fallback_impl() { + test_a_kernel("kernel", kernel_fallback_impl); + } + + #[test] + fn test_loop_m_n() { + let mut m = [[0; NR]; MR]; + loop_m!(i, loop_n!(j, m[i][j] += 1)); + for arr in &m[..] { + for elt in &arr[..] { + assert_eq!(*elt, 1); + } + } + } + + mod test_arch_kernels { + use super::test_a_kernel; + macro_rules! test_arch_kernels_x86 { + ($($feature_name:tt, $function_name:ident),*) => { + $( + #[test] + fn $function_name() { + if is_x86_feature_detected_!($feature_name) { + test_a_kernel(stringify!($function_name), super::super::$function_name); + } else { + println!("Skipping, host does not have feature: {:?}", $feature_name); + } + } + )* + } + } + + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + test_arch_kernels_x86! { + "avx2", kernel_target_avx2, + "avx", kernel_target_avx, + "sse2", kernel_target_sse2 + } + } +} diff --git a/src/igemm_kernel.rs b/src/igemm_kernel.rs deleted file mode 100644 index 26a7bf6..0000000 --- a/src/igemm_kernel.rs +++ /dev/null @@ -1,485 +0,0 @@ -// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use kernel::GemmKernel; -use archparam; - - -#[cfg(target_arch="x86")] -use std::arch::x86::*; -#[cfg(target_arch="x86_64")] -use std::arch::x86_64::*; - -pub enum Gemm { } - -pub type Tin = i8; -pub type Tout = i16 - -const MR: usize = 8; -const NR: usize = 8; - -macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } -macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } - -impl GemmKernel for Gemm { - type Elem = T; - - #[inline(always)] - fn align_to() -> usize { 16 } - - #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } - - #[inline(always)] - fn always_masked() -> bool { true } - - #[inline(always)] - fn nc() -> usize { archparam::S_NC } - #[inline(always)] - fn kc() -> usize { archparam::S_KC } - #[inline(always)] - fn mc() -> usize { archparam::S_MC } - - #[inline(always)] - unsafe fn kernel( - k: usize, - alpha: T, - a: *const T, - b: *const T, - beta: T, - c: *mut T, rsc: isize, csc: isize) { - kernel(k, alpha, a, b, beta, c, rsc, csc) - } -} - -/// matrix multiplication kernel -/// -/// This does the matrix multiplication: -/// -/// C ← α A B + β C -/// -/// + k: length of data in a, b -/// + a, b are packed -/// + c has general strides -/// + rsc: row stride of c -/// + csc: col stride of c -/// + if beta is 0, then c does not need to be initialized -#[inline(never)] -pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - // dispatch to specific compiled versions - #[cfg(any(target_arch="x86", target_arch="x86_64"))] - { - if is_x86_feature_detected_!("avx2") { - return kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc); - } else if is_x86_feature_detected_!("avx") { - return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); - } else if is_x86_feature_detected_!("sse2") { - return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); - } - } - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); -} - -#[inline] -#[target_feature(enable="avx2")] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_target_avx2(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - kernel_x86_avx2(k, alpha, a, b, beta, c, rsc, csc) -} - -#[inline] -#[target_feature(enable="avx")] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) -} - -#[inline] -#[target_feature(enable="sse2")] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) -} - - -#[inline(always)] -unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - let mut ab: [[T; NR]; MR] = [[0; NR]; MR]; - let mut a = a; - let mut b = b; - debug_assert_eq!(beta, 0); - - // Compute A B into ab[i][j] - unroll_by!(4 => k, { - loop_m!(i, loop_n!(j, { - ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j))); - })); - - a = a.offset(MR as isize); - b = b.offset(NR as isize); - }); - - macro_rules! c { - ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); - } - - // set C = α A B + β C - loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); -} - -#[inline(always)] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_x86_avx2(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - debug_assert_ne!(k, 0); - - let mut ab = [_mm256_setzero_si256(); MR]; - - let (mut a, mut b) = (a, b); - - let mut a_0123_4567 = _mm256_loadu_si256(a as *const __m256i); - let mut b_0123_4567 = _mm256_loadu_si256(b as *const __m256i); - - // The task in the loop below is to multiply every number packed in b_0123_4567 with - // every number in the two a_* vectors. With a_* loading a column, and b_* loading a row, - // this is exactly equivalent to matrix multiplication defined by C_ij = ∑_k A_ik B_kj + C_ij, but - // where you fix k, perform all multiplications in i and j for a fixed k, add the result to - // C, and then increment k to and repeat. - // - // Reapeted indices are shortened, e.g. b_02_46 = b_0022_4466 - - unroll_by_with_last!(4 => k, is_last, { - // two bits select the value of one i32, so 4 bits select two adjacent i32. 4 bits can be - // represented in hexadecimal notation. The first (from back to font) 4 pairs of bits - // select the first 128bit lane, the last 4 pairs the second lane. - let b_02_46 = _mm256_shuffle_epi32( - b_0123_4567, - 0b_10_10_00_00__10_10_00_00 - ); - - let b_20_64 = _mm256_shuffle_epi32( - b_0123_4567, - 0b_00_00_10_10__00_00_10_10 - ); - - let b_46_02 = _mm256_permute2x128_si256( - b_02_46, - b_02_46, - 0x03 - ); - - let b_64_20 = _mm256_permute2x128_si256( - b_20_64, - b_20_64, - 0x03 - ); - - let b_13_57 = _mm256_shuffle_epi32( - b_0123_4567, - 0b_11_11_01_01__11_11_01_01 - ); - - let b_31_75 = _mm256_shuffle_epi32( - b_0123_4567, - 0b_01_01_11_11__01_01_11_11 - ); - - let b_57_13 = _mm256_permute2x128_si256( - b_13_57, - b_13_57, - 0x03 - ); - - let b_75_31 = _mm256_permute2x128_si256( - b_31_75, - b_31_75, - 0x03 - ); - - // Add and multiply in one go - ab[0] = _mm256_add_epi32(ab[0], _mm256_mul_epi32(a_0123_4567, b_02_46)); - ab[1] = _mm256_add_epi32(ab[1], _mm256_mul_epi32(a_0123_4567, b_20_64)); - ab[2] = _mm256_add_epi32(ab[2], _mm256_mul_epi32(a_0123_4567, b_46_02)); - ab[3] = _mm256_add_epi32(ab[3], _mm256_mul_epi32(a_0123_4567, b_64_20)); - - ab[4] = _mm256_add_epi32(ab[4], _mm256_mul_epi32(a_0123_4567, b_13_57)); - ab[5] = _mm256_add_epi32(ab[5], _mm256_mul_epi32(a_0123_4567, b_31_75)); - ab[6] = _mm256_add_epi32(ab[6], _mm256_mul_epi32(a_0123_4567, b_57_13)); - ab[7] = _mm256_add_epi32(ab[7], _mm256_mul_epi32(a_0123_4567, b_75_31)); - - if !is_last { - a = a.add(MR); - b = b.add(NR); - - a_0123_4567 = _mm256_loadu_si256(a as _); - b_0123_4567 = _mm256_loadu_si256(b as _); - } - }); - - let a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4 = _mm256_blend_epi32( - ab[0], - ab[1], - 0b_1100_1100 - ); - - let a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6 = _mm256_blend_epi32( - ab[0], - ab[1], - 0b_0011_0011 - ); - - let a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0 = _mm256_blend_epi32( - ab[2], - ab[3], - 0b_1100_1100 - ); - - let a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2 = _mm256_blend_epi32( - ab[2], - ab[3], - 0b_0011_0011 - ); - - // a0b0_a1b0_a2b0_a3b0_b4b0_b5b0_b6b0_b7b0 - ab[0] = _mm256_permute2x128_si256( - a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, - a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, - 0x30 - ); - // a0b4_a1b4_a2b4_a3b4_b4b4_b5b4_b6b4_b7b4 - ab[4] = _mm256_permute2x128_si256( - a0b0_a1b0_a2b0_a3b0_a4b4_a5b4_a6b4_a7b4, - a0b4_a1b4_a2b4_a3b4_a4b0_a5b0_a6b0_a7b0, - 0x12 - ); - // a0b2_a1b2_a2b2_a3b2_b4b2_b5b2_b6b2_b7b2 - ab[2] = _mm256_permute2x128_si256( - a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, - a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, - 0x30 - ); - // a0b6_a1b6_a2b6_a3b6_b4b6_b5b6_b6b6_b7b6 - ab[6] = _mm256_permute2x128_si256( - a0b2_a1b2_a2b2_a3b2_a4b6_a5b6_a6b6_a7b6, - a0b6_a1b6_a2b6_a3b6_a4b2_a5b2_a6b2_a7b2, - 0x12 - ); - - let a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5 = _mm256_blend_epi32( - ab[4], - ab[5], - 0b_1100_1100 - ); - - let a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7 = _mm256_blend_epi32( - ab[4], - ab[5], - 0b_0011_0011 - ); - - let a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1 = _mm256_blend_epi32( - ab[6], - ab[7], - 0b_1100_1100 - ); - - let a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3 = _mm256_blend_epi32( - ab[6], - ab[7], - 0b_0011_0011 - ); - - // a0b1_a1b1_a2b1_a3b1_b4b1_b5b1_b6b1_b7b1 - ab[1] = _mm256_permute2x128_si256( - a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, - a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, - 0x30 - ); - // a0b5_a1b5_a2b5_a3b5_b4b5_b5b5_b6b5_b7b5 - ab[5] = _mm256_permute2x128_si256( - a0b1_a1b1_a2b1_a3b1_a4b5_a5b5_a6b5_a7b5, - a0b5_a1b5_a2b5_a3b5_a4b1_a5b1_a6b1_a7b1, - 0x12 - ); - // a0b3_a1b3_a2b3_a3b3_b4b3_b5b3_b6b3_b7b3 - ab[3] = _mm256_permute2x128_si256( - a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, - a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, - 0x30 - ); - // a0b7_a1b7_a2b7_a3b7_b4b7_b5b7_b6b7_b7b7 - ab[7] = _mm256_permute2x128_si256( - a0b3_a1b3_a2b3_a3b3_a4b7_a5b7_a6b7_a7b7, - a0b7_a1b7_a2b7_a3b7_a4b3_a5b3_a6b3_a7b3, - 0x12 - ); - - // Compute α (A B) - let alpha_v = _mm256_set1_epi32(alpha); - loop_m!(i, ab[i] = _mm256_mul_epi32(alpha_v, ab[i])); - - macro_rules! c { - ($i:expr, $j:expr) => - (c.offset(rsc * $i as isize + csc * $j as isize)); - } - - // C ← α A B + β C - let mut cv = [_mm256_setzero_si256(); MR]; - - if beta != 0 { - let beta_v = _mm256_set1_epi32(beta); - - // Read C - if rsc == 1 { - loop_m!(i, cv[i] = _mm256_loadu_si256(c![0, i] as _)); - // } else if csc == 1 { - // loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0])); - // loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0])); - } else { - loop_m!(i, cv[i] = _mm256_setr_epi32( - *c![0, i], - *c![1, i], - *c![2, i], - *c![3, i], - *c![4, i], - *c![5, i], - *c![6, i], - *c![7, i], - )); - } - // Compute β C - loop_m!(i, cv[i] = _mm256_mul_epi32(cv[i], beta_v)); - } - - // Compute (α A B) + (β C) - loop_m!(i, cv[i] = _mm256_add_epi32(cv[i], ab[i])); - - if rsc == 1 { - loop_m!(i, _mm256_storeu_si256(c![0, i] as _, cv[i])); - // } else if csc == 1 { - // loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i])); - // loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4])); - } else { - // TODO: This inner unrolled loop should be replaced by - // `loop_n!(j, *c![i, j] = _mm256_extract_epi32(cv[i], j);` - // However, rustc currently errors with: - // > error: argument 2 is required to be a constant - // Some reading: - // + https://internals.rust-lang.org/t/pre-rfc-const-function-arguments/6709/12 - // + https://www.reddit.com/r/rust/comments/9pxuoj/simd_instructions_requiring_a_constant_parameter/ - loop_m!(i, { - *c![i, 0] = _mm256_extract_epi32(cv[i], 0); - *c![i, 1] = _mm256_extract_epi32(cv[i], 1); - *c![i, 2] = _mm256_extract_epi32(cv[i], 2); - *c![i, 3] = _mm256_extract_epi32(cv[i], 3); - *c![i, 4] = _mm256_extract_epi32(cv[i], 4); - *c![i, 5] = _mm256_extract_epi32(cv[i], 5); - *c![i, 6] = _mm256_extract_epi32(cv[i], 6); - *c![i, 7] = _mm256_extract_epi32(cv[i], 7); - }) - } -} - -#[inline(always)] -unsafe fn at(ptr: *const T, i: usize) -> T { - *ptr.offset(i as isize) -} - -#[cfg(test)] -mod tests { - use super::*; - use aligned_alloc::Alloc; - - fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy - { - unsafe { - Alloc::new(n, Gemm::align_to()).init_with(elt) - } - } - - use super::T; - type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); - - fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { - const K: usize = 4; - let mut a = aligned_alloc(1, MR * K); - let mut b = aligned_alloc(0, NR * K); - for (i, x) in a.iter_mut().enumerate() { - *x = i as _; - } - - for i in 0..K { - b[i + i * NR] = 1; - } - let mut c = [0; MR * NR]; - unsafe { - kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); - // col major C - } - assert_eq!(&a[..], &c[..a.len()]); - } - - #[test] - fn test_native_kernel() { - test_a_kernel("kernel", kernel); - } - - #[test] - fn test_kernel_fallback_impl() { - test_a_kernel("kernel", kernel_fallback_impl); - } - - #[test] - fn test_loop_m_n() { - let mut m = [[0; NR]; MR]; - loop_m!(i, loop_n!(j, m[i][j] += 1)); - for arr in &m[..] { - for elt in &arr[..] { - assert_eq!(*elt, 1); - } - } - } - - mod test_arch_kernels { - use super::test_a_kernel; - macro_rules! test_arch_kernels_x86 { - ($($feature_name:tt, $function_name:ident),*) => { - $( - #[test] - fn $function_name() { - if is_x86_feature_detected_!($feature_name) { - test_a_kernel(stringify!($function_name), super::super::$function_name); - } else { - println!("Skipping, host does not have feature: {:?}", $feature_name); - } - } - )* - } - } - - #[cfg(any(target_arch="x86", target_arch="x86_64"))] - test_arch_kernels_x86! { - "avx2", kernel_target_avx2, - "avx", kernel_target_avx, - "sse2", kernel_target_sse2 - } - } -} diff --git a/src/kernel.rs b/src/kernel.rs index 32345db..83a71bd 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -133,11 +133,11 @@ macro_rules! impl_element_i { fn is_zero(&self) -> bool { *self == 0 } fn scale_by(&mut self, x: Self) { // TODO: Change the semantics - *self = self.wrapping_mul(x); + *self = self.saturating_mul(x); } // TODO: Change the semantics fn scaled_add(&mut self, alpha: Self, a: Self) { - *self = self.wrapping_add(alpha.wrapping_mul(a)); + *self = self.saturating_add(alpha.saturating_mul(a)); } } )+ diff --git a/src/lib.rs b/src/lib.rs index 331e3ac..4acc8c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,10 +62,10 @@ mod kernel; mod gemm; mod sgemm_kernel; mod dgemm_kernel; -// mod igemm_kernel; +mod i8gemm_kernel; mod util; mod aligned_alloc; pub use gemm::sgemm; pub use gemm::dgemm; -// pub use gemm::igemm; +pub use gemm::i8gemm; diff --git a/src/loopmacros.rs b/src/loopmacros.rs index 8e40d81..917d2c8 100644 --- a/src/loopmacros.rs +++ b/src/loopmacros.rs @@ -50,6 +50,80 @@ macro_rules! loop8 { }} } +#[cfg(debug_assertions)] +macro_rules! loop16 { + ($i:ident, $e:expr) => { + for $i in 0..16 { $e } + } +} + +#[cfg(not(debug_assertions))] +macro_rules! loop16 { + ($i:ident, $e:expr) => {{ + let $i = 0; $e; + let $i = 1; $e; + let $i = 2; $e; + let $i = 3; $e; + let $i = 4; $e; + let $i = 5; $e; + let $i = 6; $e; + let $i = 7; $e; + let $i = 8; $e; + let $i = 9; $e; + let $i = 10; $e; + let $i = 11; $e; + let $i = 12; $e; + let $i = 13; $e; + let $i = 14; $e; + let $i = 15; $e; + }} +} + +#[cfg(debug_assertions)] +macro_rules! loop32 { + ($i:ident, $e:expr) => { + for $i in 0..32 { $e } + } +} + +#[cfg(not(debug_assertions))] +macro_rules! loop32 { + ($i:ident, $e:expr) => {{ + let $i = 0; $e; + let $i = 1; $e; + let $i = 2; $e; + let $i = 3; $e; + let $i = 4; $e; + let $i = 5; $e; + let $i = 6; $e; + let $i = 7; $e; + let $i = 8; $e; + let $i = 9; $e; + let $i = 10; $e; + let $i = 11; $e; + let $i = 12; $e; + let $i = 13; $e; + let $i = 14; $e; + let $i = 15; $e; + let $i = 16; $e; + let $i = 17; $e; + let $i = 18; $e; + let $i = 19; $e; + let $i = 20; $e; + let $i = 21; $e; + let $i = 22; $e; + let $i = 23; $e; + let $i = 24; $e; + let $i = 25; $e; + let $i = 26; $e; + let $i = 27; $e; + let $i = 28; $e; + let $i = 29; $e; + let $i = 30; $e; + let $i = 31; $e; + }} +} + #[cfg(debug_assertions)] macro_rules! unroll_by { ($by:tt => $ntimes:expr, $e:expr) => { From 1ae7f38a2a3a17f0268f36a2bc225773ce12c1d6 Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Sun, 9 Dec 2018 13:26:20 +0100 Subject: [PATCH 4/6] First shot at generalizing and fixing the integration tests --- tests/sgemm.rs | 361 ++++++++++++++++++++++++++++--------------------- 1 file changed, 206 insertions(+), 155 deletions(-) diff --git a/tests/sgemm.rs b/tests/sgemm.rs index 5dc996b..e15bbb5 100644 --- a/tests/sgemm.rs +++ b/tests/sgemm.rs @@ -1,7 +1,11 @@ extern crate itertools; extern crate matrixmultiply; -use matrixmultiply::{sgemm, dgemm}; //, igemm}; +use matrixmultiply::{ + sgemm, + dgemm, + i8gemm, +}; use itertools::Itertools; use itertools::{ @@ -11,93 +15,129 @@ use itertools::{ }; use std::fmt::{Display, Debug}; -trait Float : Copy + Display + Debug + PartialEq { - fn zero() -> Self; - fn one() -> Self; - fn from(x: i64) -> Self; - fn nan() -> Self; - fn is_nan(self) -> bool; +trait GemmElement : Copy + Display + Debug + PartialEq { + // TODO: Provide default associated types once the following RFCs are merged and implemented: + // https://github.com/rust-lang/rfcs/pull/2532 + // https://github.com/rust-lang/rust/issues/29661 + // I.e., then we can do something like: + // type Output = Self; + // + // XXX: Is it somehow possible to already provide default impls for the _out functions in terms + // of the _in functions, where we assume that Input = Output? + type Output: Copy + Display + Debug + PartialEq; + + fn zero_in() -> Self; + fn one_in() -> Self; + fn nan_in() -> Self; + fn from_in(x: i64) -> Self; + fn is_nan_in(Self) -> bool; + + fn zero_out() -> Self::Output; + fn one_out() -> Self::Output; + fn nan_out() -> Self::Output; + fn from_out(x: i64) -> Self::Output; + fn is_nan_out(Self::Output) -> bool; + + fn to_out(Self) -> Self::Output; } -impl Float for f32 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { 0./0. } - fn is_nan(self) -> bool { self.is_nan() } +macro_rules! impl_gemm_element_f { + ($($t:ty),+) => { + $( + impl GemmElement for $t { + type Output = Self; + + fn zero_in() -> Self { 0. } + fn one_in() -> Self { 1. } + fn from_in(x: i64) -> Self { x as Self } + fn nan_in() -> Self { 0./0. } + fn is_nan_in(var: Self) -> bool { var.is_nan() } + + fn zero_out() -> Self::Output { 0. } + fn one_out() -> Self::Output { 1. } + fn from_out(x: i64) -> Self::Output { x as Self::Output } + fn nan_out() -> Self { 0./0. } + fn is_nan_out(var: Self::Output) -> bool { var.is_nan() } + + fn to_out(var: Self) -> Self::Output { + var + } + } + )+ + }; } -impl Float for f64 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { 0./0. } - fn is_nan(self) -> bool { self.is_nan() } -} +impl_gemm_element_f!(f32, f64); + +impl GemmElement for i8 { + type Output = i16; + + fn zero_in() -> Self { 0 } + fn one_in() -> Self { 1 } + fn from_in(x: i64) -> Self { x as Self } + fn nan_in() -> Self { i8::min_value() } // hack + fn is_nan_in(var: Self) -> bool { var == Self::nan_in() } + + fn zero_out() -> Self::Output { 0 } + fn one_out() -> Self::Output { 1 } + fn from_out(x: i64) -> Self::Output { x as Self::Output } + fn nan_out() -> Self::Output { i16::min_value() } // hack + fn is_nan_out(var: Self::Output) -> bool { var == Self::nan_out() } -impl Float for i32 { - fn zero() -> Self { 0 } - fn one() -> Self { 1 } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { i32::min_value() } // hack - fn is_nan(self) -> bool { self == i32::min_value() } + fn to_out(var: Self) -> Self::Output { + var as Self::Output + } } trait Gemm : Sized { + type Output; + unsafe fn gemm( m: usize, k: usize, n: usize, - alpha: Self, + alpha: Self::Output, a: *const Self, rsa: isize, csa: isize, b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize); + beta: Self::Output, + c: *mut Self::Output, rsc: isize, csc: isize); } -impl Gemm for f32 { - unsafe fn gemm( - m: usize, k: usize, n: usize, - alpha: Self, - a: *const Self, rsa: isize, csa: isize, - b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - sgemm( - m, k, n, - alpha, - a, rsa, csa, - b, rsb, csb, - beta, - c, rsc, csc) - } +macro_rules! impl_gemm_f { + ($(($t:ty, $f:ident)),+) => { + $( + impl Gemm for $t { + type Output = Self; + unsafe fn gemm( + m: usize, k: usize, n: usize, + alpha: Self, + a: *const Self, rsa: isize, csa: isize, + b: *const Self, rsb: isize, csb: isize, + beta: Self, + c: *mut Self, rsc: isize, csc: isize) { + $f( + m, k, n, + alpha, + a, rsa, csa, + b, rsb, csb, + beta, + c, rsc, csc) + } + } + )+ + }; } -// impl Gemm for i32 { -// unsafe fn gemm( -// m: usize, k: usize, n: usize, -// alpha: Self, -// a: *const Self, rsa: isize, csa: isize, -// b: *const Self, rsb: isize, csb: isize, -// beta: Self, -// c: *mut Self, rsc: isize, csc: isize) { -// igemm( -// m, k, n, -// alpha, -// a, rsa, csa, -// b, rsb, csb, -// beta, -// c, rsc, csc) -// } -// } - -impl Gemm for f64 { +impl_gemm_f!((f32, sgemm), (f64, dgemm)); + +impl Gemm for i8 { + type Output = i16; unsafe fn gemm( m: usize, k: usize, n: usize, - alpha: Self, + alpha: i16, a: *const Self, rsa: isize, csa: isize, b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - dgemm( + beta: i16, + c: *mut i16, rsc: isize, csc: isize) { + i8gemm( m, k, n, alpha, a, rsa, csa, @@ -109,105 +149,112 @@ impl Gemm for f64 { #[test] fn test_sgemm() { - test_gemm::(); + test_gemm::(); } #[test] fn test_dgemm() { - test_gemm::(); + test_gemm::(); } #[test] fn test_sgemm_strides() { - test_gemm_strides::(); + test_gemm_strides::(); } #[test] fn test_dgemm_strides() { - test_gemm_strides::(); + test_gemm_strides::(); } -// #[test] -// fn test_i32gemm_strides() { -// test_gemm_strides::(); -// } +#[test] +fn test_i8gemm_strides() { + test_gemm_strides::(); +} -fn test_gemm_strides() where F: Gemm + Float { +fn test_gemm_strides() + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq +{ for n in 0..10 { - test_strides::(n, n, n); + test_strides::(n, n, n); } for n in (3..12).map(|x| x * 7) { - test_strides::(n, n, n); + test_strides::(n, n, n); } - test_strides::(8, 12, 16); - test_strides::(8, 0, 10); + test_strides::(8, 12, 16); + test_strides::(8, 0, 10); } -fn test_gemm() where F: Gemm + Float { - test_mul_with_id::(4, 4, true); - test_mul_with_id::(8, 8, true); - test_mul_with_id::(32, 32, false); - test_mul_with_id::(128, 128, false); - test_mul_with_id::(17, 128, false); +fn test_gemm() + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq +{ + test_mul_with_id::(4, 4, true); + test_mul_with_id::(8, 8, true); + test_mul_with_id::(32, 32, false); + test_mul_with_id::(128, 128, false); + test_mul_with_id::(17, 128, false); for i in 0..12 { for j in 0..12 { - test_mul_with_id::(i, j, true); + test_mul_with_id::(i, j, true); } } /* */ - test_mul_with_id::(17, 257, false); - test_mul_with_id::(24, 512, false); + test_mul_with_id::(17, 257, false); + test_mul_with_id::(24, 512, false); for i in 0..10 { for j in 0..10 { - test_mul_with_id::(i * 4, j * 4, true); + test_mul_with_id::(i * 4, j * 4, true); } } - test_mul_with_id::(266, 265, false); - test_mul_id_with::(4, 4, true); + test_mul_with_id::(266, 265, false); + test_mul_id_with::(4, 4, true); for i in 0..12 { for j in 0..12 { - test_mul_id_with::(i, j, true); + test_mul_id_with::(i, j, true); } } - test_mul_id_with::(266, 265, false); - test_scale::(0, 4, 4, true); - test_scale::(4, 0, 4, true); - test_scale::(4, 4, 0, true); - test_scale::(4, 4, 4, true); - test_scale::(19, 20, 16, true); - test_scale::(150, 140, 128, false); + test_mul_id_with::(266, 265, false); + test_scale::(0, 4, 4, true); + test_scale::(4, 0, 4, true); + test_scale::(4, 4, 0, true); + test_scale::(4, 4, 4, true); + test_scale::(19, 20, 16, true); + test_scale::(150, 140, 128, false); } /// multiply a M x N matrix with an N x N id matrix #[cfg(test)] -fn test_mul_with_id(m: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_mul_with_id(m: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, n, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c = vec![F::zero(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c = vec![F::zero_out(); m * n]; println!("test matrix with id input M={}, N={}", m, n); for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for i in 0..k { - b[i + i * k] = F::one(); + b[i + i * k] = F::one_in(); } unsafe { F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c.as_mut_ptr(), n as isize, 1, ) } for (i, (x, y)) in a.iter().zip(&c).enumerate() { - if x != y { + if F::to_out(*x) != *y { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -228,33 +275,34 @@ fn test_mul_with_id(m: usize, n: usize, small: bool) /// multiply a K x K id matrix with an K x N matrix #[cfg(test)] -fn test_mul_id_with(k: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_mul_id_with(k: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (k, k, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c = vec![F::zero(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c = vec![F::zero_out(); m * n]; for i in 0..k { - a[i + i * k] = F::one(); + a[i + i * k] = F::one_in(); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } unsafe { F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c.as_mut_ptr(), n as isize, 1, ) } for (i, (x, y)) in b.iter().zip(&c).enumerate() { - if x != y { + if F::to_out(*x) != *y { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -274,55 +322,56 @@ fn test_mul_id_with(k: usize, n: usize, small: bool) } #[cfg(test)] -fn test_scale(m: usize, k: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_scale(m: usize, k: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c1 = vec![F::one(); m * n]; - let mut c2 = vec![F::nan(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c1 = vec![F::one_out(); m * n]; + let mut c2 = vec![F::nan_out(); m * n]; // init c2 with NaN to test the overwriting behavior when beta = 0. for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } unsafe { // C1 = 3 A B F::gemm( m, k, n, - F::from(3), + F::from_out(3), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c1.as_mut_ptr(), n as isize, 1, ); // C2 = A B F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c2.as_mut_ptr(), n as isize, 1, ); // C2 = A B + 2 C2 F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::from(2), + F::from_out(2), c2.as_mut_ptr(), n as isize, 1, ); } for (i, (x, y)) in c1.iter().zip(&c2).enumerate() { - if x != y || x.is_nan() || y.is_nan() { + if x != y || F::is_nan_out(*x) || F::is_nan_out(*y) { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -369,8 +418,9 @@ impl Default for Layout { #[cfg(test)] -fn test_strides(m: usize, k: usize, n: usize) - where F: Gemm + Float +fn test_strides(m: usize, k: usize, n: usize) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); @@ -383,15 +433,16 @@ fn test_strides(m: usize, k: usize, n: usize) for elt in layouts_iter { let layouts = [elt[0], elt[1], elt[2], elt[3]]; let (m0, m1, m2, m3) = multipliers_iter.next_tuple().unwrap(); - test_strides_inner::(m, k, n, [m0, m1, m2, m3], layouts); + test_strides_inner::(m, k, n, [m0, m1, m2, m3], layouts); } } -fn test_strides_inner(m: usize, k: usize, n: usize, +fn test_strides_inner(m: usize, k: usize, n: usize, stride_multipliers: [[usize; 2]; 4], layouts: [Layout; 4]) - where F: Gemm + Float + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); @@ -401,16 +452,16 @@ fn test_strides_inner(m: usize, k: usize, n: usize, let mstridec = stride_multipliers[2]; let mstridec2 = stride_multipliers[3]; - let mut a = vec![F::zero(); m * k * mstridea[0] * mstridea[1]]; - let mut b = vec![F::zero(); k * n * mstrideb[0] * mstrideb[1]]; - let mut c1 = vec![F::nan(); m * n * mstridec[0] * mstridec[1]]; - let mut c2 = vec![F::nan(); m * n * mstridec2[0] * mstridec2[1]]; + let mut a = vec![F::zero_in(); m * k * mstridea[0] * mstridea[1]]; + let mut b = vec![F::zero_in(); k * n * mstrideb[0] * mstrideb[1]]; + let mut c1 = vec![F::nan_out(); m * n * mstridec[0] * mstridec[1]]; + let mut c2 = vec![F::nan_out(); m * n * mstridec2[0] * mstridec2[1]]; for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } let la = layouts[0]; @@ -439,30 +490,30 @@ fn test_strides_inner(m: usize, k: usize, n: usize, // C1 = A B F::gemm( m, k, n, - F::from(1), + F::from_out(1), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::zero(), + F::zero_out(), c1.as_mut_ptr(), rs_c1, cs_c1, ); - + // C1 += 2 A B F::gemm( m, k, n, - F::from(2), + F::from_out(2), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::from(1), + F::from_out(1), c1.as_mut_ptr(), rs_c1, cs_c1, ); // C2 = 3 A B F::gemm( m, k, n, - F::from(3), + F::from_out(3), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::zero(), + F::zero_out(), c2.as_mut_ptr(), rs_c2, cs_c2, ); } @@ -488,7 +539,7 @@ fn test_strides_inner(m: usize, k: usize, n: usize, let irem = index % rs_c1 as usize; let jrem = index % cs_c1 as usize; if irem != 0 && jrem != 0 { - assert!(elt.is_nan(), + assert!(F::is_nan_out(*elt), "Element at index={} ({}, {}) should be NaN, but was {}\n\ c1: {:?}\n", index, i, j, elt, From 26e6c82d1e4c4949560907f3c60a4f0dc012bbde Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Sun, 9 Dec 2018 14:02:35 +0100 Subject: [PATCH 5/6] Make gemm loop work with different kernel dimensions --- src/dgemm_kernel.rs | 8 +++----- src/gemm.rs | 42 ++++++++++++++++++++++-------------------- src/i8gemm_kernel.rs | 8 +++----- src/kernel.rs | 11 ++++++----- src/lib.rs | 2 ++ src/sgemm_kernel.rs | 8 +++----- 6 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/dgemm_kernel.rs b/src/dgemm_kernel.rs index 8c679d8..eee30dd 100644 --- a/src/dgemm_kernel.rs +++ b/src/dgemm_kernel.rs @@ -27,13 +27,11 @@ impl GemmKernel for Gemm { type ElemIn = T; type ElemOut = T; - #[inline(always)] - fn align_to() -> usize { 0 } + const MR: usize = MR; + const NR: usize = NR; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 0 } #[inline(always)] fn always_masked() -> bool { true } diff --git a/src/gemm.rs b/src/gemm.rs index 2ab91ec..251be86 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -113,8 +113,8 @@ pub unsafe fn i8gemm( fn ensure_kernel_params() where K: GemmKernel { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; assert!(mr > 0 && mr <= 8); assert!(nr > 0 && nr <= 8); assert!(mr * nr * size_of::() <= 8 * 4 * 8); @@ -146,7 +146,7 @@ unsafe fn gemm_loop( let knc = K::nc(); let kkc = K::kc(); let kmc = K::mc(); - ensure_kernel_params::(); + // ensure_kernel_params::(); let (mut packing_buffer, bp_offset) = make_packing_buffer::(m, k, n); let app = packing_buffer.ptr_mut(); @@ -165,7 +165,7 @@ unsafe fn gemm_loop( let a = a.stride_offset(csa, kkc * l4); // Pack B -> B~ - pack(kc, nc, K::nr(), bpp, b, csb, rsb); + pack(kc, nc, K::NR, bpp, b, csb, rsb); // LOOP 3: split m into mc parts for (l3, mc) in range_chunk(m, kmc) { @@ -174,13 +174,13 @@ unsafe fn gemm_loop( let c = c.stride_offset(rsc, kmc * l3); // Pack A -> A~ - pack(kc, mc, K::mr(), app, a, rsa, csa); + pack(kc, mc, K::MR, app, a, rsa, csa); // First time writing to C, use user's `beta`, else accumulate let betap = if l4 == 0 { beta } else { <_>::one() }; // LOOP 2 and 1 - gemm_packed::(nc, kc, mc, + gemm_packed::(nc, kc, mc, alpha, app, bpp, betap, @@ -197,18 +197,20 @@ unsafe fn gemm_loop( /// + nc: columns of packed B /// + kc: columns of packed A / rows of packed B /// + mc: rows of packed A -unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, - alpha: K::ElemOut, - app: *const K::ElemIn, bpp: *const K::ElemIn, - beta: K::ElemOut, - c: *mut K::ElemOut, rsc: isize, csc: isize) - where K: GemmKernel, +unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, + alpha: Tout, + app: *const Tin, bpp: *const Tin, + beta: Tout, + c: *mut Tout, rsc: isize, csc: isize) + where K: GemmKernel, + Tin: Element, + Tout: Element, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment - assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); - let mut mask_buf = [0u8; 256 + 31]; + // assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); + let mut mask_buf = [0u8; K::MR * K::NR * size_of::() + 31]; let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::ElemOut; // LOOP 2: through micropanels in packed `b` @@ -254,8 +256,8 @@ unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc(k: usize, alpha: Tout, Tin: Element, Tout: Element, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // use column major order for `mask_buf` K::kernel(k, Tout::one(), a, b, Tout::zero(), mask_buf, 1, mr as isize); let mut ab = mask_buf; diff --git a/src/i8gemm_kernel.rs b/src/i8gemm_kernel.rs index 3ccc584..532e5c3 100644 --- a/src/i8gemm_kernel.rs +++ b/src/i8gemm_kernel.rs @@ -30,13 +30,11 @@ impl GemmKernel for Gemm { type ElemIn = Tin; type ElemOut = Tout; - #[inline(always)] - fn align_to() -> usize { 16 } + const MR: usize = MR; + const NR: usize = NR; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 16 } #[inline(always)] fn always_masked() -> bool { true } diff --git a/src/kernel.rs b/src/kernel.rs index 83a71bd..5b1c233 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -11,14 +11,15 @@ pub trait GemmKernel { type ElemIn: Element; type ElemOut: Element; + /// Number of kernel rows + const MR: usize; + + /// Number of kernel columns + const NR: usize; + /// align inputs to this fn align_to() -> usize; - /// Kernel rows - fn mr() -> usize; - /// Kernel cols - fn nr() -> usize; - /// Whether to always use the masked wrapper around the kernel. /// /// If masked, the kernel is always called with α=1, β=0 diff --git a/src/lib.rs b/src/lib.rs index 4acc8c8..9ab1695 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,8 @@ #![doc(html_root_url = "https://docs.rs/matrixmultiply/0.2/")] +#![feature(const_fn)] + extern crate rawpointer; #[macro_use] mod archmacros_x86; diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index be05bfb..b23f490 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -29,13 +29,11 @@ impl GemmKernel for Gemm { type ElemIn = T; type ElemOut = T; - #[inline(always)] - fn align_to() -> usize { 32 } + const MR: usize = MR; + const NR: usize = NR; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 32 } #[inline(always)] fn always_masked() -> bool { false } From 49ad58a70eed49d3b714bad7097a2de1dda0788a Mon Sep 17 00:00:00 2001 From: Richard Janis Goldschmidt Date: Mon, 10 Dec 2018 13:34:07 +0100 Subject: [PATCH 6/6] Increase the buffer size to accomodate very large kernels --- src/gemm.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index 251be86..c3337cc 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -180,7 +180,7 @@ unsafe fn gemm_loop( let betap = if l4 == 0 { beta } else { <_>::one() }; // LOOP 2 and 1 - gemm_packed::(nc, kc, mc, + gemm_packed::(nc, kc, mc, alpha, app, bpp, betap, @@ -197,20 +197,19 @@ unsafe fn gemm_loop( /// + nc: columns of packed B /// + kc: columns of packed A / rows of packed B /// + mc: rows of packed A -unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, - alpha: Tout, - app: *const Tin, bpp: *const Tin, - beta: Tout, - c: *mut Tout, rsc: isize, csc: isize) - where K: GemmKernel, - Tin: Element, - Tout: Element, +unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, + alpha: K::ElemOut, + app: *const K::ElemIn, bpp: *const K::ElemIn, + beta: K::ElemOut, + c: *mut K::ElemOut, rsc: isize, csc: isize) + where K: GemmKernel, { let mr = K::MR; let nr = K::NR; // make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment // assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); - let mut mask_buf = [0u8; K::MR * K::NR * size_of::() + 31]; + // let mut mask_buf = [0u8; 256 + 31]; + let mut mask_buf = [0u8; 16*32*2 + 31]; let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::ElemOut; // LOOP 2: through micropanels in packed `b`