Skip to content

Commit

Permalink
with packing, up to 2.5GF
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 6, 2022
1 parent 6f34d8a commit 625fae1
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 37 deletions.
1 change: 1 addition & 0 deletions linalg/matmul-bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ matrixmultiply = "*"
opencl3 = { version = "0.8.1", optional = true }
lazy_static = "1.4.0"
paste = "1.0.5"
itertools = "0.10.3"

[features]
default = [ ]
Expand Down
1 change: 1 addition & 0 deletions linalg/matmul-bench/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ fn matmul(crit: &mut Criterion, m: usize, k: usize, n: usize) {
{
b!(opencl_gemm1);
b!(opencl_gemm_1_with_local_2x2, Some((2, 2)));
b!(opencl_gemm_2_pack, Some((4,4)));
}
tract_blaslike(&mut crit, m, k, n, f32::datum_type());
tract_blaslike(&mut crit, m, k, n, f16::datum_type());
Expand Down
64 changes: 34 additions & 30 deletions linalg/matmul-bench/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,41 +456,41 @@ pub fn tract(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32])
}
}

#[cfg(test)]
mod test {
use super::*;

pub fn pack_a(a: &[f32], m: usize, k: usize, r: usize) -> Vec<f32> {
let panels = m.divceil(r);
let mut pa = vec![0f32; m * k];
for p in 0..panels {
for ik in 0..k {
for ir in 0..r {
let row = p * r + ir;
let col = ik;
let v = a[row * k + col];
pa[p * k * r + ik * r + ir] = v;
}
pub fn pack_a(a: &[f32], m: usize, k: usize, r: usize) -> Vec<f32> {
let panels = m.divceil(r);
let mut pa = vec![0f32; m * k];
for p in 0..panels {
for ik in 0..k {
for ir in 0..r {
let row = p * r + ir;
let col = ik;
let v = a[row * k + col];
pa[p * k * r + ik * r + ir] = v;
}
}
pa
}
pa
}

pub fn pack_b(b: &[f32], k: usize, n: usize, r: usize) -> Vec<f32> {
let panels = n.divceil(r);
let mut pb = vec![0f32; k * n];
for p in 0..panels {
for ik in 0..k {
for ir in 0..r {
let row = ik;
let col = p * r + ir;
let v = b[row * n + col];
pb[p * k * r + ik * r + ir] = v;
}
pub fn pack_b(b: &[f32], k: usize, n: usize, r: usize) -> Vec<f32> {
let panels = n.divceil(r);
let mut pb = vec![0f32; k * n];
for p in 0..panels {
for ik in 0..k {
for ir in 0..r {
let row = ik;
let col = p * r + ir;
let v = b[row * n + col];
pb[p * k * r + ik * r + ir] = v;
}
}
pb
}
pb
}

#[cfg(test)]
mod test {
use super::*;

#[macro_export]
macro_rules! t {
Expand All @@ -517,10 +517,14 @@ mod test {
}
}
if let Some(r) = $pack {
a = $crate::test::pack_a(&*a, m, k, r);
b = $crate::test::pack_b(&*b, k, n, r);
a = $crate::pack_a(&*a, m, k, r);
b = $crate::pack_b(&*b, k, n, r);
}
$id(m, k, n, &a, &b, &mut found);
for im in 0..m {
eprint!("{} | ", found[im * n..][..n].iter().map(|x| format!("{:6}", x)).collect::<String>());
eprintln!("{}", expected[im * n..][..n].iter().map(|x| format!("{:6}", x)).collect::<String>());
}
assert_eq!(found, expected);
}
}
Expand Down
93 changes: 86 additions & 7 deletions linalg/matmul-bench/src/opencl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,61 @@ impl Gpu {
C[m * N + n + M + 3] = acc13;
C[m * N + n + 2 * M + 3] = acc23;
C[m * N + n + 3 * M + 3] = acc33;
}"#;
}
// packed
__kernel void gemm_2(const int M, const int K, const int N,
const __global float* A,
const __global float* B,
__global float* C) {
const int m = get_global_id(0);
const int n = get_global_id(1);
#pragma promote_to_registers
float4 acc[4];
for (int i=0; i<4; i++) {
acc[i].x = 0;
acc[i].y = 0;
acc[i].z = 0;
acc[i].w = 0;
}
const __global float *pa = &A[m*K*4];
const __global float *pb = &B[n*K*4];
for (int k=0; k<K; k++) {
#pragma promote_to_registers
float4 a = vload4(k, pa);
#pragma promote_to_registers
float4 b = vload4(k, pb);
// #define mac(a, b, c) c += a * b;
#define mac(a, b, c) c = mad(a, b, c);
#pragma unroll
for (int i = 0; i<4; i++) {
float va;
switch(i) {
case 0: va = a.x; break;
case 1: va = a.y; break;
case 2: va = a.z; break;
case 3: va = a.w; break;
}
mac(va, b.x, acc[i].x)
mac(va, b.y, acc[i].y)
mac(va, b.z, acc[i].z)
mac(va, b.w, acc[i].w)
}
}
for (int i = 0; i<4; i++) {
int offset = n + i * N / 4 + m * N;
vstore4(acc[i], offset, C);
}
}
"#;

let program = Program::create_and_build_from_source(&context, kernel_cl, "").unwrap();
let kernel = Kernel::create(&program, k).expect("Kernel::create failed");
Expand All @@ -145,6 +199,12 @@ impl Gpu {
}
}

#[derive(Default)]
struct Params {
packed: bool,
local_sizes: Option<(usize, usize)>,
}

impl Gpu {
fn run(
&self,
Expand All @@ -154,18 +214,23 @@ impl Gpu {
a: &[f32],
b: &[f32],
c: &mut [f32],
local_sizes: Option<(usize, usize)>,
params: Params,
) -> Result<(), ClError> {
let mut a_cl =
Buffer::<cl_float>::create(&self.context, CL_MEM_READ_ONLY, m * k, null_mut())?;
let mut b_cl =
Buffer::<cl_float>::create(&self.context, CL_MEM_READ_ONLY, k * n, null_mut())?;

let packed_a = crate::pack_a(a, m, k, self.mr);
let packed_b = crate::pack_b(b, k, n, self.nr);

let (pa, pb) = if params.packed { (&*packed_a, &*packed_b) } else { (a, b) };

let mut c_cl =
Buffer::<cl_float>::create(&self.context, CL_MEM_READ_WRITE, m * n, null_mut())?;

let write_a = self.queue.enqueue_write_buffer(&mut a_cl, CL_NON_BLOCKING, 0, a, &[])?;
let write_b = self.queue.enqueue_write_buffer(&mut b_cl, CL_NON_BLOCKING, 0, b, &[])?;
let write_a = self.queue.enqueue_write_buffer(&mut a_cl, CL_NON_BLOCKING, 0, pa, &[])?;
let write_b = self.queue.enqueue_write_buffer(&mut b_cl, CL_NON_BLOCKING, 0, pb, &[])?;

let mut run = ExecuteKernel::new(&self.kernel);
run.set_arg(&(m as i32))
Expand All @@ -176,7 +241,7 @@ impl Gpu {
.set_arg(&c_cl)
.set_global_work_sizes(&[m / self.mr, n / self.nr])
.set_event_wait_list(&[write_a.get(), write_b.get()]);
if let Some((mr, nr)) = local_sizes {
if let Some((mr, nr)) = params.local_sizes {
run.set_local_work_sizes(&[mr, nr]);
}
let run = run.enqueue_nd_range(&self.queue).unwrap();
Expand Down Expand Up @@ -204,10 +269,11 @@ mod kernels {
}

kernel!(gemm_1, 4, 4);
kernel!(gemm_2, 4, 4);
}

pub fn opencl_gemm1(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, None).unwrap();
kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, Params::default()).unwrap();
}

pub fn opencl_gemm_1_with_local_2x2(
Expand All @@ -218,7 +284,19 @@ pub fn opencl_gemm_1_with_local_2x2(
b: &[f32],
c: &mut [f32],
) {
kernels::gemm_1.lock().unwrap().run(m, k, n, a, b, c, Some((2, 2))).unwrap();
kernels::gemm_1
.lock()
.unwrap()
.run(m, k, n, a, b, c, Params { local_sizes: Some((2, 2)), ..Params::default() })
.unwrap();
}

pub fn opencl_gemm_2_pack(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
kernels::gemm_2
.lock()
.unwrap()
.run(m, k, n, a, b, c, Params { packed: true, local_sizes: Some((2,2)), ..Params::default() })
.unwrap();
}

#[cfg(test)]
Expand All @@ -228,4 +306,5 @@ mod test {

t!(opencl_gemm1);
t!(opencl_gemm_1_with_local_2x2);
t!(opencl_gemm_2_pack);
}

0 comments on commit 625fae1

Please sign in to comment.