From 387d9dd6dce7d3db98deb0acc072e4d6d6901284 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 30 May 2024 00:38:17 +0200 Subject: [PATCH] add Hermitian test and fix implementation --- ext/KrylovKitChainRulesCoreExt/eigsolve.jl | 3 +- test/ad.jl | 108 ++++++++++++++++++--- 2 files changed, 99 insertions(+), 12 deletions(-) diff --git a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl index 89ead6d..7f62a54 100644 --- a/ext/KrylovKitChainRulesCoreExt/eigsolve.jl +++ b/ext/KrylovKitChainRulesCoreExt/eigsolve.jl @@ -280,9 +280,10 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, end # several simplications happen in the case of a Hermitian eigenvalue problem -function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T, +function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, alg_primal::Lanczos, alg_rrule::Arnoldi) n = length(Δvecs) + T = scalartype(vecs[1]) VdΔV = zeros(T, n, n) for j in 1:n for i in 1:n diff --git a/test/ad.jl b/test/ad.jl index 8fcbf92..e72c202 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -89,7 +89,7 @@ end (JA2, Jb2, Jx2) = Zygote.jacobian(mat_example_fun, Avec, bvec, xvec) @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=3 * eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) # factor 2 is minimally necessary for complex case, but 3 is more robust @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) @test all(iszero, Jx1) @@ -263,6 +263,60 @@ function build_fun_example(A, x, c, d, howmany::Int, which, alg, alg_rrule) return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany end +function build_hermitianfun_example(A, x, c, howmany::Int, which, alg, alg_rrule) + Avec, matfromvec = to_vec(A) + xvec, xvecfromvec = to_vec(x) + cvec, cvecfromvec = to_vec(c) + + vals, vecs, info = eigsolve(x, howmany, which, alg) do y + return Hermitian(A) * y + c * dot(c, y) + end + info.converged < howmany && @warn "eigsolve did not converge" + + function fun_example(Av, xv, cv) + Ã = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg; + alg_rrule=alg_rrule) do y + return Hermitian(Ã) * y + c̃ * dot(c̃, y) + end + info′.converged < howmany && @warn "eigsolve did not converge" + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + function fun_example_fd(Av, xv, cv) + Ã = matfromvec(Av) + x̃ = xvecfromvec(xv) + c̃ = cvecfromvec(cv) + + vals′, vecs′, info′ = eigsolve(x̃, howmany, which, alg) do y + return Hermitian(Ã) * y + c̃ * dot(c̃, y) + end + info′.converged < howmany && @warn "eigsolve did not converge" + for i in 1:howmany + d = dot(vecs[i], vecs′[i]) + @assert abs(d) > sqrt(eps(real(eltype(A)))) + phasefix = abs(d) / d + vecs′[i] = vecs′[i] * phasefix + end + catresults = vcat(vals′[1:howmany], vecs′[1:howmany]...) + if eltype(catresults) <: Complex + return vcat(real(catresults), imag(catresults)) + else + return catresults + end + end + + return fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany +end + @timedtestset "Small eigsolve AD test for eltype=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) @@ -313,7 +367,7 @@ end # finite difference comparison using some kind of tolerance heuristic @test isapprox(JA, JA1; rtol=condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=3 * eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) @test norm(Jx, Inf) < condA * sqrt(eps(real(T))) @test all(iszero, Jx1) @test all(iszero, Jx2) @@ -361,24 +415,57 @@ end alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n) alg_rrule2 = GMRES(; tol=tol, krylovdim=2n) @testset for alg_rrule in (alg_rrule1, alg_rrule2) - fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A, + fun_example, fun_example_fd, Avec, xvec, cvec, dvec, vals, vecs, howmany = build_fun_example(A, + x, + c, + d, + howmany, + which, + alg, + alg_rrule) + + (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec, dvec) + (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example, Avec, xvec, cvec, dvec) + @test JA ≈ JA′ + @test Jc ≈ Jc′ + @test Jd ≈ Jd′ + end + end +end +@timedtestset "Large Hermitian eigsolve AD test with eltype=$T" for T in + (Float64, ComplexF64) + whichlist = (:LR, :SR) + @testset for which in whichlist + A = rand(T, (N, N)) .- one(T) / 2 + A = I - (9 // 10) * A / maximum(abs, eigvals(A)) + x = 2 * (rand(T, N) .- one(T) / 2) + x /= norm(x) + c = 2 * (rand(T, N) .- one(T) / 2) + + howmany = 2 + tol = 2 * N^2 * eps(real(T)) + alg = Lanczos(; tol=tol, krylovdim=2n) + alg_rrule1 = Arnoldi(; tol=tol, krylovdim=2n) + alg_rrule2 = GMRES(; tol=tol, krylovdim=2n) + @testset for alg_rrule in (alg_rrule1, alg_rrule2) + fun_example, fun_example_fd, Avec, xvec, cvec, vals, vecs, howmany = build_hermitianfun_example(A, x, c, - d, howmany, which, alg, alg_rrule) - (JA, Jx, Jc, Jd) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, - cvec, dvec) - (JA′, Jx′, Jc′, Jd′) = Zygote.jacobian(fun_example_ad, Avec, xvec, cvec, dvec) + (JA, Jx, Jc) = FiniteDifferences.jacobian(fdm, fun_example_fd, Avec, xvec, + cvec) + (JA′, Jx′, Jc′) = Zygote.jacobian(fun_example, Avec, xvec, cvec) @test JA ≈ JA′ @test Jc ≈ Jc′ - @test Jd ≈ Jd′ end end end + end module SvdsolveAD @@ -522,7 +609,6 @@ function build_fun_example(A, x, c, d, howmany::Int, alg, alg_rrule) return catresults end end - return fun_example_ad, fun_example_fd, Avec, xvec, cvec, dvec, vals, lvecs, rvecs end @@ -564,8 +650,8 @@ end # finite difference comparison using some kind of tolerance heuristic @test isapprox(JA, JA1; rtol=3 * n * n * condA * sqrt(eps(real(T)))) - @test all(isapprox.(JA1, JA2; atol=3 * eps(real(T)))) - @test all(isapprox.(JA1, JA3; atol=3 * eps(real(T)))) + @test all(isapprox.(JA1, JA2; atol=n * eps(real(T)))) + @test all(isapprox.(JA1, JA3; atol=n * eps(real(T)))) @test norm(Jx, Inf) < 4 * condA * sqrt(eps(real(T))) @test all(iszero, Jx1) @test all(iszero, Jx2)