Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BLAS C bindings #37

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions faer-libs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"faer-svd",
"faer-evd",
"faer-sparse",
"faer-blas",

"faer",
]
Expand Down
13 changes: 13 additions & 0 deletions faer-libs/faer-blas/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "faer-blas"
version = "0.16.0"
edition = "2021"

[lib]
crate-type = ["cdylib"]

[dependencies]
faer-core = { version = "0.16.0", path = "../faer-core" }
gemm = "0.17"
num-traits = "0.2.15"
paste = "1.0.12"
2 changes: 2 additions & 0 deletions faer-libs/faer-blas/examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.o
gemm
9 changes: 9 additions & 0 deletions faer-libs/faer-blas/examples/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
gemm: gemm.o
$(CC) -L../../target/release -lfaer_blas -o $@ gemm.o

gemm.o: gemm.c
$(CC) -c -O2 -I../include -o $@ gemm.c

clean:
rm -f *.o
rm -f gemm
33 changes: 33 additions & 0 deletions faer-libs/faer-blas/examples/gemm.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// From gsl/doc/examples/cblas.c
#include <stdio.h>
#include "cblas.h"

int main(void)
{
int lda = 3;

float A[] = {0.11, 0.12, 0.13,
0.21, 0.22, 0.23};

int ldb = 2;

float B[] = {1011, 1012,
1021, 1022,
1031, 1032};

int ldc = 2;

float C[] = {0.00, 0.00,
0.00, 0.00};

/* Compute C = A B */

cblas_sgemm(CblasRowMajor,
CblasNoTrans, CblasNoTrans, 2, 2, 3,
1.0, A, lda, B, ldb, 0.0, C, ldc);

printf("[ %g, %g\n", C[0], C[1]);
printf(" %g, %g ]\n", C[2], C[3]);

return 0;
}
614 changes: 614 additions & 0 deletions faer-libs/faer-blas/include/cblas.h

Large diffs are not rendered by default.

173 changes: 173 additions & 0 deletions faer-libs/faer-blas/src/conversions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use faer_core::{Entity, MatMut, MatRef, mat::{from_raw_parts, from_raw_parts_mut}, GroupFor};

use crate::{CblasInt, CblasLayout};

#[inline(always)]
pub unsafe fn from_blas<'a, E: Entity>(
layout: CblasLayout,
ptr: GroupFor<E, *const E::Unit>,
nrows: CblasInt,
ncols: CblasInt,
leading_dim: CblasInt,
) -> MatRef<'a, E> {
let stride = Stride::from_leading_dim(layout, leading_dim);

from_raw_parts(
ptr,
nrows as usize,
ncols as usize,
stride.row as isize,
stride.col as isize,
)
}

#[inline(always)]
pub unsafe fn from_blas_mut<'a, E: Entity>(
layout: CblasLayout,
ptr: GroupFor<E, *mut E::Unit>,
nrows: CblasInt,
ncols: CblasInt,
leading_dim: CblasInt,
) -> MatMut<'a, E> {
let stride = Stride::from_leading_dim(layout, leading_dim);
from_raw_parts_mut(
ptr,
nrows as usize,
ncols as usize,
stride.row as isize,
stride.col as isize,
)
}

#[inline(always)]
pub unsafe fn from_blas_vec<'a, E: Entity>(
ptr: GroupFor<E, *const E::Unit>,
n: CblasInt,
inc: CblasInt,
) -> MatRef<'a, E> {
from_raw_parts(ptr, n as usize, 1, inc as isize, 0)
}

#[inline(always)]
pub unsafe fn from_blas_vec_mut<'a, E: Entity>(
ptr: GroupFor<E, *mut E::Unit>,
n: CblasInt,
inc: CblasInt,
) -> MatMut<'a, E> {
from_raw_parts_mut(ptr, n as usize, 1, inc as isize, 0)
}

#[derive(Debug, Clone, Copy)]
pub struct Stride {
pub row: isize,
pub col: isize,
}
impl Stride {
#[inline(always)]
pub fn from_leading_dim(layout: CblasLayout, leading_dim: CblasInt) -> Self {
match layout {
CblasLayout::RowMajor => Self {
col: 1,
row: leading_dim as isize,
},
CblasLayout::ColMajor => Self {
col: leading_dim as isize,
row: 1,
},
}
}

#[inline(always)]
pub fn transposed(self) -> Self {
Self {
row: self.col,
col: self.row,
}
}
}

#[cfg(test)]
mod tests {
use crate::impls::CblasLayout;

use super::{from_blas, from_blas_vec};

#[test]
fn test_row_major() {
/*
| 0.11 0.12 0.13 |
| 0.21 0.22 0.23 |
In row major order
*/
let m = 2;
let n = 3;
let lda = 3;
let a: [f64; 6] = [0.11, 0.12, 0.13, 0.21, 0.22, 0.23];
let result = unsafe { from_blas::<f64>(CblasLayout::RowMajor, a.as_ptr(), m, n, lda) };
assert_eq!(result.nrows(), 2);
assert_eq!(result.ncols(), 3);
assert_eq!(*result.get(0, 0), 0.11);
assert_eq!(*result.get(0, 2), 0.13);
assert_eq!(*result.get(1, 2), 0.23);
}

#[test]
fn test_col_major() {
/*
| 0.11 0.12 0.13 |
| 0.21 0.22 0.23 |
In col major order
*/
let m = 2;
let n = 3;
let lda = 2;
let a: [f64; 6] = [0.11, 0.21, 0.12, 0.22, 0.13, 0.23];
let result = unsafe { from_blas::<f64>(CblasLayout::ColMajor, a.as_ptr(), m, n, lda) };
assert_eq!(result.nrows(), 2);
assert_eq!(result.ncols(), 3);
assert_eq!(*result.get(0, 0), 0.11);
assert_eq!(*result.get(0, 2), 0.13);
assert_eq!(*result.get(1, 2), 0.23);
}

#[test]
fn test_mat_excess_storage() {
/*
| 0.11 0.12 0.13 | 0.0 0.0
| 0.21 0.22 0.23 | 0.0 0.0
In row major order, where 0s are not part of the matrix
*/
let m = 2;
let n = 3;
let lda = 5;
let a: [f64; 10] = [0.11, 0.12, 0.13, 0.0, 0.0, 0.21, 0.22, 0.23, 0.0, 0.0];
let result = unsafe { from_blas::<f64>(CblasLayout::RowMajor, a.as_ptr(), m, n, lda) };
assert_eq!(result.nrows(), 2);
assert_eq!(result.ncols(), 3);
assert_eq!(*result.get(0, 0), 0.11);
assert_eq!(*result.get(0, 2), 0.13);
assert_eq!(*result.get(1, 2), 0.23);
}

#[test]
fn test_vec() {
/*
[ 0.1 0.2 0.3 ]
*/
let n = 3;
let xinc = 1;
let x: [f64; 3] = [0.1, 0.2, 0.3];
let result = unsafe { from_blas_vec::<f64>(x.as_ptr(), n, xinc) };
assert_eq!(*result.get(2, 0), 0.3);

/*
[ 0.1 /0.0/ 0.2 /0.0/ 0.3 ]
where 0s are excess storage
*/
let n = 3;
let xinc = 2;
let x_excess: [f64; 5] = [0.1, 0.0, 0.2, 0.0, 0.3];
let result_excess = unsafe { from_blas_vec::<f64>(x_excess.as_ptr(), n, xinc) };
assert_eq!(result, result_excess);
}
}
Loading