Skip to content

Commit

Permalink
make lu! allocate less if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Jun 5, 2022
1 parent 890eb43 commit 23bc7c9
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions src/solvers/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ function show_umf_info(control::Vector{Float64}, info::Vector{Float64}=umf_info,
end


# TODO, this actually doesn't need to be this big if iter refinement is off
worspace_W_size(S::SparseMatrixCSC{Float64}) = 5 * size(S, 2)
worspace_W_size(S::SparseMatrixCSC{ComplexF64}) = 10 * size(S, 2)


"""
Expand All @@ -152,9 +155,16 @@ struct UmfpackWS{T<:UMFITypes}
W::Vector{Float64}
end

# TODO, this actually doesn't need to be this big if iter refinement is off
UmfpackWS(S::SparseMatrixCSC{Float64, T}) where T = UmfpackWS{T}(Vector{T}(undef, size(S, 2)), Vector{Float64}(undef, 5 * size(S, 2)))
UmfpackWS(S::SparseMatrixCSC{ComplexF64, T}) where T = UmfpackWS{T}(Vector{T}(undef, size(S, 2)), Vector{Float64}(undef, 10 * size(S, 2)))
function Base.resize!(W::UmfpackWS, S::SparseMatrixCSC)
resize!(W.Wi, size(S, 2))
resize!(W.W, worspace_W_size(S))
end


UmfpackWS(S::SparseMatrixCSC{Tv, Ti}) where {Tv, Ti} = UmfpackWS{Ti}(
Vector{Ti}(undef, size(S, 2)),
Vector{Float64}(undef, worspace_W_size(S)))


Base.similar(w::UmfpackWS) = UmfpackWS(similar(w.Wi), similar(w.W))

Expand Down Expand Up @@ -338,19 +348,32 @@ julia> F \\ ones(2)
1.0
```
"""
function lu!(F::UmfpackLU, S::SparseMatrixCSC{<:UMFVTypes,<:UMFITypes}; check::Bool=true)
function lu!(F::UmfpackLU, S::SparseMatrixCSC{<:UMFVTypes,Ti}; check::Bool=true) where {Ti<:UMFITypes}
zerobased = getcolptr(S)[1] == 0
# resize workspace if needed
if F.n < size(S, 2)
F.workspace = UmfpackWS(S)
end

F.m = size(S, 1)
F.n = size(S, 2)
F.colptr = zerobased ? copy(getcolptr(S)) : decrement(getcolptr(S))
F.rowval = zerobased ? copy(rowvals(S)) : decrement(rowvals(S))
F.nzval = copy(nonzeros(S))

# resize workspace if needed
resize!(F.workspace, S)

resize!(F.colptr, length(getcolptr(S)))
if zerobased
copy!(F.colptr, getcolptr(S))
else
F.colptr .= getcolptr(S) .- one(Ti)
end

resize!(F.rowval, length(rowvals(S)))
if zerobased
copy!(F.rowval, rowvals(S))
else
F.rowval .= rowvals(S) .- one(Ti)
end

resize!(F.nzval, length(nonzeros(S)))
copy!(F.nzval, nonzeros(S))

umfpack_numeric!(F, reuse_numeric = false)
check && (issuccess(F) || throw(LinearAlgebra.SingularException(0)))
return F
Expand Down

0 comments on commit 23bc7c9

Please sign in to comment.