From 5c1e646050a283a69b2818acdfe9c52473757a9c Mon Sep 17 00:00:00 2001 From: Douglas Bates Date: Fri, 18 Jan 2013 09:29:02 -0500 Subject: [PATCH] Chain methods for gemm and gemv to gemm! and gemv! Add gemm and gemv methods that have an implicit 1.0 multiplier --- base/blas.jl | 109 ++++++++++++++++++++++++++++----------------------- 1 file changed, 59 insertions(+), 50 deletions(-) diff --git a/base/blas.jl b/base/blas.jl index 35cf99df810a5..a9830a3eff608 100644 --- a/base/blas.jl +++ b/base/blas.jl @@ -290,24 +290,34 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32), end end -# (GE) general matrix-matrix multiplication -# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) -# * .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER K,LDA,LDB,LDC,M,N -# CHARACTER TRANSA,TRANSB -# * .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) -for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32), - (:zgemm_,:Complex128), (:cgemm_,:Complex64)) +# (GE) general matrix-matrix and matrix-vector multiplication +for (gemm, gemv, elty) in + ((:dgemm_,:dgemv_,:Float64), + (:sgemm_,:sgemv_,:Float32), + (:zgemm_,:zgemv_,:Complex128), + (:cgemm_,:cgemv_,:Complex64)) @eval begin - function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, - B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty}) + # SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) + # * .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER K,LDA,LDB,LDC,M,N + # CHARACTER TRANSA,TRANSB + # * .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) + function gemm!(transA::BlasChar, transB::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}, + beta::($elty), C::StridedMatrix{$elty}) +# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1) +# error("gemm!: BLAS module requires contiguous matrix columns") +# end # should this be checked on every call? m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end - ccall(($(string(fname)),libblas), Void, + if m != size(C,1) || n != size(C,2) + error("gemm!: mismatched dimensions") + end + ccall(($(string(gemm)),libblas), Void, (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), @@ -315,47 +325,46 @@ for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32), B, &stride(B,2), &beta, C, &stride(C,2)) C end - function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) - m = size(A, transA == 'N' ? 1 : 2) - k = size(A, transA == 'N' ? 2 : 1) - if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end - n = size(B, transB == 'N' ? 2 : 1) - C = Array($elty, (m, n)) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2), - B, &stride(B,2), &0., C, &stride(C,2)) - C + function gemm(transA::BlasChar, transB::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}) + gemm!(transA, transB, alpha, A, B, zero($elty), + Array($elty, (size(A, transA == 'N' ? 1 : 2), + size(B, transB == 'N' ? 2 : 1)))) end - end -end - -#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) -#* .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER INCX,INCY,LDA,M,N -# CHARACTER TRANS -#* .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),X(*),Y(*) - -for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32), - (:zgemv_,:Complex128), (:cgemv_,:Complex64)) - @eval begin - function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, - X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty}) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), + function gemm(transA::BlasChar, transB::BlasChar, + A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) + gemm(transA, transB, one($elty), A, B) + end + #SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) + #* .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER INCX,INCY,LDA,M,N + # CHARACTER TRANS + #* .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),X(*),Y(*) + function gemv!(trans::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + X::StridedVector{$elty}, + beta::($elty), Y::StridedVector{$elty}) + ccall(($(string(gemv)),libblas), Void, + (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), &trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1)) Y end - function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty}) - Y = Array($elty, size(A,1)) - gemv!(trans, alpha, A, X, zero($elty), Y) - Y + function gemv(trans::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + X::StridedVector{$elty}) + gemv!(trans, alpha, A, X, zero($elty), + Array($elty, size(A, (trans == 'N' ? 1 : 2)))) + end + function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty}) + gemv!(trans, one($elty), A, X, zero($elty), + Array($elty, size(A, (trans == 'N' ? 1 : 2)))) end end end