Skip to content

Commit

Permalink
some updates; general svd rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed May 23, 2024
1 parent c4f6a48 commit 6213881
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 173 deletions.
2 changes: 2 additions & 0 deletions ext/KrylovKitChainRulesCoreExt/KrylovKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using ChainRulesCore
using LinearAlgebra
using VectorInterface

using KrylovKit: apply_normal, apply_adjoint

include("utilities.jl")
include("linsolve.jl")
include("eigsolve.jl")
Expand Down
73 changes: 47 additions & 26 deletions ext/KrylovKitChainRulesCoreExt/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ function ChainRulesCore.rrule(config::RuleConfig,

ws = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n),
info, which, fᴴ, T, alg_primal, alg_rrule)
# alg_rrule2 = Arnoldi(; tol=alg_rrule.tol, krylovdim=alg_rrule.krylovdim, maxiter=alg_rrule.maxiter, orth=alg_rrule.orth)
# ws2 = compute_eigsolve_pullback_data(Δvals, Δvecs, view(vals, 1:n), view(vecs, 1:n), info, which, fᴴ, T, alg_primal, alg_rrule2)
# for i = 1:n
# @show ws[i]
# @show ws2[i]
# end

∂f = construct∂f(ws)
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
end
Expand Down Expand Up @@ -106,22 +99,23 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
else
vdΔv = inner(v, Δv)
gaugeᵢ = abs(imag(vdΔv))
gaugeᵢ < alg_primal.tol ||
if gaugeᵢ > alg_primal.tol && alg_rrule.verbosity >= 1
@warn "`eigsolve` cotangent for eigenvector $i is sensitive to gauge choice: (|gaugeᵢ| = $gaugeᵢ)"

Check warning on line 103 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L100-L103

Added lines #L100 - L103 were not covered by tests
end
Δv = add(Δv, v, -vdΔv)
b = (Δv, convert(T, Δλ))

Check warning on line 106 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L105-L106

Added lines #L105 - L106 were not covered by tests
end
w, reverse_info = let λ = λ, v = v
linsolve(b, zerovector(b), alg_rrule) do x
x1, x2 = x
y1 = add!(add!(fᴴ(x1), x1, conj(λ), -1), v, x2)
y1 = VectorInterface.add!!(VectorInterface.add!!(fᴴ(x1), x1, conj(λ), -1), v, x2)
y2 = inner(v, x1)
return (y1, y2)

Check warning on line 113 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L108-L113

Added lines #L108 - L113 were not covered by tests
end
end
if info.converged >= i && reverse_info.converged == 0
@warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did"
elseif abs(w[2]) > alg_rrule.tol
if info.converged >= i && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent linear problem ($i) did not converge, whereas the primal eigenvalue problem did: normres = $(reverse_info.normres)"
elseif abs(w[2]) > alg_rrule.tol && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent linear problem ($i) returns unexpected result: error = $(w[2])"

Check warning on line 119 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L116-L119

Added lines #L116 - L119 were not covered by tests
end
ws[i] = w[1]
Expand Down Expand Up @@ -154,8 +148,9 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = VdΔV[mask] - Diagonal(real(diag(VdΔV)))[mask]
Δgauge = norm(gaugepart, Inf)
Δgauge < tol ||
if Δgauge > tol && alg_rrule.verbosity >= 1
@warn "`eigsolve` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Check warning on line 152 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L152

Added line #L152 was not covered by tests
end
VdΔV′ = VdΔV - G * Diagonal(diag(VdΔV) ./ diag(G))
aVdΔV = VdΔV′ .* conj.(safe_inv.(transpose(vals) .- vals, tol))
for i in 1:n
Expand Down Expand Up @@ -188,33 +183,45 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
end

