From e6194547e8cc15b212f427ea5629b2a3e613da6b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 7 Nov 2024 13:19:24 +0530 Subject: [PATCH] Check isdiag in dense trig functions --- stdlib/LinearAlgebra/src/dense.jl | 65 +++++++++++++++++++++++------- stdlib/LinearAlgebra/test/dense.jl | 37 ++++++++--------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index b8d5c84c3db53f..e8bf2b0d12ee11 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -1004,9 +1004,16 @@ end cbrt(A::AdjointAbsMat) = adjoint(cbrt(parent(A))) cbrt(A::TransposeAbsMat) = transpose(cbrt(parent(A))) +function applydiagonal(f, A) + dinv = f(Diagonal(diag(A))) + copyto!(similar(A, eltype(dinv)), dinv) +end + function inv(A::StridedMatrix{T}) where T checksquare(A) - if istriu(A) + if isdiag(A) + Ai = applydiagonal(inv, A) + elseif istriu(A) Ai = triu!(parent(inv(UpperTriangular(A)))) elseif istril(A) Ai = tril!(parent(inv(LowerTriangular(A)))) @@ -1034,14 +1041,18 @@ julia> cos(fill(1.0, (2,2))) ``` """ function cos(A::AbstractMatrix{<:Real}) - if issymmetric(A) + if isdiag(A) + return applydiagonal(cos, A) + elseif issymmetric(A) return copytri!(parent(cos(Symmetric(A))), 'U') end T = complex(float(eltype(A))) return real(exp!(T.(im .* A))) end function cos(A::AbstractMatrix{<:Complex}) - if ishermitian(A) + if isdiag(A) + return applydiagonal(cos, A) + elseif ishermitian(A) return copytri!(parent(cos(Hermitian(A))), 'U', true) end T = complex(float(eltype(A))) @@ -1067,14 +1078,18 @@ julia> sin(fill(1.0, (2,2))) ``` """ function sin(A::AbstractMatrix{<:Real}) - if issymmetric(A) + if isdiag(A) + return applydiagonal(sin, A) + elseif issymmetric(A) return copytri!(parent(sin(Symmetric(A))), 'U') end T = complex(float(eltype(A))) return imag(exp!(T.(im .* A))) end function sin(A::AbstractMatrix{<:Complex}) - if ishermitian(A) + if isdiag(A) + return applydiagonal(sin, A) + elseif ishermitian(A) return copytri!(parent(sin(Hermitian(A))), 'U', true) end T = complex(float(eltype(A))) @@ -1153,7 +1168,9 @@ julia> tan(fill(1.0, (2,2))) ``` """ function tan(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(tan, A) + elseif ishermitian(A) return copytri!(parent(tan(Hermitian(A))), 'U', true) end S, C = sincos(A) @@ -1167,7 +1184,9 @@ end Compute the matrix hyperbolic cosine of a square matrix `A`. """ function cosh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(cosh, A) + elseif ishermitian(A) return copytri!(parent(cosh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1181,7 +1200,9 @@ end Compute the matrix hyperbolic sine of a square matrix `A`. """ function sinh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(sinh, A) + elseif ishermitian(A) return copytri!(parent(sinh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1195,7 +1216,9 @@ end Compute the matrix hyperbolic tangent of a square matrix `A`. """ function tanh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(tanh, A) + elseif ishermitian(A) return copytri!(parent(tanh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1230,7 +1253,9 @@ julia> acos(cos([0.5 0.1; -0.2 0.3])) ``` """ function acos(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(acos, A) + elseif ishermitian(A) acosHermA = acos(Hermitian(A)) return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA) end @@ -1261,7 +1286,9 @@ julia> asin(sin([0.5 0.1; -0.2 0.3])) ``` """ function asin(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(asin, A) + elseif ishermitian(A) asinHermA = asin(Hermitian(A)) return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA) end @@ -1292,7 +1319,9 @@ julia> atan(tan([0.5 0.1; -0.2 0.3])) ``` """ function atan(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(atan, A) + elseif ishermitian(A) return copytri!(parent(atan(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) @@ -1310,7 +1339,9 @@ logarithmic formulas used to compute this function, see [^AH16_4]. [^AH16_4]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function acosh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(acosh, A) + elseif ishermitian(A) acoshHermA = acosh(Hermitian(A)) return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA) end @@ -1329,7 +1360,9 @@ logarithmic formulas used to compute this function, see [^AH16_5]. [^AH16_5]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function asinh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(asinh, A) + elseif ishermitian(A) return copytri!(parent(asinh(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) @@ -1347,7 +1380,9 @@ logarithmic formulas used to compute this function, see [^AH16_6]. [^AH16_6]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function atanh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(atanh, A) + elseif ishermitian(A) return copytri!(parent(atanh(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) diff --git a/stdlib/LinearAlgebra/test/dense.jl b/stdlib/LinearAlgebra/test/dense.jl index 1d43d768993921..10f50a80ab7fdf 100644 --- a/stdlib/LinearAlgebra/test/dense.jl +++ b/stdlib/LinearAlgebra/test/dense.jl @@ -607,6 +607,7 @@ end -0.4579038628067864 1.7361475641080275 6.478801851038108]) A3 = convert(Matrix{elty}, [0.25 0.25; 0 0]) A4 = convert(Matrix{elty}, [0 0.02; 0 0]) + A5 = convert(Matrix{elty}, [2.0 0; 0 3.0]) cosA1 = convert(Matrix{elty},[-0.18287716254368605 -0.29517205254584633 0.761711400552759; 0.23326967400345625 0.19797853773269333 -0.14758602627292305; @@ -614,8 +615,8 @@ end sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513; -0.6227405671629401 0.2176922827908092 -0.5538759902910078; -0.6227405671629398 -0.6916051440348725 0.3554214365346742]) - @test cos(A1) ≈ cosA1 - @test sin(A1) ≈ sinA1 + @test @inferred(cos(A1)) ≈ cosA1 + @test @inferred(sin(A1)) ≈ sinA1 cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532; 0.12878366262380136 -0.5596234510748788 0.5210483146041339; @@ -637,36 +638,36 @@ end @test sin(A4) ≈ sinA4 # Identities - for (i, A) in enumerate((A1, A2, A3, A4)) - @test sincos(A) == (sin(A), cos(A)) + for (i, A) in enumerate((A1, A2, A3, A4, A5)) + @test @inferred(sincos(A)) == (sin(A), cos(A)) @test cos(A)^2 + sin(A)^2 ≈ Matrix(I, size(A)) @test cos(A) ≈ cos(-A) @test sin(A) ≈ -sin(-A) - @test tan(A) ≈ sin(A) / cos(A) + @test @inferred(tan(A)) ≈ sin(A) / cos(A) @test cos(A) ≈ real(exp(im*A)) @test sin(A) ≈ imag(exp(im*A)) @test cos(A) ≈ real(cis(A)) @test sin(A) ≈ imag(cis(A)) - @test cis(A) ≈ cos(A) + im * sin(A) + @test @inferred(cis(A)) ≈ cos(A) + im * sin(A) - @test cosh(A) ≈ 0.5 * (exp(A) + exp(-A)) - @test sinh(A) ≈ 0.5 * (exp(A) - exp(-A)) - @test cosh(A) ≈ cosh(-A) - @test sinh(A) ≈ -sinh(-A) + @test @inferred(cosh(A)) ≈ 0.5 * (exp(A) + exp(-A)) + @test @inferred(sinh(A)) ≈ 0.5 * (exp(A) - exp(-A)) + @test @inferred(cosh(A)) ≈ cosh(-A) + @test @inferred(sinh(A)) ≈ -sinh(-A) # Some of the following identities fail for A3, A4 because the matrices are singular - if i in (1, 2) - @test sec(A) ≈ inv(cos(A)) - @test csc(A) ≈ inv(sin(A)) - @test cot(A) ≈ inv(tan(A)) - @test sech(A) ≈ inv(cosh(A)) - @test csch(A) ≈ inv(sinh(A)) - @test coth(A) ≈ inv(tanh(A)) + if i in (1, 2, 5) + @test @inferred(sec(A)) ≈ inv(cos(A)) + @test @inferred(csc(A)) ≈ inv(sin(A)) + @test @inferred(cot(A)) ≈ inv(tan(A)) + @test @inferred(sech(A)) ≈ inv(cosh(A)) + @test @inferred(csch(A)) ≈ inv(sinh(A)) + @test @inferred(coth(A)) ≈ inv(@inferred tanh(A)) end # The following identities fail for A1, A2 due to rounding errors; # probably needs better algorithm for the general case - if i in (3, 4) + if i in (3, 4, 5) @test cosh(A)^2 - sinh(A)^2 ≈ Matrix(I, size(A)) @test tanh(A) ≈ sinh(A) / cosh(A) end