From cad74f6165ecc503fb88adbb2f27391487baaebe Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Wed, 12 Jul 2017 01:48:52 +0200 Subject: [PATCH] use in-place inv, since lufact already made a copy --- base/linalg/dense.jl | 2 +- base/linalg/lu.jl | 6 ++++-- test/linalg/lu.jl | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index 46d78399e2061c..da370e26484985 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -653,7 +653,7 @@ function inv(A::StridedMatrix{T}) where T Ai = inv(LowerTriangular(AA)) Ai = convert(typeof(parent(Ai)), Ai) else - Ai = inv(lufact(AA)) + Ai = inv!(lufact(AA)) Ai = convert(typeof(parent(Ai)), Ai) end return Ai diff --git a/base/linalg/lu.jl b/base/linalg/lu.jl index 192900f92c69a8..be825281f0660a 100644 --- a/base/linalg/lu.jl +++ b/base/linalg/lu.jl @@ -202,6 +202,7 @@ convert(::Type{LU{T,S}}, F::LU) where {T,S} = LU{T,S}(convert(S, F.factors), F.i convert(::Type{Factorization{T}}, F::LU{T}) where {T} = F convert(::Type{Factorization{T}}, F::LU) where {T} = convert(LU{T}, F) +copy(A::LU{T,S}) where {T,S} = LU{T,S}(copy(A.factors), copy(A.ipiv), A.info) size(A::LU) = size(A.factors) size(A::LU,n) = size(A.factors,n) @@ -313,8 +314,9 @@ end inv!(A::LU{<:BlasFloat,<:StridedMatrix}) = @assertnonsingular LAPACK.getri!(A.factors, A.ipiv) A.info -inv(A::LU{<:BlasFloat,<:StridedMatrix}) = - inv!(LU(copy(A.factors), copy(A.ipiv), copy(A.info))) +inv!(A::LU{T,<:StridedMatrix}) where {T} = + @assertnonsingular A_ldiv_B!(A.factors, copy(A), eye(T, size(A, 1))) A.info +inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A)) function _cond1Inf(A::LU{<:BlasFloat,<:StridedMatrix}, p::Number, normA::Real) if p != 1 && p != Inf diff --git a/test/linalg/lu.jl b/test/linalg/lu.jl index 10c50463fa101a..541fb19ae24acd 100644 --- a/test/linalg/lu.jl +++ b/test/linalg/lu.jl @@ -64,6 +64,7 @@ debug && println("(Automatic) Square LU decomposition. eltya: $eltya, eltyb: $el @test l*u ≈ a[p,:] @test (l*u)[invperm(p),:] ≈ a @test a * inv(lua) ≈ eye(n) + @test copy(lua) == lua lstring = sprint(show,l) ustring = sprint(show,u)