Skip to content

Commit

Permalink
add blasfunc for appending suffix to all blas and lapack symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
tkelman committed Oct 20, 2014
1 parent 7f596d4 commit 2af8d20
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 155 deletions.
76 changes: 38 additions & 38 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module BLAS

import ..axpy!
import Base.copy!
import Base: copy!, blas_suffix

export
# Level 1
Expand Down Expand Up @@ -57,14 +57,14 @@ import ..LinAlg: BlasReal, BlasComplex, BlasFloat, BlasChar, BlasInt, blas_int,

# Level 1
## copy
for (fname, elty) in ((:dcopy_,:Float64),
for (fname, elty) in ((:dcopy_,:Float64),
(:scopy_,:Float32),
(:zcopy_,:Complex128),
(:zcopy_,:Complex128),
(:ccopy_,:Complex64))
@eval begin
# SUBROUTINE DCOPY(N,DX,INCX,DY,INCY)
function blascopy!(n::Integer, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer, DY::Union(Ptr{$elty},StridedArray{$elty}), incy::Integer)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, DX, &incx, DY, &incy)
DY
Expand All @@ -73,14 +73,14 @@ for (fname, elty) in ((:dcopy_,:Float64),
end

## scal
for (fname, elty) in ((:dscal_,:Float64),
for (fname, elty) in ((:dscal_,:Float64),
(:sscal_,:Float32),
(:zscal_,:Complex128),
(:zscal_,:Complex128),
(:cscal_,:Complex64))
@eval begin
# SUBROUTINE DSCAL(N,DA,DX,INCX)
function scal!(n::Integer, DA::$elty, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&n, &DA, DX, &incx)
DX
Expand All @@ -93,7 +93,7 @@ for (fname, elty, celty) in ((:sscal_, :Float32, :Complex64),
(:dscal_, :Float64, :Complex128))
@eval begin
function scal!(n::Integer, DA::$elty, DX::Union(Ptr{$celty},StridedArray{$celty}), incx::Integer)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{$celty}, Ptr{BlasInt}),
&(2*n), &DA, DX, &incx)
DX
Expand All @@ -102,7 +102,7 @@ for (fname, elty, celty) in ((:sscal_, :Float32, :Complex64),
end

## dot
for (fname, elty) in ((:ddot_,:Float64),
for (fname, elty) in ((:ddot_,:Float64),
(:sdot_,:Float32))
@eval begin
# DOUBLE PRECISION FUNCTION DDOT(N,DX,INCX,DY,INCY)
Expand All @@ -112,7 +112,7 @@ for (fname, elty) in ((:ddot_,:Float64),
# * .. Array Arguments ..
# DOUBLE PRECISION DX(*),DY(*)
function dot(n::Integer, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer, DY::Union(Ptr{$elty},StridedArray{$elty}), incy::Integer)
ccall(($(string(fname)),libblas), $elty,
ccall(($(blasfunc(fname)), libblas), $elty,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, DX, &incx, DY, &incy)
end
Expand All @@ -129,7 +129,7 @@ for (fname, elty) in ((:cblas_zdotc_sub,:Complex128),
# DOUBLE PRECISION DX(*),DY(*)
function dotc(n::Integer, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer, DY::Union(Ptr{$elty},StridedArray{$elty}), incy::Integer)
result = Array($elty, 1)
ccall(($(string(fname)),libblas), $elty,
ccall(($(blasfunc(fname)), libblas), $elty,
(BlasInt, Ptr{$elty}, BlasInt, Ptr{$elty}, BlasInt, Ptr{$elty}),
n, DX, incx, DY, incy, result)
result[1]
Expand All @@ -147,7 +147,7 @@ for (fname, elty) in ((:cblas_zdotu_sub,:Complex128),
# DOUBLE PRECISION DX(*),DY(*)
function dotu(n::Integer, DX::Union(Ptr{$elty},StridedArray{$elty}), incx::Integer, DY::Union(Ptr{$elty},StridedArray{$elty}), incy::Integer)
result = Array($elty, 1)
ccall(($(string(fname)),libblas), $elty,
ccall(($(blasfunc(fname)), libblas), $elty,
(BlasInt, Ptr{$elty}, BlasInt, Ptr{$elty}, BlasInt, Ptr{$elty}),
n, DX, incx, DY, incy, result)
result[1]
Expand Down Expand Up @@ -178,7 +178,7 @@ for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
@eval begin
# SUBROUTINE DNRM2(N,X,INCX)
function nrm2(n::Integer, X::Union(Ptr{$elty},StridedVector{$elty}), incx::Integer)
ccall(($(string(fname)),libblas), $ret_type,
ccall(($(blasfunc(fname)), libblas), $ret_type,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, X, &incx)
end
Expand All @@ -195,7 +195,7 @@ for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
@eval begin
# SUBROUTINE ASUM(N, X, INCX)
function asum(n::Integer, X::Union(Ptr{$elty},StridedVector{$elty}), incx::Integer)
ccall(($(string(fname)),libblas), $ret_type,
ccall(($(blasfunc(fname)), libblas), $ret_type,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, X, &incx)
end
Expand All @@ -218,7 +218,7 @@ for (fname, elty) in ((:daxpy_,:Float64),
#* .. Array Arguments ..
# DOUBLE PRECISION DX(*),DY(*)
function axpy!(n::Integer, alpha::($elty), dx::Union(Ptr{$elty}, StridedArray{$elty}), incx::Integer, dy::Union(Ptr{$elty}, StridedArray{$elty}), incy::Integer)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, &alpha, dx, &incx, dy, &incy)
dy
Expand Down Expand Up @@ -249,7 +249,7 @@ for (fname, elty) in ((:idamax_,:Float64),
(:icamax_,:Complex64))
@eval begin
function iamax(n::BlasInt, dx::Union(StridedVector{$elty}, Ptr{$elty}), incx::BlasInt)
ccall(($(string(fname)), libblas),BlasInt,
ccall(($(blasfunc(fname)), libblas),BlasInt,
(Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&n, dx, &incx)
end
Expand All @@ -275,7 +275,7 @@ for (fname, elty) in ((:dgemv_,:Float64),
function gemv!(trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
m,n = size(A,1),size(A,2)
length(X) == (trans == 'N' ? n : m) && length(Y) == (trans == 'N' ? m : n) || throw(DimensionMismatch(""))
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
Expand All @@ -294,9 +294,9 @@ for (fname, elty) in ((:dgemv_,:Float64),
end

### (GB) general banded matrix-vector multiplication
for (fname, elty) in ((:dgbmv_,:Float64),
for (fname, elty) in ((:dgbmv_,:Float64),
(:sgbmv_,:Float32),
(:zgbmv_,:Complex128),
(:zgbmv_,:Complex128),
(:cgbmv_,:Complex64))
@eval begin
# SUBROUTINE DGBMV(TRANS,M,N,KL,KU,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
Expand All @@ -307,7 +307,7 @@ for (fname, elty) in ((:dgbmv_,:Float64),
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gbmv!(trans::BlasChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}, beta::($elty), y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Expand Down Expand Up @@ -346,7 +346,7 @@ for (fname, elty) in ((:dsymv_,:Float64),
m, n = size(A)
if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
if m != length(x) || m != length(y) throw(DimensionMismatch("")) end
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -375,7 +375,7 @@ for (fname, elty) in ((:zhemv_,:Complex128),
lda = max(1, stride(A, 2))
incx = stride(x, 1)
incy = stride(y, 1)
ccall(($fname, libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
Expand All @@ -394,9 +394,9 @@ for (fname, elty) in ((:zhemv_,:Complex128),
end

### sbmv, (SB) symmetric banded matrix-vector multiplication
for (fname, elty) in ((:dsbmv_,:Float64),
for (fname, elty) in ((:dsbmv_,:Float64),
(:ssbmv_,:Float32),
(:zsbmv_,:Complex128),
(:zsbmv_,:Complex128),
(:csbmv_,:Complex64))
@eval begin
# SUBROUTINE DSBMV(UPLO,N,K,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
Expand All @@ -407,7 +407,7 @@ for (fname, elty) in ((:dsbmv_,:Float64),
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function sbmv!(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}, beta::($elty), y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -443,7 +443,7 @@ for (fname, elty) in ((:dtrmv_,:Float64),
if n != length(x)
throw(DimensionMismatch("length(x)=$(length(x))does not match size(A)=$(size(A))"))
end
ccall(($(string(fname)), libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &diag, &n,
Expand All @@ -470,7 +470,7 @@ for (fname, elty) in ((:dtrsv_,:Float64),
function trsv!(uplo::Char, trans::Char, diag::Char, A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = chksquare(A)
n==length(x) || throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
ccall(($(string(fname)), libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &diag, &n,
Expand All @@ -493,7 +493,7 @@ for (fname, elty) in ((:dger_,:Float64),
m, n = size(A)
m == length(x) || throw(DimensionMismatch(""))
n == length(y) || throw(DimensionMismatch(""))
ccall(($(string(fname)), libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}),
Expand All @@ -514,7 +514,7 @@ for (fname, elty) in ((:dsyr_,:Float64),
function syr!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
ccall(($(string(fname)), libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &α, x,
Expand All @@ -531,7 +531,7 @@ for (fname, elty) in ((:zher_,:Complex128),
function her!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
length(x) == A || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
ccall(($(string(fname)), libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &α, x,
Expand Down Expand Up @@ -566,7 +566,7 @@ for (gemm, elty) in
if m != size(C,1) || n != size(C,2)
throw(DimensionMismatch(""))
end
ccall(($(string(gemm)),libblas), Void,
ccall(($(blasfunc(gemm)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Expand Down Expand Up @@ -603,7 +603,7 @@ for (mfname, elty) in ((:dsymm_,:Float64),
m, n = size(C)
j = chksquare(A)
if j != (side == 'L' ? m : n) || size(B,2) != n throw(DimensionMismatch("")) end
ccall(($(string(mfname)),libblas), Void,
ccall(($(blasfunc(mfname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -641,7 +641,7 @@ for (fname, elty) in ((:dsyrk_,:Float64),
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("syrk!")) end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -674,7 +674,7 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
n = chksquare(C)
n == size(A, trans == 'N' ? 1 : 2) || throw(DimensionMismatch("herk!"))
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -713,7 +713,7 @@ for (fname, elty) in ((:dsyr2k_,:Float64),
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("syr2k!")) end
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
Expand Down Expand Up @@ -749,7 +749,7 @@ for (fname, elty1, elty2) in ((:zher2k_,:Complex128,:Float64), (:cher2k_,:Comple
n = chksquare(C)
n == size(A, trans == 'N' ? 1 : 2) || throw(DimensionMismatch("her2k!"))
k = size(A, trans == 'N' ? 2 : 1)
ccall(($(string(fname)),libblas), Void,
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty1}, Ptr{$elty1}, Ptr{BlasInt}, Ptr{$elty1}, Ptr{BlasInt},
Ptr{$elty2}, Ptr{$elty1}, Ptr{BlasInt}),
Expand Down Expand Up @@ -785,7 +785,7 @@ for (mmname, smname, elty) in
m, n = size(B)
nA = chksquare(A)
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
ccall(($(string(mmname)), libblas), Void,
ccall(($(blasfunc(mmname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
&side, &uplo, &transa, &diag, &m, &n,
Expand All @@ -808,7 +808,7 @@ for (mmname, smname, elty) in
m, n = size(B)
k = chksquare(A)
k==(side == 'L' ? m : n) || throw(DimensionMismatch("size of A is $n, size(B)=($m,$n) and transa='$transa'"))
ccall(($(string(smname)), libblas), Void,
ccall(($(blasfunc(smname)), libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8},
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
Expand Down
Loading

0 comments on commit 2af8d20

Please sign in to comment.