Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve latency of matrix-exp #998

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions src/expm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,53 +77,54 @@ function _exp(::Size, _A::StaticMatrix{<:Any,<:Any,T}) where T
A = S.(_A)
# omitted: matrix balancing, i.e., LAPACK.gebal!
nA = maximum(sum(abs.(A); dims=Val(1))) # marginally more performant than norm(A, 1)
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
A2 = A*A
if nA > 0.95
U = @evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
elseif nA > 0.25
U = @evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
elseif nA > 0.015
U = @evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
else
U = @evalpoly(A2, S(60)*I, S(1)*I)
U = A*U
V = @evalpoly(A2, S(120)*I, S(12)*I)
end
expA = (V - U) \ (V + U)

if (nA ≤ 2.1) # for sufficiently small nA, use lower order Padé-Approximations
return _pade_exp(S, A, nA)
else
s = log2(nA/5.4) # power of 2 later reversed by squaring
if s > 0
si = ceil(Int,s)
A = A / S(2^si)
end

A2 = A*A
A4 = A2*A2
A6 = A2*A4

U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
S(32382376266240000)*I
U = A*U
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
S(64764752532480000)*I
expA = (V - U) \ (V + U)

if s > 0 # squaring to reverse dividing by power of 2
for t=1:si
expA = expA*expA
end
end
return _rescaled_exp(S, A, nA)
end
end

expA
function _pade_exp(S, A, nA)
A2 = A*A
U, V = if nA > 0.95
@evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I),
@evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
elseif nA > 0.25
@evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I),
@evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
elseif nA > 0.015
@evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I),
@evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
else
@evalpoly(A2, S(60)*I, S(1)*I),
@evalpoly(A2, S(120)*I, S(12)*I)
end
U = A*U
return (V - U) \ (V + U)
end

function _rescaled_exp(S, A, nA)
si = ceil(Int, log2(nA/5.4)) # power of 2 later reversed by squaring
if si > 0
A /= S(2^si)
end

A2 = A*A
A4 = A2*A2
A6 = A2*A4

U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
S(32382376266240000)*I
U = A*U
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
S(64764752532480000)*I
expA = (V - U) \ (V + U)

for _ in 1:si # squaring to reverse dividing by power of 2
expA *= expA
end
return expA
end