Skip to content

Commit

Permalink
Fix naivesub Bidigonal to work for matrix elements by avoiding (#27203)
Browse files Browse the repository at this point in the history
calling zero(T). Also a few minor performance improvements.
  • Loading branch information
andreasnoack authored May 24, 2018
1 parent 14dbdf5 commit 3fa83f9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
43 changes: 31 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,20 +551,39 @@ function naivesub!(A::Bidiagonal{T}, b::AbstractVector, x::AbstractVector = b) w
if N != length(b) || N != length(x)
throw(DimensionMismatch("second dimension of A, $N, does not match one of the lengths of x, $(length(x)), or b, $(length(b))"))
end
if A.uplo == 'L' #do forward substitution
for j = 1:N
x[j] = b[j]
j > 1 && (x[j] -= A.ev[j-1] * x[j-1])
x[j] /= A.dv[j] == zero(T) ? throw(SingularException(j)) : A.dv[j]
end
else #do backward substitution
for j = N:-1:1
x[j] = b[j]
j < N && (x[j] -= A.ev[j] * x[j+1])
x[j] /= A.dv[j] == zero(T) ? throw(SingularException(j)) : A.dv[j]

if N == 0
return x
end

@inbounds begin
if A.uplo == 'L' #do forward substitution
x[1] = xj1 = A.dv[1]\b[1]
for j = 2:N
xj = b[j]
xj -= A.ev[j - 1] * xj1
dvj = A.dv[j]
if iszero(dvj)
throw(SingularException(j))
end
xj = dvj\xj
x[j] = xj1 = xj
end
else #do backward substitution
x[N] = xj1 = A.dv[N]\b[N]
for j = (N - 1):-1:1
xj = b[j]
xj -= A.ev[j] * xj1
dvj = A.dv[j]
if iszero(dvj)
throw(SingularException(j))
end
xj = dvj\xj
x[j] = xj1 = xj
end
end
end
x
return x
end

### Generic promotion methods and fallbacks
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,16 @@ end
@test promote_type(Tridiagonal{Tuple{T}} where T<:Integer, Bidiagonal{Int}) <: Tridiagonal
end

@testset "solve with matrix elements" begin
A = triu(tril(randn(9, 9), 3), -3)
b = randn(9)
Alb = Bidiagonal(Any[tril(A[1:3,1:3]), tril(A[4:6,4:6]), tril(A[7:9,7:9])],
Any[triu(A[4:6,1:3]), triu(A[7:9,4:6])], 'L')
Aub = Bidiagonal(Any[triu(A[1:3,1:3]), triu(A[4:6,4:6]), triu(A[7:9,7:9])],
Any[tril(A[1:3,4:6]), tril(A[4:6,7:9])], 'U')
bb = Any[b[1:3], b[4:6], b[7:9]]
@test vcat((Alb\bb)...) LowerTriangular(A)\b
@test vcat((Aub\bb)...) UpperTriangular(A)\b
end

end # module TestBidiagonal

0 comments on commit 3fa83f9

Please sign in to comment.