diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index 3f32927d21a224..13522941689ef5 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -319,7 +319,6 @@ function logm(A::StridedMatrix) return retmat end end - logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b) logm(a::Complex) = log(a) @@ -328,21 +327,27 @@ function sqrtm{T<:Real}(A::StridedMatrix{T}) return full(sqrtm(Symmetric(A))) end n = chksquare(A) - SchurF = schurfact(complex(A)) - R = full(sqrtm(UpperTriangular(SchurF[:T]))) - retmat = SchurF[:vectors]*R*SchurF[:vectors]' - all(imag(retmat) .== 0) ? real(retmat) : retmat + if istriu(A) + return full(sqrtm(UpperTriangular(A))) + else + SchurF = schurfact(complex(A)) + R = full(sqrtm(UpperTriangular(SchurF[:T]))) + return SchurF[:vectors] * R * SchurF[:vectors]' + end end function sqrtm{T<:Complex}(A::StridedMatrix{T}) if ishermitian(A) return full(sqrtm(Hermitian(A))) end n = chksquare(A) - SchurF = schurfact(A) - R = full(sqrtm(UpperTriangular(SchurF[:T]))) - SchurF[:vectors]*R*SchurF[:vectors]' + if istriu(A) + return full(sqrtm(UpperTriangular(A))) + else + SchurF = schurfact(A) + R = full(sqrtm(UpperTriangular(SchurF[:T]))) + return SchurF[:vectors] * R * SchurF[:vectors]' + end end - sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b) sqrtm(a::Complex) = sqrt(a) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 7c8dd9f393d8ee..82dc33e7ca6bc6 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1108,11 +1108,25 @@ end logm(A::LowerTriangular) = logm(A.').' function sqrtm{T}(A::UpperTriangular{T}) - n = size(A, 1) - TT = typeof(sqrt(zero(T))) + n = chksquare(A) + realmatrix = false + if isreal(A) + realmatrix = true + for i = 1:n + if real(A[i,i]) < 0 + realmatrix = false + break + end + end + end + if realmatrix + TT = typeof(sqrt(zero(T))) + else + TT = typeof(sqrt(complex(-one(T)))) + end R = zeros(TT, n, n) for j = 1:n - R[j,j] = sqrt(A[j,j]) + R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j])) for i = j-1:-1:1 r = A[i,j] for k = i+1:j-1 @@ -1124,7 +1138,7 @@ function sqrtm{T}(A::UpperTriangular{T}) return UpperTriangular(R) end function sqrtm{T}(A::UnitUpperTriangular{T}) - n = size(A, 1) + n = chksquare(A) TT = typeof(sqrt(zero(T))) R = zeros(TT, n, n) for j = 1:n diff --git a/test/linalg/triangular.jl b/test/linalg/triangular.jl index 926d8f6c232666..d8c9dfb9413b0c 100644 --- a/test/linalg/triangular.jl +++ b/test/linalg/triangular.jl @@ -238,6 +238,14 @@ for elty1 in (Float32, Float64, Complex64, Complex128, BigFloat, Int) end end +# Matrix square root +Atn = UpperTriangular([-1 1 2; 0 -2 2; 0 0 -3]) +Atp = UpperTriangular([1 1 2; 0 2 2; 0 0 3]) +@test_approx_eq sqrtm(Atn) |> t->t*t Atn +@test typeof(sqrtm(Atn)[1,1]) <: Complex +@test_approx_eq sqrtm(Atp) |> t->t*t Atp +@test typeof(sqrtm(Atp)[1,1]) <: Real + Areal = randn(n, n)/2 Aimg = randn(n, n)/2 A2real = randn(n, n)/2