Skip to content

Commit

Permalink
add Hermitian test and fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed May 29, 2024
1 parent e703eef commit 387d9dd
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 12 deletions.
3 changes: 2 additions & 1 deletion ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
end

# several simplications happen in the case of a Hermitian eigenvalue problem
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T,
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ,
alg_primal::Lanczos, alg_rrule::Arnoldi)
n = length(Δvecs)
T = scalartype(vecs[1])
VdΔV = zeros(T, n, n)
for j in 1:n
for i in 1:n
Expand Down
108 changes: 97 additions & 11 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end
(JA2, Jb2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, bvec, xvec)

@test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=3 * eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=n * eps(real(T))))
# factor 2 is minimally necessary for complex case, but 3 is more robust
@test norm(Jx, Inf) < condA * sqrt(eps(real(T)))
@test all(iszero, Jx1)
Expand Down Expand Up @@ -263,6 +263,60 @@ function build_fun_example(A, x, c, d, howmany::Int, which, alg, alg_rrule)
return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany
end

function build_hermitianfun_example(A, x, c, howmany::Int, which, alg, alg_rrule)
Avec, matfromvec = to_vec(A)
xvec, xvecfromvec = to_vec(x)
cvec, cvecfromvec = to_vec(c)

vals, vecs, info = eigsolve(x, howmany, which, alg) do y
return Hermitian(A) * y + c * dot(c, y)
end
info.converged < howmany && @warn "eigsolve did not converge"

function fun_example(Av, xv, cv)
= matfromvec(Av)
= xvecfromvec(xv)
= cvecfromvec(cv)

vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg;
alg_rrule=alg_rrule) do y
return Hermitian(Ã) * y +* dot(c̃, y)
end
info′.converged < howmany && @warn "eigsolve did not converge"
catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...)
if eltype(catresults) <: Complex
return vcat(real(catresults), imag(catresults))
else
return catresults
end
end

function fun_example_fd(Av, xv, cv)
= matfromvec(Av)
= xvecfromvec(xv)
= cvecfromvec(cv)

vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg) do y
return Hermitian(Ã) * y +* dot(c̃, y)
end
info′.converged < howmany && @warn "eigsolve did not converge"
for i in 1:howmany
d = dot(vecs[i], vecs′[i])
@assert abs(d) > sqrt(eps(real(eltype(A))))
phasefix = abs(d) / d
vecs′[i] = vecs′[i] * phasefix
end
catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...)
if eltype(catresults) <: Complex
return vcat(real(catresults), imag(catresults))
else
return catresults
end
end

return fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany
end

@timedtestset "Small eigsolve AD test for eltype=$T" for T in
(Float32, Float64, ComplexF32,
ComplexF64)
Expand Down Expand Up @@ -313,7 +367,7 @@ end

# finite difference comparison using some kind of tolerance heuristic
@test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=3 * eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=n * eps(real(T))))
@test norm(Jx, Inf) < condA * sqrt(eps(real(T)))
@test all(iszero, Jx1)
@test all(iszero, Jx2)
Expand Down Expand Up @@ -361,24 +415,57 @@ end
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n)
alg_rrule2 = GMRES(; tol=tol, krylovdim=2n)
@testset for alg_rrule in (alg_rrule1, alg_rrule2)
fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A,
fun_example, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A,
x,
c,
d,
howmany,
which,
alg,
alg_rrule)

(JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec,
cvec, dvec)
(JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example, Avec, xvec, cvec, dvec)
@test JA JA′
@test Jc Jc′
@test Jd Jd′
end
end
end
@timedtestset "Large Hermitian eigsolve AD test with eltype=$T" for T in
(Float64, ComplexF64)
whichlist = (:LR, :SR)
@testset for which in whichlist
A = rand(T, (N, N)) .- one(T) / 2
A = I - (9 // 10) * A / maximum(abs, eigvals(A))
x = 2 * (rand(T, N) .- one(T) / 2)
x /= norm(x)
c = 2 * (rand(T, N) .- one(T) / 2)

howmany = 2
tol = 2 * N^2 * eps(real(T))
alg = Lanczos(; tol=tol, krylovdim=2n)
alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n)
alg_rrule2 = GMRES(; tol=tol, krylovdim=2n)
@testset for alg_rrule in (alg_rrule1, alg_rrule2)
fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany = build_hermitianfun_example(A,
x,
c,
d,
howmany,
which,
alg,
alg_rrule)

(JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec,
cvec, dvec)
(JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example_ad, Avec, xvec, cvec, dvec)
(JA, Jx, Jc) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec,
cvec)
(JA′, Jx′, Jc′) = Zygote.jacobian(fun_example, Avec, xvec, cvec)
@test JA JA′
@test Jc Jc′
@test Jd Jd′
end
end
end

end

module SvdsolveAD
Expand Down Expand Up @@ -522,7 +609,6 @@ function build_fun_example(A, x, c, d, howmany::Int, alg, alg_rrule)
return catresults
end
end

return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs
end

Expand Down Expand Up @@ -564,8 +650,8 @@ end

# finite difference comparison using some kind of tolerance heuristic
@test isapprox(JA, JA1; rtol=3 * n * n * condA * sqrt(eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=3 * eps(real(T))))
@test all(isapprox.(JA1, JA3; atol=3 * eps(real(T))))
@test all(isapprox.(JA1, JA2; atol=n * eps(real(T))))
@test all(isapprox.(JA1, JA3; atol=n * eps(real(T))))
@test norm(Jx, Inf) < 4 * condA * sqrt(eps(real(T)))
@test all(iszero, Jx1)
@test all(iszero, Jx2)
Expand Down

0 comments on commit 387d9dd

Please sign in to comment.