Skip to content

Commit

Permalink
Support cis(A) for matrix A (JuliaLang#40194)
Browse files Browse the repository at this point in the history
  • Loading branch information
schneiderfelipe authored and antoine-levitt committed May 9, 2021
1 parent a0ac208 commit 22ce688
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 2 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Standard library changes
* On aarch64, OpenBLAS now uses an ILP64 BLAS like all other 64-bit platforms. ([#39436])
* OpenBLAS is updated to 0.3.13. ([#39216])
* SuiteSparse is updated to 5.8.1. ([#39455])
* `cis(A)` now supports matrix arguments ([#40194]).

#### Markdown

Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ LinearAlgebra.nullspace
Base.kron
Base.kron!
LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat})
Base.cis(::AbstractMatrix)
Base.:^(::AbstractMatrix, ::Number)
Base.:^(::Number, ::AbstractMatrix)
LinearAlgebra.log(::StridedMatrix)
Expand Down
20 changes: 20 additions & 0 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,26 @@ julia> exp(A)
exp(A::StridedMatrix{<:BlasFloat}) = exp!(copy(A))
exp(A::StridedMatrix{<:Union{Integer,Complex{<:Integer}}}) = exp!(float.(A))

"""
cis(A::AbstractMatrix)
Compute ``\\exp(i A)`` for a square matrix ``A``.
!!! compat "Julia 1.7"
Support for using `cis` with matrices was added in Julia 1.7.
# Examples
```jldoctest
julia> cis([π 0; 0 π]) ≈ -I
true
```
"""
Base.cis(A::AbstractMatrix) = exp(im * A) # fallback
Base.cis(A::AbstractMatrix{<:Base.HWNumber}) = exp_maybe_inplace(float.(im .* A))

exp_maybe_inplace(A::StridedMatrix{<:Union{ComplexF32, ComplexF64}}) = exp!(A)
exp_maybe_inplace(A) = exp(A)

"""
^(b::Number, A::AbstractMatrix)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ function logdet(D::Diagonal{<:Complex}) # make sure branch cut is correct
end

# Matrix functions
for f in (:exp, :log, :sqrt,
for f in (:exp, :cis, :log, :sqrt,
:cos, :sin, :tan, :csc, :sec, :cot,
:cosh, :sinh, :tanh, :csch, :sech, :coth,
:acos, :asin, :atan, :acsc, :asec, :acot,
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,12 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
end
end

function cis(A::Union{RealHermSymComplexHerm,SymTridiagonal{<:Real}})
F = eigen(A)
# The returned matrix is unitary, and is complex-symmetric for real A
return F.vectors .* cis.(F.values') * F.vectors'
end

for func in (:acos, :asin)
@eval begin
function ($func)(A::HermOrSym{<:Real})
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ end
[4.000000000000000 -1.414213562373094 -1.414213562373095
-1.414213562373095 4.999999999999996 -0.000000000000000
0 -0.000000000000002 3.000000000000000])

# cis always returns a complex matrix
if elty <: Real
eltyim = Complex{elty}
else
eltyim = elty
end

@test cis(A1) convert(Matrix{eltyim}, [-0.339938 + 0.000941506im 0.772659 - 0.8469im 0.52745 + 0.566543im;
0.650054 - 0.140179im -0.0762135 + 0.284213im 0.38633 - 0.42345im ;
0.650054 - 0.140179im 0.913779 + 0.143093im -0.603663 - 0.28233im ]) rtol=7e-7
end

@testset "Additional tests for $elty" for elty in (Float64, ComplexF64)
Expand Down Expand Up @@ -560,8 +571,13 @@ end
@test cos(A) cos(-A)
@test sin(A) -sin(-A)
@test tan(A) sin(A) / cos(A)

@test cos(A) real(exp(im*A))
@test sin(A) imag(exp(im*A))
@test cos(A) real(cis(A))
@test sin(A) imag(cis(A))
@test cis(A) cos(A) + im * sin(A)

@test cosh(A) 0.5 * (exp(A) + exp(-A))
@test sinh(A) 0.5 * (exp(A) - exp(-A))
@test cosh(A) cosh(-A)
Expand Down Expand Up @@ -605,6 +621,9 @@ end

@test cos(A5) 0.5 * (exp(im*A5) + exp(-im*A5))
@test sin(A5) -0.5im * (exp(im*A5) - exp(-im*A5))
@test cos(A5) 0.5 * (cis(A5) + cis(-A5))
@test sin(A5) -0.5im * (cis(A5) - cis(-A5))

@test cosh(A5) 0.5 * (exp(A5) + exp(-A5))
@test sinh(A5) 0.5 * (exp(A5) - exp(-A5))
end
Expand Down
7 changes: 6 additions & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Random.seed!(1)
@test func(D) func(DM) atol=n^2*eps(relty)*(1+(elty<:Complex))
end
if relty <: BlasFloat
for func in (exp, sinh, cosh, tanh, sech, csch, coth)
for func in (exp, cis, sinh, cosh, tanh, sech, csch, coth)
@test func(D) func(DM) atol=n^3*eps(relty)
end
@test log(Diagonal(abs.(D.diag))) log(abs.(DM)) atol=n^3*eps(relty)
Expand All @@ -102,6 +102,10 @@ Random.seed!(1)
end
end

@testset "Two-dimensional Euler formula for Diagonal" begin
@test cis(Diagonal([π, π])) -I
end

@testset "Linear solve" begin
for (v, U) in ((vv, UU), (view(vv, 1:n), view(UU, 1:n, 1:2)))
@test D*v DM*v atol=n*eps(relty)*(1+(elty<:Complex))
Expand Down Expand Up @@ -568,6 +572,7 @@ end
@test ishermitian(Dsym) == false

@test exp(D) == Diagonal([exp([1 2; 3 4]), exp([1 2; 3 4])])
@test cis(D) == Diagonal([cis([1 2; 3 4]), cis([1 2; 3 4])])
@test log(D) == Diagonal([log([1 2; 3 4]), log([1 2; 3 4])])
@test sqrt(D) == Diagonal([sqrt([1 2; 3 4]), sqrt([1 2; 3 4])])

Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ Random.seed!(1010)
@test ishermitian(σ)
end

@testset "Two-dimensional Euler formula for Hermitian" begin
@test cis(Hermitian([π 0; 0 π])) -I
end

@testset "Hermitian matrix exponential/log" begin
A1 = randn(4,4) + im*randn(4,4)
A2 = A1 + A1'
@test exp(A2) exp(Hermitian(A2))
@test cis(A2) cis(Hermitian(A2))
@test log(A2) log(Hermitian(A2))
A3 = A1 * A1' # posdef
@test exp(A3) exp(Hermitian(A3))
@test cis(A3) cis(Hermitian(A3))
@test log(A3) log(Hermitian(A3))

A1 = randn(4,4)
Expand Down

0 comments on commit 22ce688

Please sign in to comment.