Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: linalg cleanups #2069

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 124 additions & 130 deletions base/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,16 @@ function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range
return axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
end


# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * ..
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
(:zsyrk_,:Complex128), (:csyrk_,:Complex64))
@eval begin
# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
function syrk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
Expand All @@ -193,14 +191,9 @@ for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
end
function syrk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
syrk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n, n)))
end
syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = syrk(uplo, trans, one($elty), A)
end
end

Expand All @@ -214,7 +207,7 @@ end
# COMPLEX A(LDA,*),C(LDC,*)
for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
@eval begin
function herk!(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty},
function herk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
if m != n error("syrk!: matrix C must be square") end
Expand All @@ -227,16 +220,11 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2))
C
end
function herk(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty})
function herk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
herk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n,n)))
end
herk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = herk(uplo, trans, one($elty), A)
end
end

Expand Down Expand Up @@ -266,16 +254,12 @@ for (fname, elty) in ((:dgbmv_,:Float64), (:sgbmv_,:Float32),
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = stride(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &m, &n, &kl, &ku, &alpha, A, &stride(A,2),
x, &stride(x,1), &0., y, &1)
y
gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), Array($elty, n))
end
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
A::StridedMatrix{$elty}, x::StridedVector{$elty})
gbmv(trans, m, kl, ku, one($elty), A, x)
end

end
end

Expand Down Expand Up @@ -303,127 +287,109 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32),
function sbmv(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty},
x::StridedVector{$elty})
n = size(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &size(A,2), &k, &alpha, A, &stride(A,2), x, &stride(x,1), &0., y, &1)
y
sbmv!(uplo, k, alpha, A, x, zero($elty), Array($elty, n))
end
function sbmv(uplo::BlasChar, k::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty})
sbmv(uplo, k, one($elty), A, x)
end
end
end

# (GE) general matrix-matrix multiplication
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32),
(:zgemm_,:Complex128), (:cgemm_,:Complex64))
# (GE) general matrix-matrix and matrix-vector multiplication
for (gemm, gemv, elty) in
((:dgemm_,:dgemv_,:Float64),
(:sgemm_,:sgemv_,:Float32),
(:zgemm_,:zgemv_,:Complex128),
(:cgemm_,:cgemv_,:Complex64))
@eval begin
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end
ccall(($(string(fname)),libblas), Void,
if m != size(C,1) || n != size(C,2)
error("gemm!: mismatched dimensions")
end
ccall(($(string(gemm)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &beta, C, &stride(C,2))
C
end
function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end
n = size(B, transB == 'N' ? 2 : 1)
C = Array($elty, (m, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &0., C, &stride(C,2))
C
function gemm(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
Array($elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
end
end

#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32),
(:zgemv_,:Complex128), (:cgemv_,:Complex64))
@eval begin
function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
function gemm(transA::BlasChar, transB::BlasChar,
A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
ccall(($(string(gemv)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2),
X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
Y = Array($elty, size(A,1))
gemv!(trans, alpha, A, X, zero($elty), Y)
Y
function gemv(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty})
gemv!(trans, alpha, A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
end

# (SY) symmetric matrix-matrix and matrix-vector multiplication

# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)

# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (vfname, mfname, elty) in
((:dsymv_,:dsymm_,:Float64),
(:ssymv_,:ssymm_,:Float32),
(:zsymv_,:zsymm_,:Complex128),
(:csymv_,:csymm_,:Complex64))
for (mfname, vfname, elty) in
((:dsymm_,:dsymv_,:Float64),
(:ssymm_,:ssymv_,:Float32),
(:zsymm_,:zsymv_,:Complex128),
(:csymm_,:csymv_,:Complex64))
@eval begin
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(X) || m != length(Y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
symv!(uplo, alpha, A, X, zero($elty), similar(X))
end
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
side = uppercase(convert(Char, side))
# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
k, j = size(A)
if k != j error("symm!: matrix A is $k by $j but must be square") end
Expand All @@ -435,9 +401,37 @@ for (vfname, mfname, elty) in
&beta, C, &stride(C,2))
C
end
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
symm!(side, uplo, alpha, A, B, zero($elty), similar(B))
end
function symm(side::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
symm(side, uplo, one($elty), A, B)
end
# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},
beta::($elty), y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(x) || m != length(y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv!(uplo, alpha, A, x, zero($elty), similar(x))
end
function symv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv(uplo, one($elty), A, x)
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ core numbers strings unicode corelib hashing remote \
arrayops linalg fft dct sparse bitarray suitesparse arpack \
random math functional bigint bigfloat sorting \
statistics glpk linprog poly file Rmath remote zlib image \
iostring gzip integers spawn ccall parallel
iostring gzip integers spawn ccall blas parallel

$(TESTS) ::
$(QUIET_JULIA) $(JULIA_EXECUTABLE) ./runtests.jl $@
Expand Down
Loading