diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index 3f32927d21a224..fa0326252398c2 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -331,7 +331,7 @@ function sqrtm{T<:Real}(A::StridedMatrix{T}) SchurF = schurfact(complex(A)) R = full(sqrtm(UpperTriangular(SchurF[:T]))) retmat = SchurF[:vectors]*R*SchurF[:vectors]' - all(imag(retmat) .== 0) ? real(retmat) : retmat + return retmat end function sqrtm{T<:Complex}(A::StridedMatrix{T}) if ishermitian(A) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index 7c8dd9f393d8ee..82dc33e7ca6bc6 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -1108,11 +1108,25 @@ end logm(A::LowerTriangular) = logm(A.').' function sqrtm{T}(A::UpperTriangular{T}) - n = size(A, 1) - TT = typeof(sqrt(zero(T))) + n = chksquare(A) + realmatrix = false + if isreal(A) + realmatrix = true + for i = 1:n + if real(A[i,i]) < 0 + realmatrix = false + break + end + end + end + if realmatrix + TT = typeof(sqrt(zero(T))) + else + TT = typeof(sqrt(complex(-one(T)))) + end R = zeros(TT, n, n) for j = 1:n - R[j,j] = sqrt(A[j,j]) + R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j])) for i = j-1:-1:1 r = A[i,j] for k = i+1:j-1 @@ -1124,7 +1138,7 @@ function sqrtm{T}(A::UpperTriangular{T}) return UpperTriangular(R) end function sqrtm{T}(A::UnitUpperTriangular{T}) - n = size(A, 1) + n = chksquare(A) TT = typeof(sqrt(zero(T))) R = zeros(TT, n, n) for j = 1:n