W₀ = (zerovector(vecs[1]), one.(vals))
P = orthogonalcomplementprojector(vecs, n, Gc)
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg
P = orthogonalprojector(vecs, n, Gc)
by, rev = KrylovKit.eigsort(which)
if (rev ? (by(vals[n]) < by(zero(vals[n]))) : (by(vals[n]) > by(zero(vals[n]))))
shift = 2*conj(vals[n])

Check warning on line 189 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L189

Added line #L189 was not covered by tests
else
shift = zero(vals[n])
end
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift
eigsolve(W₀, n, reverse_wich(which), alg_rrule) do W
w, x = W
w′ = fᴴ(P(w))
w₀ = P(w)
w′ = fᴴ(add(w, w₀, -1))
if !iszero(shift)
w′ = VectorInterface.add!!(w′, w₀, shift)

Check warning on line 199 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L199

Added line #L199 was not covered by tests
end
@inbounds for i in 1:length(x) # length(x) = n but let us not use outer variables
w′ = VectorInterface.add!!(w′, ΔV[i], -x[i])
end
return (w′, conj.(vals) .* x)
end
end
if info.converged >= n && reverse_info.converged < n
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"

Check warning on line 208 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L208

Added line #L208 was not covered by tests
end

# cleanup and construct final result
ws = zs
tol = alg_rrule.tol
Q = orthogonalcomplementprojector(vecs, n, Gc)
for i in 1:n
w, x = Ws[i]
_, ic = findmax(abs, x)
factor = 1 / x[ic]
x[ic] = zero(x[ic])
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
error < tol ||
if error > 5*tol && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"

Check warning on line 222 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L222

Added line #L222 was not covered by tests
ws[ic] = VectorInterface.add!!(zs[ic], P(w), -factor)
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
end
return ws
end
Expand All @@ -238,8 +245,9 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
mask = abs.(transpose(vals) .- vals) .< tol
gaugepart = view(aVdΔV, mask)
Δgauge = norm(gaugepart, Inf)
Δgauge < tol ||
if Δgauge > tol && alg_rrule.verbosity >= 1
@warn "`eigsolve` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Check warning on line 249 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L243-L249

Added lines #L243 - L249 were not covered by tests
end
aVdΔV .= aVdΔV .* safe_inv.(transpose(vals) .- vals, tol)
for i in 1:n
aVdΔV[i, i] += real(Δvals[i])
Expand Down Expand Up @@ -267,34 +275,47 @@ function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which,
sylvesterarg[i] = y
end

Check warning on line 276 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L275-L276

Added lines #L275 - L276 were not covered by tests


W₀ = (zerovector(vecs[1]), one.(vals))
P = orthogonalcomplementprojector(vecs, n)
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg
P = orthogonalprojector(vecs, n)
by, rev = KrylovKit.eigsort(which)
if (rev ? (by(vals[n]) < by(zero(vals[n]))) : (by(vals[n]) > by(zero(vals[n]))))
shift = 2*conj(vals[n])

Check warning on line 283 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L279-L283

Added lines #L279 - L283 were not covered by tests
else
shift = zero(vals[n])

Check warning on line 285 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L285

Added line #L285 was not covered by tests
end
rvals, Ws, reverse_info = let P = P, ΔV = sylvesterarg, shift = shift
eigsolve(W₀, n, reverse_wich(which), alg_rrule) do W
w, x = W
w′ = fᴴ(P(w))
w₀ = P(w)
w′ = fᴴ(add(w, w₀, -1))
if !iszero(shift)
w′ = VectorInterface.add!!(w′, w₀, shift)

Check warning on line 293 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L287-L293

Added lines #L287 - L293 were not covered by tests
end
@inbounds for i in 1:length(x) # length(x) = n but let us not use outer variables
w′ = VectorInterface.add!!(w′, ΔV[i], -x[i])
end
return (w′, vals .* x)

Check warning on line 298 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L295-L298

Added lines #L295 - L298 were not covered by tests
end
end
if info.converged >= n && reverse_info.converged < n
if info.converged >= n && reverse_info.converged < n && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent problem did not converge, whereas the primal eigenvalue problem did"

