Skip to content

Commit

Permalink
Reinstate all the commits from #2069.
Browse files Browse the repository at this point in the history
Relevant comments are in #2062.
  • Loading branch information
ViralBShah committed Feb 4, 2013
1 parent 3f92b13 commit df8f86e
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 186 deletions.
254 changes: 124 additions & 130 deletions base/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,16 @@ function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range
return axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
end


# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * ..
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
(:zsyrk_,:Complex128), (:csyrk_,:Complex64))
@eval begin
# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
function syrk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
Expand All @@ -193,14 +191,9 @@ for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
end
function syrk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
syrk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n, n)))
end
syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = syrk(uplo, trans, one($elty), A)
end
end

Expand All @@ -214,7 +207,7 @@ end
# COMPLEX A(LDA,*),C(LDC,*)
for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
@eval begin
function herk!(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty},
function herk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
if m != n error("syrk!: matrix C must be square") end
Expand All @@ -227,16 +220,11 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2))
C
end
function herk(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty})
function herk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
herk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n,n)))
end
herk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = herk(uplo, trans, one($elty), A)
end
end

Expand Down Expand Up @@ -266,16 +254,12 @@ for (fname, elty) in ((:dgbmv_,:Float64), (:sgbmv_,:Float32),
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = stride(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &m, &n, &kl, &ku, &alpha, A, &stride(A,2),
x, &stride(x,1), &0., y, &1)
y
gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), Array($elty, n))
end
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
A::StridedMatrix{$elty}, x::StridedVector{$elty})
gbmv(trans, m, kl, ku, one($elty), A, x)
end

end
end

Expand Down Expand Up @@ -303,127 +287,109 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32),
function sbmv(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty},
x::StridedVector{$elty})
n = size(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &size(A,2), &k, &alpha, A, &stride(A,2), x, &stride(x,1), &0., y, &1)
y
sbmv!(uplo, k, alpha, A, x, zero($elty), Array($elty, n))
end
function sbmv(uplo::BlasChar, k::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty})
sbmv(uplo, k, one($elty), A, x)
end
end
end

# (GE) general matrix-matrix multiplication
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32),
(:zgemm_,:Complex128), (:cgemm_,:Complex64))
# (GE) general matrix-matrix and matrix-vector multiplication
for (gemm, gemv, elty) in
((:dgemm_,:dgemv_,:Float64),
(:sgemm_,:sgemv_,:Float32),
(:zgemm_,:zgemv_,:Complex128),
(:cgemm_,:cgemv_,:Complex64))
@eval begin
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end
ccall(($(string(fname)),libblas), Void,
if m != size(C,1) || n != size(C,2)
error("gemm!: mismatched dimensions")
end
ccall(($(string(gemm)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &beta, C, &stride(C,2))
C
end
function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end
n = size(B, transB == 'N' ? 2 : 1)
C = Array($elty, (m, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &0., C, &stride(C,2))
C
function gemm(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
Array($elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
end
end

#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32),
(:zgemv_,:Complex128), (:cgemv_,:Complex64))
@eval begin
function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
function gemm(transA::BlasChar, transB::BlasChar,
A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
ccall(($(string(gemv)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2),
X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
Y = Array($elty, size(A,1))
gemv!(trans, alpha, A, X, zero($elty), Y)
Y
function gemv(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty})
gemv!(trans, alpha, A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
end

# (SY) symmetric matrix-matrix and matrix-vector multiplication

# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)

# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (vfname, mfname, elty) in
((:dsymv_,:dsymm_,:Float64),
(:ssymv_,:ssymm_,:Float32),
(:zsymv_,:zsymm_,:Complex128),
(:csymv_,:csymm_,:Complex64))
for (mfname, vfname, elty) in
((:dsymm_,:dsymv_,:Float64),
(:ssymm_,:ssymv_,:Float32),
(:zsymm_,:zsymv_,:Complex128),
(:csymm_,:csymv_,:Complex64))
@eval begin
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(X) || m != length(Y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
symv!(uplo, alpha, A, X, zero($elty), similar(X))
end
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
side = uppercase(convert(Char, side))
# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
k, j = size(A)
if k != j error("symm!: matrix A is $k by $j but must be square") end
Expand All @@ -435,9 +401,37 @@ for (vfname, mfname, elty) in
&beta, C, &stride(C,2))
C
end
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
symm!(side, uplo, alpha, A, B, zero($elty), similar(B))
end
function symm(side::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
symm(side, uplo, one($elty), A, B)
end
# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},
beta::($elty), y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(x) || m != length(y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv!(uplo, alpha, A, x, zero($elty), similar(x))
end
function symv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv(uplo, one($elty), A, x)
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ core numbers strings unicode corelib hashing remote \
arrayops linalg fft dct sparse bitarray suitesparse arpack \
random math functional bigint bigfloat sorting \
statistics glpk linprog poly file Rmath remote zlib image \
iostring gzip integers spawn ccall parallel
iostring gzip integers spawn ccall blas parallel

$(TESTS) ::
$(QUIET_JULIA) $(JULIA_EXECUTABLE) ./runtests.jl $@
Expand Down
Loading

0 comments on commit df8f86e

Please sign in to comment.