Skip to content

Commit

Permalink
Chain methods for gemm and gemv to gemm! and gemv!
Browse files Browse the repository at this point in the history
Add gemm and gemv methods that have an implicit 1.0 multiplier
  • Loading branch information
dmbates committed Jan 18, 2013
1 parent d088d75 commit 5c1e646
Showing 1 changed file with 59 additions and 50 deletions.
109 changes: 59 additions & 50 deletions base/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,72 +290,81 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32),
end
end

# (GE) general matrix-matrix multiplication
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32),
(:zgemm_,:Complex128), (:cgemm_,:Complex64))
# (GE) general matrix-matrix and matrix-vector multiplication
for (gemm, gemv, elty) in
((:dgemm_,:dgemv_,:Float64),
(:sgemm_,:sgemv_,:Float32),
(:zgemm_,:zgemv_,:Complex128),
(:cgemm_,:cgemv_,:Complex64))
@eval begin
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end
ccall(($(string(fname)),libblas), Void,
if m != size(C,1) || n != size(C,2)
error("gemm!: mismatched dimensions")
end
ccall(($(string(gemm)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &beta, C, &stride(C,2))
C
end
function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end
n = size(B, transB == 'N' ? 2 : 1)
C = Array($elty, (m, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &0., C, &stride(C,2))
C
function gemm(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
Array($elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
end
end

#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32),
(:zgemv_,:Complex128), (:cgemv_,:Complex64))
@eval begin
function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
function gemm(transA::BlasChar, transB::BlasChar,
A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
ccall(($(string(gemv)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2),
X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
Y = Array($elty, size(A,1))
gemv!(trans, alpha, A, X, zero($elty), Y)
Y
function gemv(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty})
gemv!(trans, alpha, A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
end
Expand Down

0 comments on commit 5c1e646

Please sign in to comment.