Check warning on line 302 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L301-L302

Added lines #L301 - L302 were not covered by tests
end

# cleanup and construct final result
ws = zs
tol = alg_rrule.tol
Q = orthogonalcomplementprojector(vecs, n)
for i in 1:n
w, x = Ws[i]
_, ic = findmax(abs, x)
factor = 1 / x[ic]
x[ic] = zero(x[ic])
error = max(norm(x, Inf), abs(rvals[i] - conj(vals[ic])))
error < tol ||
if error > 5*tol && alg_rrule.verbosity >= 0
@warn "`eigsolve` cotangent linear problem ($ic) returns unexpected result: error = $error"

Check warning on line 316 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L306-L316

Added lines #L306 - L316 were not covered by tests
ws[ic] = VectorInterface.add!!(zs[ic], P(w), -factor)
end
ws[ic] = VectorInterface.add!!(zs[ic], Q(w), -factor)
end
return ws

Check warning on line 320 in ext/KrylovKitChainRulesCoreExt/eigsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/eigsolve.jl#L318-L320

Added lines #L318 - L320 were not covered by tests
end
81 changes: 6 additions & 75 deletions ext/KrylovKitChainRulesCoreExt/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ function ChainRulesCore.rrule(config::RuleConfig,
alg_primal,
a₀,
a₁; alg_rrule=alg_primal)

(x, info) = linsolve(f, b, x₀, alg_primal, a₀, a₁)
T, fᴴ, construct∂f = _prepare_inputs(config, f, (x,), alg_primal)

