diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index 173a2236cf0e9..93533a4b02f64 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -345,9 +345,13 @@ end eye(::Type{Diagonal{T}}, n::Int) where {T} = Diagonal(ones(T,n)) # Matrix functions -exp(D::Diagonal) = Diagonal(exp.(D.diag)) -log(D::Diagonal) = Diagonal(log.(D.diag)) -sqrt(D::Diagonal) = Diagonal(sqrt.(D.diag)) +for f in (:exp, :log, :sqrt, + :cos, :sin, :tan, :csc, :sec, :cot, + :cosh, :sinh, :tanh, :csch, :sech, :coth, + :acos, :asin, :atan, :acsc, :asec, :acot, + :acosh, :asinh, :atanh, :acsch, :asech, :acoth) + @eval $f(D::Diagonal) = Diagonal($f.(D.diag)) +end #Linear solver function A_ldiv_B!(D::Diagonal, B::StridedVecOrMat) diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl index 3ba9c7c757099..d57807c9eaf6c 100644 --- a/test/linalg/diagonal.jl +++ b/test/linalg/diagonal.jl @@ -71,13 +71,15 @@ srand(1) @test func(D) ≈ func(DM) atol=n^2*eps(relty)*(1+(elty<:Complex)) end if relty <: BlasFloat - for func in (exp,) + for func in (exp, sinh, cosh, tanh, sech, csch, coth) @test func(D) ≈ func(DM) atol=n^3*eps(relty) end @test log(Diagonal(abs.(D.diag))) ≈ log(abs.(DM)) atol=n^3*eps(relty) end if elty <: BlasComplex - for func in (logdet, sqrt) + for func in (logdet, sqrt, sin, cos, tan, sec, csc, cot, + asin, acos, atan, asec, acsc, acot, + asinh, acosh, atanh, asech, acsch, acoth) @test func(D) ≈ func(DM) atol=n^2*eps(relty)*2 end end