Skip to content

Commit

Permalink
Fix behaviour of sqrtm on UpperTriangular matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
mfasi committed Aug 12, 2015
1 parent 66f336d commit 8527eb6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
23 changes: 14 additions & 9 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ function logm(A::StridedMatrix)
return retmat
end
end

logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b)
logm(a::Complex) = log(a)

Expand All @@ -328,21 +327,27 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
return full(sqrtm(Symmetric(A)))
end
n = chksquare(A)
SchurF = schurfact(complex(A))
R = full(sqrtm(UpperTriangular(SchurF[:T])))
retmat = SchurF[:vectors]*R*SchurF[:vectors]'
all(imag(retmat) .== 0) ? real(retmat) : retmat
if istriu(A)
return full(sqrtm(UpperTriangular(A)))
else
SchurF = schurfact(complex(A))
R = full(sqrtm(UpperTriangular(SchurF[:T])))
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end
function sqrtm{T<:Complex}(A::StridedMatrix{T})
if ishermitian(A)
return full(sqrtm(Hermitian(A)))
end
n = chksquare(A)
SchurF = schurfact(A)
R = full(sqrtm(UpperTriangular(SchurF[:T])))
SchurF[:vectors]*R*SchurF[:vectors]'
if istriu(A)
return full(sqrtm(UpperTriangular(A)))
else
SchurF = schurfact(A)
R = full(sqrtm(UpperTriangular(SchurF[:T])))
return SchurF[:vectors] * R * SchurF[:vectors]'
end
end

sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
sqrtm(a::Complex) = sqrt(a)

Expand Down
22 changes: 18 additions & 4 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1108,11 +1108,25 @@ end
logm(A::LowerTriangular) = logm(A.').'

function sqrtm{T}(A::UpperTriangular{T})
n = size(A, 1)
TT = typeof(sqrt(zero(T)))
n = chksquare(A)
realmatrix = false
if isreal(A)
realmatrix = true
for i = 1:n
if real(A[i,i]) < 0
realmatrix = false
break
end
end
end
if realmatrix
TT = typeof(sqrt(zero(T)))
else
TT = typeof(sqrt(complex(-one(T))))
end
R = zeros(TT, n, n)
for j = 1:n
R[j,j] = sqrt(A[j,j])
R[j,j] = realmatrix?sqrt(A[j,j]):sqrt(complex(A[j,j]))
for i = j-1:-1:1
r = A[i,j]
for k = i+1:j-1
Expand All @@ -1124,7 +1138,7 @@ function sqrtm{T}(A::UpperTriangular{T})
return UpperTriangular(R)
end
function sqrtm{T}(A::UnitUpperTriangular{T})
n = size(A, 1)
n = chksquare(A)
TT = typeof(sqrt(zero(T)))
R = zeros(TT, n, n)
for j = 1:n
Expand Down
10 changes: 9 additions & 1 deletion test/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ for elty1 in (Float32, Float64, Complex64, Complex128, BigFloat, Int)
@test_approx_eq_eps det(A1) det(lufact(full(A1))) sqrt(eps(real(float(one(elty1)))))*n*n

# Matrix square root
@test_approx_eq sqrtm(A1) |> t->t*t A1
@test sqrtm(A1) |> t->t*t A1

# naivesub errors
@test_throws DimensionMismatch naivesub!(A1,ones(elty1,n+1))
Expand Down Expand Up @@ -238,6 +238,14 @@ for elty1 in (Float32, Float64, Complex64, Complex128, BigFloat, Int)
end
end

# Matrix square root
Atn = UpperTriangular([-1 1 2; 0 -2 2; 0 0 -3])
Atp = UpperTriangular([1 1 2; 0 2 2; 0 0 3])
@test sqrtm(Atn) |> t->t*t Atn
@test typeof(sqrtm(Atn)[1,1]) <: Complex
@test sqrtm(Atp) |> t->t*t Atp
@test typeof(sqrtm(Atp)[1,1]) <: Real

Areal = randn(n, n)/2
Aimg = randn(n, n)/2
A2real = randn(n, n)/2
Expand Down

0 comments on commit 8527eb6

Please sign in to comment.