Expand All @@ -16,8 +17,9 @@ function ChainRulesCore.rrule(config::RuleConfig,
∂algorithm = NoTangent()
∂b, reverse_info = linsolve(fᴴ, x̄, (zero(a₀) * zero(a₁)) * x̄, alg_rrule, conj(a₀),
conj(a₁))
info.converged > 0 && reverse_info.converged == 0 &&
@warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem did"
if info.converged > 0 && reverse_info.converged == 0 && alg_rrule.verbosity >= 0
@warn "`linsolve` cotangent problem did not converge, whereas the primal linear problem di: normres = $(reverse_info.normres)"

Check warning on line 21 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L21

Added line #L21 was not covered by tests
end

∂f = construct∂f((scale(∂b, -conj(a₁)),))
∂a₀ = @thunk(-inner(x, ∂b))
Expand All @@ -31,77 +33,6 @@ function ChainRulesCore.rrule(config::RuleConfig,
return (x, info), linsolve_pullback
end

# function generate_linsolve_pullback(alg_rrule, A::AbstractMatrix, b::AbstractVector, a₀, a₁,
# x, info, alg_primal)
# project_A = ProjectTo(A)

# function linsolve_pullback(X̄)
# x̄ = unthunk(X̄[1])
# ∂self = NoTangent()
# ∂x₀ = ZeroTangent()
# ∂algorithm = NoTangent()
# ∂b, reverse_info = linsolve(A', x̄, (zero(a₀) * zero(a₁)) * x̄, alg_rrule, conj(a₀),
# conj(a₁))
# if info.converged > 0 && reverse_info.converged == 0
# @warn "The cotangent linear problem did not converge, whereas the primal linear problem did."
# end
# if A isa StridedMatrix
# ∂A = InplaceableThunk(Ā -> mul!(Ā, ∂b, x', -conj(a₁), true),
# @thunk(-conj(a₁) * ∂b * x'))
# else
# ∂A = @thunk(project_A(-conj(a₁) * ∂b * x'))
# end
# ∂a₀ = @thunk(-dot(x, ∂b))
# if a₀ == zero(a₀) && a₁ == one(a₁)
# ∂a₁ = @thunk(-dot(b, ∂b))
# else
# ∂a₁ = @thunk(-dot((b - a₀ * x) / a₁, ∂b))
# end
# return ∂self, ∂A, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁
# end
# return linsolve_pullback
# end

# function generate_linsolve_pullback(config::RuleConfig{>:HasReverseMode},
# alg_rrule,
# f,
# b,
# a₀,
# a₁,
# x,
# info,
# alg_primal)

# # f defines a linear map => pullback defines action of the adjoint
# (y, f_pullback) = rrule_via_ad(config, f, x)
# fᴴ(xᴴ) = f_pullback(xᴴ)[2]
# # TODO can we avoid computing f_pullback if algorithm isa Union{CG,MINRES}?

# function linsolve_pullback(X̄)
# x̄ = unthunk(X̄[1])
# ∂self = NoTangent()
# ∂x₀ = ZeroTangent()
# ∂algorithm = NoTangent()
# T = VectorInterface.promote_scale(VectorInterface.promote_scale(x̄, a₀),
# scalartype(a₁))
# ∂b, reverse_info = linsolve(fᴴ, x̄, zerovector(x̄, T), alg_rrule, conj(a₀),
# conj(a₁))
# if reverse_info.converged == 0
# @warn "Linear problem for reverse rule did not converge." reverse_info
# end
# ∂f = @thunk(f_pullback(scale(∂b, -conj(a₁)))[1])
# ∂a₀ = @thunk(-inner(x, ∂b))
# # ∂a₁ = @thunk(-dot(f(x), ∂b))
# if a₀ == zero(a₀) && a₁ == one(a₁)
# ∂a₁ = @thunk(-inner(b, ∂b))
# else
# ∂a₁ = @thunk(-inner(scale!!(add(b, x, -a₀), inv(a₁)), ∂b))
# end
# return ∂self, ∂f, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁
# end
# return linsolve_pullback
# end

# frule - currently untested

function ChainRulesCore.frule((_, ΔA, Δb, Δx₀, _, Δa₀, Δa₁)::Tuple, ::typeof(linsolve),

Check warning on line 38 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L38

Added line #L38 was not covered by tests
Expand All @@ -120,7 +51,7 @@ function ChainRulesCore.frule((_, ΔA, Δb, Δx₀, _, Δa₀, Δa₁)::Tuple, :
rhs = mul!(rhs, ΔA, x, -a₁, true)

Check warning on line 51 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end
(Δx, forward_info) = linsolve(A, rhs, zerovector(rhs), algorithm, a₀, a₁)
if info.converged > 0 && forward_info.converged == 0
if info.converged > 0 && forward_info.converged == 0 && alg_rrule.verbosity >= 0
@warn "The tangent linear problem did not converge, whereas the primal linear problem did."

Check warning on line 55 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L53-L55

Added lines #L53 - L55 were not covered by tests
end
return (x, info), (Δx, NoTangent())

Check warning on line 57 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L57

Added line #L57 was not covered by tests
Expand Down Expand Up @@ -150,7 +81,7 @@ function ChainRulesCore.frule(config::RuleConfig{>:HasForwardsMode},
rhs = add!!(rhs, frule_via_ad(config, (Δf, ZeroTangent()), f, x), -a₀)

Check warning on line 81 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
end
(Δx, forward_info) = linsolve(f, rhs, zerovector(rhs), algorithm, a₀, a₁)
if info.converged > 0 && forward_info.converged == 0
if info.converged > 0 && forward_info.converged == 0 && alg_rrule.verbosity >= 0
@warn "The tangent linear problem did not converge, whereas the primal linear problem did."

Check warning on line 85 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L83-L85

Added lines #L83 - L85 were not covered by tests
end
return (x, info), (Δx, NoTangent())

Check warning on line 87 in ext/KrylovKitChainRulesCoreExt/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

ext/KrylovKitChainRulesCoreExt/linsolve.jl#L87

Added line #L87 was not covered by tests
Expand Down
Loading

0 comments on commit 6213881

Please sign in to comment.