diff --git a/Cargo.toml b/Cargo.toml index 0f8bd72..612608a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["gemm", "gemm-common", "gemm-f16", "gemm-f32", "gemm-f64", "gemm-c32", "gemm-c64"] +resolver = "2" [workspace.dependencies] lazy_static = "1.4" @@ -13,3 +14,4 @@ paste = "1.0" [profile.dev] opt-level = 3 + diff --git a/gemm-common/src/gemm.rs b/gemm-common/src/gemm.rs index f9eced1..b4e1162 100644 --- a/gemm-common/src/gemm.rs +++ b/gemm-common/src/gemm.rs @@ -814,7 +814,7 @@ macro_rules! gemm_def { $crate::__inject_mod!(avx512f, $ty, 8 * $multiplier, Avx512f); #[cfg(target_arch = "aarch64")] - $crate::__inject_mod!(neon, $ty, 2 * $multiplier, Scalar); + $crate::__inject_mod!(neon, $ty, 2 * $multiplier, Neon); #[cfg(target_arch = "wasm32")] $crate::__inject_mod!(simd128, $ty, 2 * $multiplier, Simd128); diff --git a/gemm-common/src/microkernel.rs b/gemm-common/src/microkernel.rs index 45b42a3..3998c1e 100644 --- a/gemm-common/src/microkernel.rs +++ b/gemm-common/src/microkernel.rs @@ -414,6 +414,312 @@ macro_rules! microkernel { }; } +#[macro_export] +macro_rules! microkernel_f16 { + ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt $(, $nr_div_n: tt, $n: tt)?) => { + #[inline] + $(#[target_feature(enable = $target)])? + // 0, 1, or 2 for generic alpha + pub unsafe fn $name( + m: usize, + n: usize, + k: usize, + dst: *mut T, + mut packed_lhs: *const T, + mut packed_rhs: *const T, + dst_cs: isize, + dst_rs: isize, + lhs_cs: isize, + rhs_rs: isize, + rhs_cs: isize, + alpha: T, + beta: T, + alpha_status: u8, + _conj_dst: bool, + _conj_lhs: bool, + _conj_rhs: bool, + mut next_lhs: *const T, + ) { + let mut accum_storage = [[splat(T::ZERO); $mr_div_n]; $nr]; + let accum = accum_storage.as_mut_ptr() as *mut Pack; + + let mut lhs = [::core::mem::MaybeUninit::::uninit(); $mr_div_n]; + let mut rhs = ::core::mem::MaybeUninit::::uninit(); + + #[derive(Copy, Clone)] + struct KernelIter { + packed_lhs: *const T, + packed_rhs: *const T, + next_lhs: *const T, + lhs_cs: isize, + rhs_rs: isize, + rhs_cs: isize, + accum: *mut Pack, + lhs: *mut Pack, + rhs: *mut Pack, + } + + impl KernelIter { + #[inline(always)] + unsafe fn execute(self, iter: usize) { + let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs); + let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs); + let next_lhs = self.next_lhs.wrapping_offset(iter as isize * self.lhs_cs); + + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + *self.lhs.add(M_ITER) = (packed_lhs.add(M_ITER * N) as *const Pack).read_unaligned(); + }}); + + seq_macro::seq!(N_ITER in 0..$nr {{ + *self.rhs = splat(*packed_rhs.wrapping_offset(N_ITER * self.rhs_cs)); + let accum = self.accum.add(N_ITER * $mr_div_n); + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + let accum = &mut *accum.add(M_ITER); + *accum = mul_add( + *self.lhs.add(M_ITER), + *self.rhs, + *accum, + ); + }}); + }}); + + let _ = next_lhs; + } + + $( + #[inline(always)] + unsafe fn execute_neon(self, iter: usize) { + debug_assert_eq!(self.rhs_cs, 1); + let packed_lhs = self.packed_lhs.wrapping_offset(iter as isize * self.lhs_cs); + let packed_rhs = self.packed_rhs.wrapping_offset(iter as isize * self.rhs_rs); + + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + *self.lhs.add(M_ITER) = (packed_lhs.add(M_ITER * N) as *const Pack).read_unaligned(); + }}); + + seq_macro::seq!(N_ITER0 in 0..$nr_div_n {{ + *self.rhs = (packed_rhs.wrapping_offset(N_ITER0 * $n) as *const Pack).read_unaligned(); + + seq_macro::seq!(N_ITER1 in 0..$n {{ + const N_ITER: usize = N_ITER0 * $n + N_ITER1; + let accum = self.accum.add(N_ITER * $mr_div_n); + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + let accum = &mut *accum.add(M_ITER); + *accum = mul_add_lane::( + *self.lhs.add(M_ITER), + *self.rhs, + *accum, + ); + }}); + }}); + }}); + } + )? + } + + let k_unroll = k / $unroll; + let k_leftover = k % $unroll; + + loop { + $( + let _ = $nr_div_n; + if rhs_cs == 1 { + let mut depth = k_unroll; + if depth != 0 { + loop { + let iter = KernelIter { + packed_lhs, + next_lhs, + packed_rhs, + lhs_cs, + rhs_rs, + rhs_cs, + accum, + lhs: lhs.as_mut_ptr() as _, + rhs: &mut rhs as *mut _ as _, + }; + + seq_macro::seq!(UNROLL_ITER in 0..$unroll {{ + iter.execute_neon(UNROLL_ITER); + }}); + + packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs); + packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs); + next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs); + + depth -= 1; + if depth == 0 { + break; + } + } + } + depth = k_leftover; + if depth != 0 { + loop { + KernelIter { + packed_lhs, + next_lhs, + packed_rhs, + lhs_cs, + rhs_rs, + rhs_cs, + accum, + lhs: lhs.as_mut_ptr() as _, + rhs: &mut rhs as *mut _ as _, + } + .execute_neon(0); + + packed_lhs = packed_lhs.wrapping_offset(lhs_cs); + packed_rhs = packed_rhs.wrapping_offset(rhs_rs); + next_lhs = next_lhs.wrapping_offset(lhs_cs); + + depth -= 1; + if depth == 0 { + break; + } + } + } + break; + } + )? + + let mut depth = k_unroll; + if depth != 0 { + loop { + let iter = KernelIter { + packed_lhs, + next_lhs, + packed_rhs, + lhs_cs, + rhs_rs, + rhs_cs, + accum, + lhs: lhs.as_mut_ptr() as _, + rhs: &mut rhs as *mut _ as _, + }; + + seq_macro::seq!(UNROLL_ITER in 0..$unroll {{ + iter.execute(UNROLL_ITER); + }}); + + packed_lhs = packed_lhs.wrapping_offset($unroll * lhs_cs); + packed_rhs = packed_rhs.wrapping_offset($unroll * rhs_rs); + next_lhs = next_lhs.wrapping_offset($unroll * lhs_cs); + + depth -= 1; + if depth == 0 { + break; + } + } + } + depth = k_leftover; + if depth != 0 { + loop { + KernelIter { + packed_lhs, + next_lhs, + packed_rhs, + lhs_cs, + rhs_rs, + rhs_cs, + accum, + lhs: lhs.as_mut_ptr() as _, + rhs: &mut rhs as *mut _ as _, + } + .execute(0); + + packed_lhs = packed_lhs.wrapping_offset(lhs_cs); + packed_rhs = packed_rhs.wrapping_offset(rhs_rs); + next_lhs = next_lhs.wrapping_offset(lhs_cs); + + depth -= 1; + if depth == 0 { + break; + } + } + } + break; + } + + if m == $mr_div_n * N && n == $nr && dst_rs == 1 { + let alpha = splat(alpha); + let beta = splat(beta); + if alpha_status == 2 { + seq_macro::seq!(N_ITER in 0..$nr {{ + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack; + dst.write_unaligned(add( + mul(alpha, dst.read_unaligned()), + mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER)), + )); + }}); + }}); + } else if alpha_status == 1 { + seq_macro::seq!(N_ITER in 0..$nr {{ + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack; + dst.write_unaligned(mul_add( + beta, + *accum.offset(M_ITER + $mr_div_n * N_ITER), + dst.read_unaligned(), + )); + }}); + }}); + } else { + seq_macro::seq!(N_ITER in 0..$nr {{ + seq_macro::seq!(M_ITER in 0..$mr_div_n {{ + let dst = dst.offset(M_ITER * N as isize + N_ITER * dst_cs) as *mut Pack; + dst.write_unaligned(mul(beta, *accum.offset(M_ITER + $mr_div_n * N_ITER))); + }}); + }}); + } + } else { + let src = accum_storage; // write to stack + let src = src.as_ptr() as *const T; + + if alpha_status == 2 { + for j in 0..n { + let dst_j = dst.offset(dst_cs * j as isize); + let src_j = src.add(j * $mr_div_n * N); + + for i in 0..m { + let dst_ij = dst_j.offset(dst_rs * i as isize); + let src_ij = src_j.add(i); + + *dst_ij = alpha * *dst_ij + beta * *src_ij; + } + } + } else if alpha_status == 1 { + for j in 0..n { + let dst_j = dst.offset(dst_cs * j as isize); + let src_j = src.add(j * $mr_div_n * N); + + for i in 0..m { + let dst_ij = dst_j.offset(dst_rs * i as isize); + let src_ij = src_j.add(i); + + *dst_ij = *dst_ij + beta * *src_ij; + } + } + } else { + for j in 0..n { + let dst_j = dst.offset(dst_cs * j as isize); + let src_j = src.add(j * $mr_div_n * N); + + for i in 0..m { + let dst_ij = dst_j.offset(dst_rs * i as isize); + let src_ij = src_j.add(i); + + *dst_ij = beta * *src_ij; + } + } + } + } + + } + }; +} + #[macro_export] macro_rules! microkernel_cplx { ($([$target: tt])?, $unroll: tt, $name: ident, $mr_div_n: tt, $nr: tt) => { diff --git a/gemm-common/src/simd.rs b/gemm-common/src/simd.rs index cad5d23..5e1e000 100644 --- a/gemm-common/src/simd.rs +++ b/gemm-common/src/simd.rs @@ -64,6 +64,24 @@ mod x86 { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub use x86::*; +#[cfg(target_arch = "aarch64")] +mod aarch64 { + use super::*; + + #[derive(Copy, Clone)] + pub struct Neon; + + impl Simd for Neon { + #[inline] + #[target_feature(enable = "neon")] + unsafe fn vectorize(f: impl FnOnce()) { + f() + } + } +} +#[cfg(target_arch = "aarch64")] +pub use aarch64::*; + #[cfg(target_arch = "wasm32")] mod wasm32 { use super::*; diff --git a/gemm-f16/src/gemm.rs b/gemm-f16/src/gemm.rs index e83ad74..50b6f87 100644 --- a/gemm-f16/src/gemm.rs +++ b/gemm-f16/src/gemm.rs @@ -89,6 +89,89 @@ unsafe fn pack_generic_inner_loop( } } +// DIRECT copy of [`pack_generic_inner_loop`] but adapted for pure f16 inner +#[inline(always)] +unsafe fn pack_generic_inner_loop_f16( + mut dst: *mut T, + mut src: *const T, + src_rs: isize, + src_cs: isize, + src_width: usize, + k: usize, +) { + if src_width == DST_WIDTH { + if src_rs == 1 { + for _ in 0..k { + // let val = (src as *const [T; DST_WIDTH]).read(); + // val.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, DST_WIDTH)); + std::ptr::copy_nonoverlapping(src, dst, DST_WIDTH); + + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } else { + for _ in 0..k { + for j in 0..DST_WIDTH { + *dst.add(j) = (*src.offset(j as isize * src_rs)).into(); + } + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } + } else if src_width == N { + if src_rs == 1 { + for _ in 0..k { + // let val = (src as *const [T; N]).read(); + // val.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, N)); + std::ptr::copy_nonoverlapping(src, dst, N); + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } else { + for _ in 0..k { + for j in 0..N { + *dst.add(j) = (*src.offset(j as isize * src_rs)).into(); + } + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } + } else if src_width == 2 * N { + if src_rs == 1 { + for _ in 0..k { + // let val0 = (src as *const [T; N]).read(); + // let val1 = (src.add(N) as *const [T; N]).read(); + // val0.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst, N)); + // val1.convert_to_f32_slice(core::slice::from_raw_parts_mut(dst.add(N), N)); + std::ptr::copy_nonoverlapping(src, dst, 2 * N); + + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } else { + for _ in 0..k { + for j in 0..2 * N { + *dst.add(j) = (*src.offset(j as isize * src_rs)).into(); + } + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } + } else { + for _ in 0..k { + for j in 0..src_width { + *dst.add(j) = (*src.offset(j as isize * src_rs)).into(); + } + quick_zero(core::slice::from_raw_parts_mut( + dst.add(src_width), + DST_WIDTH - src_width, + )); + src = src.wrapping_offset(src_cs); + dst = dst.add(DST_WIDTH); + } + } +} + #[inline(always)] unsafe fn pack_generic( m: usize, @@ -114,6 +197,32 @@ unsafe fn pack_generic( } } +// DIRECT copy of [`pack_generic`] but adapted for pure f16 +#[inline(always)] +unsafe fn pack_generic_f16( + m: usize, + k: usize, + mut dst: *mut T, + mut src: *const T, + src_cs: isize, + src_rs: isize, + dst_stride: usize, +) { + let m_width = m / DST_WIDTH * DST_WIDTH; + + let mut i = 0; + while i < m_width { + pack_generic_inner_loop_f16::(dst, src, src_rs, src_cs, DST_WIDTH, k); + src = src.wrapping_offset(src_rs * DST_WIDTH as isize); + dst = dst.add(dst_stride); + + i += DST_WIDTH; + } + if i < m { + pack_generic_inner_loop_f16::(dst, src, src_rs, src_cs, m - i, k); + } +} + #[inline(never)] pub unsafe fn pack_lhs( m: usize, @@ -144,6 +253,38 @@ pub unsafe fn pack_rhs( pack_generic::(n, k, dst, src, src_rs, src_cs, dst_stride); } +// DIRECT copy of [`pack_lhs`] but adapted for pure f16 +#[inline(never)] +pub unsafe fn pack_lhs_f16( + m: usize, + k: usize, + dst: Ptr, + src: Ptr, + src_cs: isize, + src_rs: isize, + dst_stride: usize, +) { + let dst = dst.0; + let src = src.0; + pack_generic_f16::(m, k, dst, src, src_cs, src_rs, dst_stride); +} + +// DIRECT copy of [`pack_rhs`] but adapted for pure f16 +#[inline(never)] +pub unsafe fn pack_rhs_f16( + n: usize, + k: usize, + dst: Ptr, + src: Ptr, + src_cs: isize, + src_rs: isize, + dst_stride: usize, +) { + let dst = dst.0; + let src = src.0; + pack_generic_f16::(n, k, dst, src, src_rs, src_cs, dst_stride); +} + #[inline(always)] pub unsafe fn gemm_basic_generic< const N: usize, @@ -508,6 +649,373 @@ pub unsafe fn gemm_basic_generic< } } +// DIRECT copy of [`gemm_basic`] but adapted for pure f16 +#[inline(always)] +pub unsafe fn gemm_basic_f16< + const N: usize, + const MR: usize, + const NR: usize, + const MR_DIV_N: usize, +>( + m: usize, + n: usize, + k: usize, + dst: *mut T, + dst_cs: isize, + dst_rs: isize, + read_dst: bool, + lhs: *const T, + lhs_cs: isize, + lhs_rs: isize, + rhs: *const T, + rhs_cs: isize, + rhs_rs: isize, + mut alpha: T, + beta: T, + dispatcher: &[[MicroKernelFn; NR]; MR_DIV_N], + parallelism: Parallelism, +) { + // println!("-- {m} {n} {k} \n lhs: {:?}\n {:?}", std::slice::from_raw_parts(lhs, m * k), std::slice::from_raw_parts(rhs, n * k)); + if m == 0 || n == 0 { + return; + } + if !read_dst { + alpha = T::ZERO; + } + + if k == 0 { + if alpha == T::ZERO { + for j in 0..n { + for i in 0..m { + *dst.offset(i as isize * dst_rs + j as isize * dst_cs) = T::ZERO; + } + } + return; + } + if alpha == T::ONE { + return; + } + + for j in 0..n { + for i in 0..m { + let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs); + *dst = alpha * *dst; + } + } + return; + } + + let KernelParams { kc, mc, nc } = kernel_params(m, n, k, MR, NR, core::mem::size_of::()); + let nc = if nc > 0 { + nc + } else { + match parallelism { + Parallelism::None => 128 * NR, + Parallelism::Rayon(_) => div_ceil(n, NR) * NR, + } + }; + + let simd_align = CACHELINE_ALIGN; + + let packed_rhs_stride = kc * NR; + let packed_lhs_stride = kc * MR; + + let dst = Ptr(dst); + let lhs = Ptr(lhs as *mut T); + let rhs = Ptr(rhs as *mut T); + + let mut mem = GlobalMemBuffer::new(StackReq::new_aligned::( + packed_rhs_stride * (nc / NR), + simd_align, + )); + + let stack = DynStack::new(&mut mem); + let mut packed_rhs_storage = stack + .make_aligned_uninit::(packed_rhs_stride * (nc / NR), simd_align) + .0; + + let packed_rhs = Ptr(packed_rhs_storage.as_mut_ptr() as *mut T); + + let packed_rhs_rs = NR as isize; + let packed_rhs_cs = 1; + + let mut col_outer = 0; + while col_outer != n { + let n_chunk = nc.min(n - col_outer); + + let mut alpha = alpha; + + let mut depth_outer = 0; + while depth_outer != k { + let k_chunk = kc.min(k - depth_outer); + let alpha_status = if alpha == T::ZERO { + 0 + } else if alpha == T::ONE { + 1 + } else { + 2 + }; + + let n_threads = match parallelism { + Parallelism::None => 1, + Parallelism::Rayon(n_threads) => { + let threading_threshold = get_threading_threshold(); + if m * n_chunk * k_chunk <= threading_threshold { + 1 + } else { + if n_threads == 0 { + rayon::current_num_threads() + } else { + n_threads + } + } + } + }; + + // pack rhs + if n_threads <= 1 { + pack_rhs_f16::<1, NR>( + n_chunk, + k_chunk, + packed_rhs, + rhs.wrapping_offset( + depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs, + ), + rhs_cs, + rhs_rs, + packed_rhs_stride, + ); + } else { + let n_tasks = div_ceil(n_chunk, NR); + let base = n_tasks / n_threads; + let rem = n_tasks % n_threads; + + let tid_to_col_inner = |tid: usize| { + if tid == n_threads { + return n_chunk; + } + + let col = if tid < rem { + NR * tid * (base + 1) + } else { + NR * (rem + tid * base) + }; + + col.min(n_chunk) + }; + + let func = |tid: usize| { + let col_inner = tid_to_col_inner(tid); + let ncols = tid_to_col_inner(tid + 1) - col_inner; + let j = col_inner / NR; + + if ncols > 0 { + pack_rhs_f16::<1, NR>( + ncols, + k_chunk, + packed_rhs.wrapping_add(j * packed_rhs_stride), + rhs.wrapping_offset( + depth_outer as isize * rhs_rs + + (col_outer + col_inner) as isize * rhs_cs, + ), + rhs_cs, + rhs_rs, + packed_rhs_stride, + ); + } + }; + par_for_each(n_threads, func); + } + + let n_col_mini_chunks = (n_chunk + (NR - 1)) / NR; + + let mut n_jobs = 0; + let mut row_outer = 0; + while row_outer != m { + let mut m_chunk = mc.min(m - row_outer); + if m_chunk > N { + m_chunk = m_chunk / N * N; + } + let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR; + n_jobs += n_col_mini_chunks * n_row_mini_chunks; + row_outer += m_chunk; + } + + // use a single thread for small workloads + + let func = move |tid| { + L2_SLAB.with(|mem| { + let mut mem = mem.borrow_mut(); + let stack = DynStack::new(&mut **mem); + + let (mut packed_lhs_storage, _) = + stack.make_aligned_uninit::(packed_lhs_stride * (mc / MR), simd_align); + + let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut T); + + let min_jobs_per_thread = n_jobs / n_threads; + let rem = n_jobs - n_threads * min_jobs_per_thread; + + // thread `tid` takes min_jobs_per_thread or min_jobs_per_thread + 1 + let (job_start, job_end) = if tid < rem { + let start = tid * (min_jobs_per_thread + 1); + (start, start + min_jobs_per_thread + 1) + } else { + // start = rem * (min_jobs_per_thread + 1) + (tid - rem) * min_jobs_per_thread; + let start = tid * min_jobs_per_thread + rem; + (start, start + min_jobs_per_thread) + }; + + let mut row_outer = 0; + let mut job_id = 0; + while row_outer != m { + let mut m_chunk = mc.min(m - row_outer); + if m_chunk > N { + m_chunk = m_chunk / N * N; + } + let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR; + + let n_mini_jobs = n_col_mini_chunks * n_row_mini_chunks; + + if job_id >= job_end { + return; + } + if job_id + n_mini_jobs < job_start { + row_outer += m_chunk; + job_id += n_mini_jobs; + continue; + } + + let packed_lhs_cs = MR as isize; + + pack_lhs_f16::( + m_chunk, + k_chunk, + packed_lhs, + lhs.wrapping_offset( + row_outer as isize * lhs_rs + depth_outer as isize * lhs_cs, + ), + lhs_cs, + lhs_rs, + packed_lhs_stride, + ); + + let mut j = 0; + while j < n_col_mini_chunks { + let mut i = 0; + while i < n_row_mini_chunks { + let col_inner = NR * j; + let n_chunk_inner = NR.min(n_chunk - col_inner); + + let row_inner = MR * i; + let m_chunk_inner = MR.min(m_chunk - row_inner); + + let inner_idx = &mut i; + if job_id < job_start || job_id >= job_end { + job_id += 1; + *inner_idx += 1; + continue; + } + job_id += 1; + + let dst = dst.wrapping_offset( + (row_outer + row_inner) as isize * dst_rs + + (col_outer + col_inner) as isize * dst_cs, + ); + + let func = dispatcher[(m_chunk_inner + (N - 1)) / N - 1] + [n_chunk_inner - 1]; + + let mut tmp = [[T::ZERO; MR]; NR]; + + func( + m_chunk_inner, + n_chunk_inner, + k_chunk, + tmp.as_mut_ptr() as *mut T, + packed_lhs.wrapping_add(i * packed_lhs_stride).0, + packed_rhs.wrapping_add(j * packed_rhs_stride).0, + MR as isize, + 1, + packed_lhs_cs, + packed_rhs_rs, + packed_rhs_cs, + T::ZERO, + beta, + 0, + false, + false, + false, + packed_lhs.wrapping_add((i + 1) * packed_lhs_stride).0, + ); + + match alpha_status { + 0 => { + for j in 0..n_chunk_inner { + for i in 0..m_chunk_inner { + let dst = dst + .wrapping_offset(j as isize * dst_cs) + .wrapping_offset(i as isize * dst_rs) + .0; + *dst = tmp[j][i]; + } + } + } + 1 => { + for j in 0..n_chunk_inner { + for i in 0..m_chunk_inner { + let dst = dst + .wrapping_offset(j as isize * dst_cs) + .wrapping_offset(i as isize * dst_rs) + .0; + *dst = (*dst) + tmp[j][i]; + } + } + } + _ => { + for j in 0..n_chunk_inner { + for i in 0..m_chunk_inner { + let dst = dst + .wrapping_offset(j as isize * dst_cs) + .wrapping_offset(i as isize * dst_rs) + .0; + *dst = + alpha * (*dst) + tmp[j][i] + ; + } + } + } + } + + i += 1; + } + j += 1; + } + + row_outer += m_chunk; + } + }); + }; + + match parallelism { + Parallelism::None => func(0), + Parallelism::Rayon(_) => { + if n_threads == 1 { + func(0); + } else { + par_for_each(n_threads, func); + } + } + } + + alpha = T::ONE; + depth_outer += k_chunk; + } + col_outer += n_chunk; + } +} + + pub mod f16 { use super::gemm_basic_generic; use gemm_common::Parallelism; @@ -555,6 +1063,7 @@ pub mod f16 { } } + #[cfg(target_arch = "aarch64")] { if gemm_common::feature_detected!("neon") { @@ -626,8 +1135,8 @@ pub mod f16 { #[cfg(target_arch = "aarch64")] mod neon { use super::*; - use gemm_f32::microkernel::neon::f32::*; - const N: usize = 4; + use crate::microkernel::neon::f16::{MR_DIV_N, NR, UKR}; + const N: usize = 8; #[inline(never)] pub unsafe fn gemm_basic( @@ -651,7 +1160,7 @@ pub mod f16 { _conj_rhs: bool, parallelism: gemm_common::Parallelism, ) { - gemm_basic_generic::( + crate::gemm::gemm_basic_f16::( m, n, k, diff --git a/gemm-f16/src/lib.rs b/gemm-f16/src/lib.rs index 8d6b2f4..fe48aa8 100644 --- a/gemm-f16/src/lib.rs +++ b/gemm-f16/src/lib.rs @@ -1,4 +1,8 @@ #![cfg_attr(feature = "nightly", feature(stdsimd), feature(avx512_target_feature))] pub mod gemm; +pub mod microkernel; pub use half::f16; + +#[macro_use] +extern crate gemm_common; diff --git a/gemm-f16/src/microkernel.rs b/gemm-f16/src/microkernel.rs index e1356f1..f45bafd 100644 --- a/gemm-f16/src/microkernel.rs +++ b/gemm-f16/src/microkernel.rs @@ -343,4 +343,137 @@ pub mod neon { [x3x1, x3x2, x3x3, x3x4, x3x5, x3x6, x3x7, x3x8,], } } + + pub mod f16 { + use core::mem::transmute; + use core::{ + arch::{ + aarch64::uint16x8_t, + asm, + }, + }; + + pub type T = half::f16; + pub const N: usize = 8; + pub type Pack = [T; N]; + + #[allow(non_camel_case_types)] + type float16x8_t = uint16x8_t; + + /// Floating point multiplication + /// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMUL--vector-) + #[inline] + pub unsafe fn vmulq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { + let result: float16x8_t; + asm!( + "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h", + out(vreg) result, + in(vreg) a, + in(vreg) b, + options(pure, nomem, nostack, preserves_flags)); + result + } + + /// Floating point addition + /// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FADD--vector-) + #[inline] + pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { + let result: float16x8_t; + asm!( + "fadd {0:v}.8h, {1:v}.8h, {2:v}.8h", + out(vreg) result, + in(vreg) a, + in(vreg) b, + options(pure, nomem, nostack, preserves_flags)); + result + } + + /// Fused multiply add [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMLA--vector-) + #[inline] + pub unsafe fn vfmaq_f16(mut a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ + asm!( + "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h", + inout(vreg) a, + in(vreg) b, + in(vreg) c, + options(nomem, nostack, preserves_flags)); + a + } + + pub unsafe fn vget_lane_f16(a: float16x8_t) -> u16 { + let mut result: u16 = 0; + let a: *const u16 = transmute(&a as *const float16x8_t); + std::ptr::copy_nonoverlapping(a.add(LANE as usize), &mut result as *mut u16, 1 ); + result + } + + #[inline] + pub unsafe fn vfmaq_laneq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t { + let c = vget_lane_f16::(c); + let result = core::mem::transmute([c, c, c, c, c, c, c, c]); + vfmaq_f16(a, b, result) + } + + #[inline(always)] + pub unsafe fn splat(value: T) -> Pack { + [value, value, value, value, value, value, value, value] + } + + #[inline(always)] + pub unsafe fn mul(lhs: Pack, rhs: Pack) -> Pack { + transmute(vmulq_f16(transmute(lhs), transmute(rhs))) + } + + #[inline(always)] + pub unsafe fn add(lhs: Pack, rhs: Pack) -> Pack { + transmute(vaddq_f16(transmute(lhs), transmute(rhs))) + } + + #[inline(always)] + pub unsafe fn mul_add(a: Pack, b: Pack, c: Pack) -> Pack { + transmute(vfmaq_f16(transmute(c), transmute(a), transmute(b))) + } + + #[inline(always)] + pub unsafe fn mul_add_lane(a: Pack, b: Pack, c: Pack) -> Pack { + transmute(vfmaq_laneq_f16::( + transmute(c), + transmute(a), + transmute(b), + )) + } + + microkernel_f16!(["neon"], 2, x1x1, 1, 1); + microkernel_f16!(["neon"], 2, x1x2, 1, 2); + microkernel_f16!(["neon"], 2, x1x3, 1, 3); + microkernel_f16!(["neon"], 2, x1x4, 1, 4, 1, 4); + microkernel_f16!(["neon"], 2, x1x5, 1, 5); + microkernel_f16!(["neon"], 2, x1x6, 1, 6); + microkernel_f16!(["neon"], 2, x1x7, 1, 7); + microkernel_f16!(["neon"], 2, x1x8, 1, 8, 2, 4); + + microkernel_f16!(["neon"], 2, x2x1, 2, 1); + microkernel_f16!(["neon"], 2, x2x2, 2, 2); + microkernel_f16!(["neon"], 2, x2x3, 2, 3); + microkernel_f16!(["neon"], 2, x2x4, 2, 4, 1, 4); + microkernel_f16!(["neon"], 2, x2x5, 2, 5); + microkernel_f16!(["neon"], 2, x2x6, 2, 6); + microkernel_f16!(["neon"], 2, x2x7, 2, 7); + microkernel_f16!(["neon"], 2, x2x8, 2, 8, 2, 4); + + microkernel_f16!(["neon"], 2, x3x1, 3, 1); + microkernel_f16!(["neon"], 2, x3x2, 3, 2); + microkernel_f16!(["neon"], 2, x3x3, 3, 3); + microkernel_f16!(["neon"], 2, x3x4, 3, 4, 1, 4); + microkernel_f16!(["neon"], 2, x3x5, 3, 5); + microkernel_f16!(["neon"], 2, x3x6, 3, 6); + microkernel_f16!(["neon"], 2, x3x7, 3, 7); + microkernel_f16!(["neon"], 2, x3x8, 3, 8, 2, 4); + + microkernel_fn_array! { + [x1x1, x1x2, x1x3, x1x4, x1x5, x1x6, x1x7, x1x8,], + [x2x1, x2x2, x2x3, x2x4, x2x5, x2x6, x2x7, x2x8,], + [x3x1, x3x2, x3x3, x3x4, x3x5, x3x6, x3x7, x3x8,], + } + } } diff --git a/gemm/benches/bench.rs b/gemm/benches/bench.rs index 415cf26..935a1d2 100644 --- a/gemm/benches/bench.rs +++ b/gemm/benches/bench.rs @@ -257,7 +257,8 @@ pub fn criterion_benchmark(c: &mut Criterion) { } pub fn criterion_benchmark_parallelism(c: &mut Criterion) { - let mnks = vec![(6, 768 * 3, 768)]; + // let mnks = vec![(6, 768 * 3, 768)]; + let mnks = vec![(4096, 128, 11108)]; // let mut push = |m, n, k| { // mnks.push((m, n, k)); // }; @@ -294,7 +295,7 @@ pub fn criterion_benchmark_parallelism(c: &mut Criterion) { for (rhs_label, rhs_cs, rhs_rs) in [("n", k, 1), ("t", 1, n)] { c.bench_function( &format!( - "parallelism-{}-f32-{}{}{}-gemm-{}×{}×{}", + "parallelism-f32-{}-{}{}{}-gemm-{}×{}×{}", n_cpus, dst_label, lhs_label, rhs_label, m, n, k ), |b| { @@ -325,7 +326,7 @@ pub fn criterion_benchmark_parallelism(c: &mut Criterion) { ); c.bench_function( &format!( - "parallelism-none-f32-{}{}{}-gemm-{}×{}×{}", + "parallelism-f32-none-{}{}{}-gemm-{}×{}×{}", dst_label, lhs_label, rhs_label, m, n, k ), |b| { @@ -358,6 +359,83 @@ pub fn criterion_benchmark_parallelism(c: &mut Criterion) { } } } + + let n_cpus = num_cpus::get(); + + for (m, n, k) in mnks.iter().copied() { + let a_vec = vec![f16::from_f32(0.0); m * k]; + let b_vec = vec![f16::from_f32(0.0); k * n]; + let mut c_vec = vec![f16::from_f32(0.0); m * n]; + + for (dst_label, dst_cs, dst_rs) in [("n", m, 1), ("t", 1, n)] { + for (lhs_label, lhs_cs, lhs_rs) in [("n", m, 1), ("t", 1, k)] { + for (rhs_label, rhs_cs, rhs_rs) in [("n", k, 1), ("t", 1, n)] { + c.bench_function( + &format!( + "parallelism-f16-{}-{}{}{}-gemm-{}×{}×{}", + n_cpus, dst_label, lhs_label, rhs_label, m, n, k + ), + |b| { + b.iter(|| unsafe { + gemm( + m, + n, + k, + c_vec.as_mut_ptr(), + dst_cs as isize, + dst_rs as isize, + true, + a_vec.as_ptr(), + lhs_cs as isize, + lhs_rs as isize, + b_vec.as_ptr(), + rhs_cs as isize, + rhs_rs as isize, + f16::from_f32(0.0), + f16::from_f32(0.0), + false, + false, + false, + gemm::Parallelism::Rayon(n_cpus), + ) + }) + }, + ); + c.bench_function( + &format!( + "parallelism-f16-none-{}{}{}-gemm-{}×{}×{}", + dst_label, lhs_label, rhs_label, m, n, k + ), + |b| { + b.iter(|| unsafe { + gemm( + m, + n, + k, + c_vec.as_mut_ptr(), + dst_cs as isize, + dst_rs as isize, + true, + a_vec.as_ptr(), + lhs_cs as isize, + lhs_rs as isize, + b_vec.as_ptr(), + rhs_cs as isize, + rhs_rs as isize, + f16::from_f32(0.0), + f16::from_f32(0.0), + false, + false, + false, + gemm::Parallelism::None, + ) + }) + }, + ); + } + } + } + } } criterion_group!(