Skip to content

Commit

Permalink
LAPACK: This commit adds the ordschur functionality for generalized s…
Browse files Browse the repository at this point in the history
…chur decompositons by plugging into LAPACK's tgsen function.
  • Loading branch information
cc7768 committed Jan 12, 2015
1 parent dfbcf50 commit 493e681
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 45 deletions.
5 changes: 5 additions & 0 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,11 @@ schurfact!{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = Generalized
schurfact{T<:BlasFloat}(A::StridedMatrix{T},B::StridedMatrix{T}) = schurfact!(copy(A),copy(B))
schurfact{TA,TB}(A::StridedMatrix{TA}, B::StridedMatrix{TB}) = (S = promote_type(Float32,typeof(one(TA)/norm(one(TA))),TB); schurfact!(S != TA ? convert(AbstractMatrix{S},A) : copy(A), S != TB ? convert(AbstractMatrix{S},B) : copy(B)))

ordschur!{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = GeneralizedSchur(LinAlg.LAPACK.tgsen!(select, S, T, Q, Z)...)
ordschur{Ty<:BlasFloat}(S::StridedMatrix{Ty}, T::StridedMatrix{Ty}, Q::StridedMatrix{Ty}, Z::StridedMatrix{Ty}, select::Array{Int}) = ordschur!(copy(S), copy(T), copy(Q), copy(Z), select)
ordschur!{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = (res=ordschur!(gschur.S, gschur.T, gschur.Q, gschur.Z, select); gschur[:alpha][:]=res[:alpha]; gschur[:beta][:]=res[:beta]; res)
ordschur{Ty<:BlasFloat}(gschur::GeneralizedSchur{Ty}, select::Array{Int}) = ordschur(gschur.S, gschur.T, gschur.Q, gschur.Z, select)

function getindex(F::GeneralizedSchur, d::Symbol)
d == :S && return F.S
d == :T && return F.T
Expand Down
141 changes: 130 additions & 11 deletions base/linalg/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3540,9 +3540,9 @@ for (gees, gges, elty, relty) in
end
end
# Reorder Schur forms
for (trsen, elty) in
((:dtrsen_,:Float64),
(:strsen_,:Float32))
for (trsen, tgsen, elty) in
((:dtrsen_, :dtgsen_, :Float64),
(:strsen_, :stgsen_, :Float32))
@eval begin
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
# * .. Scalar Arguments ..
Expand All @@ -3556,7 +3556,8 @@ for (trsen, elty) in
# DOUBLE PRECISION Q( LDQ, * ), T( LDT, * ), WI( * ), WORK( * ), WR( * )
chkstride1(T, Q)
n = chksquare(T)
ld = max(1, n)
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
wr = similar(T, $elty, n)
wi = similar(T, $elty, n)
m = sum(select)
Expand All @@ -3572,10 +3573,10 @@ for (trsen, elty) in
(Ptr{BlasChar}, Ptr{BlasChar}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr {BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&'N', &'V', select, &n,
T, &ld, Q, &ld,
T, &ldt, Q, &ldq,
wr, wi, &m, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
Expand All @@ -3589,12 +3590,71 @@ for (trsen, elty) in
end
T, Q, all(wi .== 0) ? wr : complex(wr, wi)
end
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
# * .. Scalar Arguments ..
# * LOGICAL WANTQ, WANTZ
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
# * $ M, N
# * DOUBLE PRECISION PL, PR
# * ..
# * .. Array Arguments ..
# * LOGICAL SELECT( * )
# * INTEGER IWORK( * )
# * DOUBLE PRECISION A( LDA, * ), ALPHAI( * ), ALPHAR( * ),
# * $ B( LDB, * ), BETA( * ), DIF( * ), Q( LDQ, * ),
# * $ WORK( * ), Z( LDZ, * )
# * ..
chkstride1(S, T, Q, Z)
n, nt, nq, nz = chksquare(S, T, Q, Z)
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
lds = max(1, stride(S, 2))
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
ldz = max(1, stride(Z, 2))
m = sum(select)
alphai = similar(T, $elty, n)
alphar = similar(T, $elty, n)
beta = similar(T, $elty, n)
lwork = blas_int(-1)
work = Array($elty, 1)
liwork = blas_int(-1)
iwork = Array(BlasInt, 1)
info = Array(BlasInt, 1)
select = convert(Array{BlasInt}, select)

for i = 1:2
ccall(($(blasfunc(tgsen)), liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&0, &1, &1, select,
&n, S, &lds, T,
&ldt, alphar, alphai, beta,
Q, &ldq, Z, &ldz,
&m, C_NULL, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
@lapackerror
if i == 1 # only estimated optimal lwork, liwork
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
liwork = blas_int(real(iwork[1]))
iwork = Array(BlasInt, liwork)
end
end
S, T, complex(alphar, alphai), beta, Q, Z
end
end
end

for (trsen, elty) in
((:ztrsen_,:Complex128),
(:ctrsen_,:Complex64))
for (trsen, tgsen, elty) in
((:ztrsen_, :ztgsen_, :Complex128),
(:ctrsen_, :ctgsen_, :Complex64))
@eval begin
function trsen!(select::Array{Int}, T::StridedMatrix{$elty}, Q::StridedMatrix{$elty})
# * .. Scalar Arguments ..
Expand All @@ -3607,7 +3667,8 @@ for (trsen, elty) in
# COMPLEX Q( LDQ, * ), T( LDT, * ), W( * ), WORK( * )
chkstride1(T, Q)
n = chksquare(T)
ld = max(1, n)
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
w = similar(T, $elty, n)
m = sum(select)
work = Array($elty, 1)
Expand All @@ -3623,7 +3684,7 @@ for (trsen, elty) in
Ptr{$elty}, Ptr {BlasInt},
Ptr{BlasInt}),
&'N', &'V', select, &n,
T, &ld, Q, &ld,
T, &ldt, Q, &ldq,
w, &m, C_NULL, C_NULL,
work, &lwork,
info)
Expand All @@ -3635,6 +3696,64 @@ for (trsen, elty) in
end
T, Q, w
end
function tgsen!(select::Array{Int}, S::StridedMatrix{$elty}, T::StridedMatrix{$elty},
Q::StridedMatrix{$elty}, Z::StridedMatrix{$elty})
# * .. Scalar Arguments ..
# * LOGICAL WANTQ, WANTZ
# * INTEGER IJOB, INFO, LDA, LDB, LDQ, LDZ, LIWORK, LWORK,
# * $ M, N
# * DOUBLE PRECISION PL, PR
# * ..
# * .. Array Arguments ..
# * LOGICAL SELECT( * )
# * INTEGER IWORK( * )
# * DOUBLE PRECISION DIF( * )
# * COMPLEX*16 A( LDA, * ), ALPHA( * ), B( LDB, * ),
# * $ BETA( * ), Q( LDQ, * ), WORK( * ), Z( LDZ, * )
# * ..
chkstride1(S, T, Q, Z)
n, nt, nq, nz = chksquare(S, T, Q, Z)
n==nt==nq==nz || throw(DimensionMismatch("matrices are not of same size"))
lds = max(1, stride(S, 2))
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
ldz = max(1, stride(Z, 2))
m = sum(select)
alpha = similar(T, $elty, n)
beta = similar(T, $elty, n)
lwork = blas_int(-1)
work = Array($elty, 1)
liwork = blas_int(-1)
iwork = Array(BlasInt, 1)
info = Array(BlasInt, 1)
select = convert(Array{BlasInt}, select)

for i = 1:2
ccall(($(blasfunc(tgsen)), liblapack), Void,
(Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{BlasInt}, Ptr{Void}, Ptr{Void}, Ptr{Void},
Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
&0, &1, &1, select,
&n, S, &lds, T,
&ldt, alpha, beta,
Q, &ldq, Z, &ldz,
&m, C_NULL, C_NULL, C_NULL,
work, &lwork, iwork, &liwork,
info)
@lapackerror
if i == 1 # only estimated optimal lwork, liwork
lwork = blas_int(real(work[1]))
work = Array($elty, lwork)
liwork = blas_int(real(iwork[1]))
iwork = Array(BlasInt, liwork)
end
end
S, T, alpha, beta, Q, Z
end
end
end

Expand Down
95 changes: 61 additions & 34 deletions test/linalg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ debug = false
import Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted

n = 10

# Split n into 2 parts for tests needing two matrices
n1 = div(n, 2)
n2 = 2*n1

srand(1234321)

a = rand(n,n)
Expand Down Expand Up @@ -110,12 +115,12 @@ debug && println("(Automatic) Square LU decomposition")
@test norm(a*(lua\b) - b, 1) < ε*κ*n*2 # Two because the right hand side has two columns

debug && println("Thin LU")
lua = lufact(a[:,1:5])
@test_approx_eq lua[:L]*lua[:U] lua[:P]*a[:,1:5]
lua = lufact(a[:,1:n1])
@test_approx_eq lua[:L]*lua[:U] lua[:P]*a[:,1:n1]

debug && println("Fat LU")
lua = lufact(a[1:5,:])
@test_approx_eq lua[:L]*lua[:U] lua[:P]*a[1:5,:]
lua = lufact(a[1:n1,:])
@test_approx_eq lua[:L]*lua[:U] lua[:P]*a[1:n1,:]

debug && println("QR decomposition (without pivoting)")
qra = qrfact(a, pivot=false)
Expand All @@ -126,23 +131,23 @@ debug && println("QR decomposition (without pivoting)")
@test_approx_eq_eps a*(qra\b) b 3000ε

debug && println("(Automatic) Fat (pivoted) QR decomposition") # Pivoting is only implemented for BlasFloats
qrpa = factorize(a[1:5,:])
qrpa = factorize(a[1:n1,:])
q,r = qrpa[:Q], qrpa[:R]
if isa(qrpa,QRPivoted) p = qrpa[:p] end # Reconsider if pivoted QR gets implemented in julia
@test_approx_eq q'*full(q, thin=false) eye(5)
@test_approx_eq q*full(q, thin=false)' eye(5)
@test_approx_eq q*r isa(qrpa,QRPivoted) ? a[1:5,p] : a[1:5,:]
@test_approx_eq isa(qrpa, QRPivoted) ? q*r[:,invperm(p)] : q*r a[1:5,:]
@test_approx_eq_eps a[1:5,:]*(qrpa\b[1:5]) b[1:5] 5000ε
@test_approx_eq q'*full(q, thin=false) eye(n1)
@test_approx_eq q*full(q, thin=false)' eye(n1)
@test_approx_eq q*r isa(qrpa,QRPivoted) ? a[1:n1,p] : a[1:n1,:]
@test_approx_eq isa(qrpa, QRPivoted) ? q*r[:,invperm(p)] : q*r a[1:n1,:]
@test_approx_eq_eps a[1:n1,:]*(qrpa\b[1:n1]) b[1:n1] 5000ε

debug && println("(Automatic) Thin (pivoted) QR decomposition") # Pivoting is only implemented for BlasFloats
qrpa = factorize(a[:,1:5])
qrpa = factorize(a[:,1:n1])
q,r = qrpa[:Q], qrpa[:R]
if isa(qrpa, QRPivoted) p = qrpa[:p] end # Reconsider if pivoted QR gets implemented in julia
@test_approx_eq q'*full(q, thin=false) eye(n)
@test_approx_eq q*full(q, thin=false)' eye(n)
@test_approx_eq q*r isa(qrpa, QRPivoted) ? a[:,p] : a[:,1:5]
@test_approx_eq isa(qrpa, QRPivoted) ? q*r[:,invperm(p)] : q*r a[:,1:5]
@test_approx_eq q*r isa(qrpa, QRPivoted) ? a[:,p] : a[:,1:n1]
@test_approx_eq isa(qrpa, QRPivoted) ? q*r[:,invperm(p)] : q*r a[:,1:n1]

debug && println("symmetric eigen-decomposition")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
Expand All @@ -164,19 +169,22 @@ debug && println("non-symmetric eigen decomposition")

debug && println("symmetric generalized eigenproblem")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
a610 = a[:,6:10]
f = eigfact(asym[1:5,1:5], a610'a610)
@test_approx_eq asym[1:5,1:5]*f[:vectors] scale(a610'a610*f[:vectors], f[:values])
@test_approx_eq f[:values] eigvals(asym[1:5,1:5], a610'a610)
@test_approx_eq_eps prod(f[:values]) prod(eigvals(asym[1:5,1:5]/(a610'a610))) 200ε
asym_sg = asym[1:n1, 1:n1]
a_sg = a[:,n1+1:n2]
f = eigfact(asym_sg, a_sg'a_sg)
@test_approx_eq asym_sg*f[:vectors] scale(a_sg'a_sg*f[:vectors], f[:values])
@test_approx_eq f[:values] eigvals(asym_sg, a_sg'a_sg)
@test_approx_eq_eps prod(f[:values]) prod(eigvals(asym_sg/(a_sg'a_sg))) 200ε
end

debug && println("Non-symmetric generalized eigenproblem")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
f = eigfact(a[1:5,1:5], a[6:10,6:10])
@test_approx_eq a[1:5,1:5]*f[:vectors] scale(a[6:10,6:10]*f[:vectors], f[:values])
@test_approx_eq f[:values] eigvals(a[1:5,1:5], a[6:10,6:10])
@test_approx_eq_eps prod(f[:values]) prod(eigvals(a[1:5,1:5]/a[6:10,6:10])) 50000ε
a1_nsg = a[1:n1, 1:n1]
a2_nsg = a[n1+1:n2, n1+1:n2]
f = eigfact(a1_nsg, a2_nsg)
@test_approx_eq a1_nsg*f[:vectors] scale(a2_nsg*f[:vectors], f[:values])
@test_approx_eq f[:values] eigvals(a1_nsg, a2_nsg)
@test_approx_eq_eps prod(f[:values]) prod(eigvals(a1_nsg/a2_nsg)) 50000ε
end

debug && println("Schur")
Expand All @@ -202,13 +210,31 @@ debug && println("Reorder Schur")

debug && println("Generalized Schur")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
f = schurfact(a[1:5,1:5], a[6:10,6:10])
@test_approx_eq f[:Q]*f[:S]*f[:Z]' a[1:5,1:5]
@test_approx_eq f[:Q]*f[:T]*f[:Z]' a[6:10,6:10]
a1_sf = a[1:n1, 1:n1]
a2_sf = a[n1+1:n2, n1+1:n2]
f = schurfact(a1_sf, a2_sf)
@test_approx_eq f[:Q]*f[:S]*f[:Z]' a1_sf
@test_approx_eq f[:Q]*f[:T]*f[:Z]' a2_sf
@test istriu(f[:S]) || iseltype(a,Real)
@test istriu(f[:T]) || iseltype(a,Real)
end

debug && println("Reorder Generalized Schur")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in Julia
a1_sf = a[1:n1, 1:n1]
a2_sf = a[n1+1:n2, n1+1:n2]
NS = schurfact(a1_sf, a2_sf)
# Currently just testing with selecting gen eig values < 1
select = int(real(NS[:values] .* conj(NS[:values])) .< 1)
m = sum(select)
S = ordschur(NS, select)
# Make sure that the new factorization stil factors matrix
@test_approx_eq S[:Q]*S[:S]*S[:Z]' a1_sf
@test_approx_eq S[:Q]*S[:T]*S[:Z]' a2_sf
# Make sure that we have sorted it correctly
@test_approx_eq NS[:values][find(select)] S[:values][1:m]
end

debug && println("singular value decomposition")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
usv = svdfact(a)
Expand All @@ -217,9 +243,10 @@ debug && println("singular value decomposition")

debug && println("Generalized svd")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
gsvd = svdfact(a,a[1:5,:])
a_svd = a[1:n1, :]
gsvd = svdfact(a,a_svd)
@test_approx_eq gsvd[:U]*gsvd[:D1]*gsvd[:R]*gsvd[:Q]' a
@test_approx_eq gsvd[:V]*gsvd[:D2]*gsvd[:R]*gsvd[:Q]' a[1:5,:]
@test_approx_eq gsvd[:V]*gsvd[:D2]*gsvd[:R]*gsvd[:Q]' a_svd
end

debug && println("Solve square general system of equations")
Expand All @@ -231,10 +258,10 @@ debug && println("Solve square general system of equations")

debug && println("Test nullspace")
if eltya != BigFloat && eltyb != BigFloat # Revisit when implemented in julia
a15null = nullspace(a[:,1:5]')
@test rank([a[:,1:5] a15null]) == 10
@test_approx_eq_eps norm(a[:,1:5]'a15null, Inf) zero(eltya) 300ε
@test_approx_eq_eps norm(a15null'a[:,1:5], Inf) zero(eltya) 400ε
a15null = nullspace(a[:,1:n1]')
@test rank([a[:,1:n1] a15null]) == 10
@test_approx_eq_eps norm(a[:,1:n1]'a15null, Inf) zero(eltya) 300ε
@test_approx_eq_eps norm(a15null'a[:,1:n1], Inf) zero(eltya) 400ε
@test size(nullspace(b), 2) == 0
end

Expand All @@ -244,9 +271,9 @@ debug && println("\ntype of a: ", eltya, "\n")

debug && println("Test pinv")
if eltya != BigFloat # Revisit when implemented in julia
pinva15 = pinv(a[:,1:5])
@test_approx_eq a[:,1:5]*pinva15*a[:,1:5] a[:,1:5]
@test_approx_eq pinva15*a[:,1:5]*pinva15 pinva15
pinva15 = pinv(a[:,1:n1])
@test_approx_eq a[:,1:n1]*pinva15*a[:,1:n1] a[:,1:n1]
@test_approx_eq pinva15*a[:,1:n1]*pinva15 pinva15
end

# if isreal(a)
Expand Down

0 comments on commit 493e681

Please sign in to comment.