Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

X^T*B doesn't call BLAS gemm #1278

Closed
fardream opened this issue Apr 6, 2023 · 2 comments · Fixed by #1419
Closed

X^T*B doesn't call BLAS gemm #1278

fardream opened this issue Apr 6, 2023 · 2 comments · Fixed by #1419
Assignees

Comments

@fardream
Copy link

fardream commented Apr 6, 2023

For matrix $A$, $B$, calculating $A^TB$ should be able to use gemm routine, however, it doesn't

let a = Array2::<f64>::zeros((10000, 1000));
let b = Array2::<f64>::zeros((10000, 1000));
let _ = a.t().dot(&b);

However, if $A^T$ is provided directly, it will call gemm

let b = Array2::<f64>::zeros((10000, 1000));
let at = Array2::<f64>::zeros((1000, 10000));
let _ = at.dot(&b);

See code example here

https://github.com/fardream/rust-ndarray-t-dot

Related to #445

@bluss
Copy link
Member

bluss commented Apr 23, 2023

This code here needs to change

// Use `c` for c-order and `f` for an f-order matrix
// We can handle c * c, f * f generally and
// c * f and f * c if the `f` matrix is square.
let mut lhs_ = lhs.view();
let mut rhs_ = rhs.view();
let mut c_ = c.view_mut();
let lhs_s0 = lhs_.strides()[0];
let rhs_s0 = rhs_.strides()[0];
let both_f = lhs_s0 == 1 && rhs_s0 == 1;
let mut lhs_trans = CblasNoTrans;
let mut rhs_trans = CblasNoTrans;
if both_f {
// A^t B^t = C^t => B A = C
let lhs_t = lhs_.reversed_axes();
lhs_ = rhs_.reversed_axes();
rhs_ = lhs_t;
c_ = c_.reversed_axes();
swap(&mut m, &mut n);
} else if lhs_s0 == 1 && m == a {
lhs_ = lhs_.reversed_axes();
lhs_trans = CblasTrans;
} else if rhs_s0 == 1 && a == n {
rhs_ = rhs_.reversed_axes();
rhs_trans = CblasTrans;
}

It needs to be rewritten to be more general. It has a comment there that I guess explains why it doesn't cover this case right now.

ndarray arrays can have more general strides than blas can handle, so there will always be arrays that can't be passed to blas, so ndarray needs to examine the arguments and figure out if and how the arrays can be used with blas.

In your example The ATB product comes in with the operands not square and the first operand in column major layout and the second in row major layout. The impl just needs to figure that out and how to call blas with it. I wonder if the check for square dimensions can be removed, not sure why it's there.

@bluss bluss self-assigned this Mar 10, 2024
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
@bluss
Copy link
Member

bluss commented Aug 7, 2024

Thanks for the good test case. I've verified that it's fixed in your test case.
BLAS does better than the matrixmultiply fallback, but the matrixmultiply fallback can do well if feature matrixmultiply-threading is enabled.

bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
bluss added a commit that referenced this issue Aug 7, 2024
We compute A B -> C with matrices A, B, C

With the blas (cblas) interface it supports matrices that adhere to
certain criteria. They should be contiguous on one dimension (stride=1).

We glance a little at how numpy does this to try to catch all cases.

In short, we accept A, B contiguous on either axis (row or column
major). We use the case where C is (weakly) row major, but if it is
column major we transpose A, B, C => A^t, B^t, C^t so that we are back
to the C row major case.

(Weakly = contiguous with stride=1 on that inner dimension, but stride
for the other dimension can be larger; to differentiate from strictly
whole array contiguous.)

Minor change to the gemv function, no functional change, only updating
due to the refactoring of blas layout functions.

Fixes #1278
@bluss bluss closed this as completed in 27e347c Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants