Skip to content

Commit

Permalink
Update USYMLQR for thenew low-level API of Krylov.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jan 5, 2025
1 parent 2c04c5d commit e167d3f
Showing 1 changed file with 49 additions and 49 deletions.
98 changes: 49 additions & 49 deletions src/usymlqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,43 +174,43 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
itmax == 0 && (itmax = n+m)

# Initial solutions r₀, x₀, y₀ and z₀.
@kfill!(rₖ, zero(FC))
@kfill!(xₖ, zero(FC))
@kfill!(yₖ, zero(FC))
@kfill!(zₖ, zero(FC))
kfill!(rₖ, zero(FC))
kfill!(xₖ, zero(FC))
kfill!(yₖ, zero(FC))
kfill!(zₖ, zero(FC))

# Initialize preconditioned orthogonal tridiagonalization process.
@kfill!(M⁻¹uₖ₋₁, zero(FC)) # u₀ = 0
@kfill!(N⁻¹vₖ₋₁, zero(FC)) # v₀ = 0
kfill!(M⁻¹uₖ₋₁, zero(FC)) # u₀ = 0
kfill!(N⁻¹vₖ₋₁, zero(FC)) # v₀ = 0

# [ I A ] [ xₖ ] = [ b - Δx - AΔy ] = [ b₀ ]
# [ Aᴴ ] [ yₖ ] [ c - AᴴΔx ] [ c₀ ]
if warm_start
mul!(b₀, A, Δy)
@kaxpy!(m, one(T), Δx, b₀)
@kaxpby!(m, one(T), b, -one(T), b₀)
kaxpy!(m, one(T), Δx, b₀)
kaxpby!(m, one(T), b, -one(T), b₀)
mul!(c₀, Aᴴ, Δx)
@kaxpby!(n, one(T), c, -one(T), c₀)
kaxpby!(n, one(T), c, -one(T), c₀)
end

# β₁Eu₁ = b ↔ β₁u₁ = Mb
M⁻¹uₖ .= b₀
kcopy!(m, M⁻¹uₖ, b₀)
MisI || mul!(uₖ, M, M⁻¹uₖ)
βₖ = sqrt(@kdot(m, uₖ, M⁻¹uₖ)) # β₁ = ‖u₁‖_E
βₖ = knorm_elliptic(m, uₖ, M⁻¹uₖ) # β₁ = ‖u₁‖_E
if βₖ 0
@kscal!(m, 1 / βₖ, M⁻¹uₖ)
MisI || @kscal!(m, 1 / βₖ, uₖ)
kscal!(m, 1 / βₖ, M⁻¹uₖ)
MisI || kscal!(m, 1 / βₖ, uₖ)
else
error("b must be nonzero")
end

# γ₁Fv₁ = c ↔ γ₁v₁ = Nc
N⁻¹vₖ .= c₀
kcopy!(n, N⁻¹vₖ, c₀)
NisI || mul!(vₖ, N, N⁻¹vₖ)
γₖ = sqrt(@kdot(n, vₖ, N⁻¹vₖ)) # γ₁ = ‖v₁‖_F
γₖ = knorm_elliptic(n, vₖ, N⁻¹vₖ) # γ₁ = ‖v₁‖_F
if γₖ 0
@kscal!(n, 1 / γₖ, N⁻¹vₖ)
NisI || @kscal!(n, 1 / γₖ, vₖ)
kscal!(n, 1 / γₖ, N⁻¹vₖ)
NisI || kscal!(n, 1 / γₖ, vₖ)
else
error("c must be nonzero")
end
Expand All @@ -222,10 +222,10 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim

cₖ₋₂ = cₖ₋₁ = cₖ = one(T) # Givens cosines used for the QR factorization of Tₖ₊₁.ₖ
sₖ₋₂ = sₖ₋₁ = sₖ = zero(FC) # Givens sines used for the QR factorization of Tₖ₊₁.ₖ
@kfill!(wₖ₋₂, zero(FC)) # Column k-2 of Wₖ = Vₖ(Rₖ)⁻¹
@kfill!(wₖ₋₁, zero(FC)) # Column k-1 of Wₖ = Vₖ(Rₖ)⁻¹
kfill!(wₖ₋₂, zero(FC)) # Column k-2 of Wₖ = Vₖ(Rₖ)⁻¹
kfill!(wₖ₋₁, zero(FC)) # Column k-1 of Wₖ = Vₖ(Rₖ)⁻¹
ϕbarₖ = βₖ # ϕbarₖ is the last component of f̄ₖ = (Qₖ)ᴴβ₁e₁
@kfill!(d̅, zero(FC)) # Last column of D̅ₖ = UₖQₖ
kfill!(d̅, zero(FC)) # Last column of D̅ₖ = UₖQₖ
ηₖ₋₁ = ηbarₖ = zero(FC) # ηₖ₋₁ and ηbarₖ are the last components of h̄ₖ = (Rₖ)⁻ᵀγ₁e₁
ηₖ₋₂ = θₖ = zero(FC) # ζₖ₋₂ and θₖ are used to update ηₖ₋₁ and ηbarₖ
δbarₖ₋₁ = δbarₖ = zero(FC) # Coefficients of Rₖ₋₁ and Rₖ modified over the course of two iterations
Expand Down Expand Up @@ -255,21 +255,21 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
mul!(p, Aᴴ, uₖ) # Forms Fvₖ₊₁ : p ← Aᴴuₖ

if iter 2
@kaxpy!(m, -γₖ, M⁻¹uₖ₋₁, q) # q ← q - γₖ * M⁻¹uₖ₋₁
@kaxpy!(n, -βₖ, N⁻¹vₖ₋₁, p) # p ← p - βₖ * N⁻¹vₖ₋₁
kaxpy!(m, -γₖ, M⁻¹uₖ₋₁, q) # q ← q - γₖ * M⁻¹uₖ₋₁
kaxpy!(n, -βₖ, N⁻¹vₖ₋₁, p) # p ← p - βₖ * N⁻¹vₖ₋₁
end

αₖ = @kdot(m, uₖ, q) # αₖ = ⟨uₖ,q⟩
αₖ = kdot(m, uₖ, q) # αₖ = ⟨uₖ,q⟩

@kaxpy!(m, - αₖ , M⁻¹uₖ, q) # q ← q - αₖ * M⁻¹uₖ
@kaxpy!(n, -conj(αₖ), N⁻¹vₖ, p) # p ← p - ᾱₖ * N⁻¹vₖ
kaxpy!(m, - αₖ , M⁻¹uₖ, q) # q ← q - αₖ * M⁻¹uₖ
kaxpy!(n, -conj(αₖ), N⁻¹vₖ, p) # p ← p - ᾱₖ * N⁻¹vₖ

# Compute vₖ₊₁ and uₖ₊₁
MisI || mulorldiv!(uₖ₊₁, M, q, ldiv) # βₖ₊₁uₖ₊₁ = MAvₖ - γₖuₖ₋₁ - αₖuₖ
NisI || mulorldiv!(vₖ₊₁, N, p, ldiv) # γₖ₊₁vₖ₊₁ = NAᴴuₖ - βₖvₖ₋₁ - ᾱₖvₖ

βₖ₊₁ = sqrt(@kdotr(m, uₖ₊₁, q)) # βₖ₊₁ = ‖uₖ₊₁‖_E
γₖ₊₁ = sqrt(@kdotr(n, vₖ₊₁, p)) # γₖ₊₁ = ‖vₖ₊₁‖_F
βₖ₊₁ = knorm_elliptic(m, uₖ₊₁, q) # βₖ₊₁ = ‖uₖ₊₁‖_E
γₖ₊₁ = knorm_elliptic(n, vₖ₊₁, p) # γₖ₊₁ = ‖vₖ₊₁‖_F

# Update M⁻¹uₖ₋₁ and N⁻¹vₖ₋₁
M⁻¹uₖ₋₁ .= M⁻¹uₖ
Expand Down Expand Up @@ -325,32 +325,32 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
# w₁ = v₁ / δ₁
if iter == 1
wₖ = wₖ₋₁
@kaxpy!(n, one(FC), vₖ, wₖ)
kaxpy!(n, one(FC), vₖ, wₖ)
wₖ .= wₖ ./ δₖ
end
# w₂ = (v₂ - λ₁w₁) / δ₂
if iter == 2
wₖ = wₖ₋₂
@kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
@kaxpy!(n, one(FC), vₖ, wₖ)
kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
kaxpy!(n, one(FC), vₖ, wₖ)
wₖ .= wₖ ./ δₖ
end
# wₖ = (vₖ - λₖ₋₁wₖ₋₁ - ϵₖ₋₂wₖ₋₂) / δₖ
if iter 3
@kscal!(n, -ϵₖ₋₂, wₖ₋₂)
kscal!(n, -ϵₖ₋₂, wₖ₋₂)
wₖ = wₖ₋₂
@kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
@kaxpy!(n, one(FC), vₖ, wₖ)
kaxpy!(n, -λₖ₋₁, wₖ₋₁, wₖ)
kaxpy!(n, one(FC), vₖ, wₖ)
wₖ .= wₖ ./ δₖ
end

# Update the solution xₖ.
# xₖ ← xₖ₋₁ + ϕₖ * wₖ
@kaxpy!(n, ϕₖ, wₖ, xₖ)
kaxpy!(n, ϕₖ, wₖ, xₖ)

# Update the residual rₖ.
# rₖ ← |sₖ|² * rₖ₋₁ - cₖ * ϕbarₖ₊₁ * uₖ₊₁
@kaxpby!(n, cₖ * ϕbarₖ₊₁, q, abs2(sₖ), rₖ)
kaxpby!(n, cₖ * ϕbarₖ₊₁, q, abs2(sₖ), rₖ)

# Compute ‖rₖ‖ = |ϕbarₖ₊₁|.
rNorm = abs(ϕbarₖ₊₁)
Expand Down Expand Up @@ -389,17 +389,17 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
if iter 2
# Compute solution yₖ.
# (yᴸ)ₖ₋₁ ← (yᴸ)ₖ₋₂ + ηₖ₋₁ * dₖ₋₁
@kaxpy!(n, ηₖ₋₁ * cₖ, d̅, x)
@kaxpy!(n, ηₖ₋₁ * sₖ, uₖ, x)
kaxpy!(n, ηₖ₋₁ * cₖ, d̅, x)
kaxpy!(n, ηₖ₋₁ * sₖ, uₖ, x)
end

# Compute d̅ₖ.
if iter == 1
# d̅₁ = u₁
@kcopy!(n, uₖ, d̅) # d̅ ← vₖ
kcopy!(n, d̅, uₖ) # d̅ ← vₖ
else
# d̅ₖ = s̄ₖ * d̅ₖ₋₁ - cₖ * uₖ
@kaxpby!(n, -cₖ, uₖ, conj(sₖ), d̅)
kaxpby!(n, -cₖ, uₖ, conj(sₖ), d̅)
end

# Compute USYMLQ residual norm
Expand All @@ -422,11 +422,11 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
end

# Compute zₖ.
@kaxpy!(n, -ηₖ, wₖ, zₖ)
kaxpy!(n, -ηₖ, wₖ, zₖ)

# Compute uₖ₊₁ and vₖ₊₁.
@kcopy!(m, uₖ, uₖ₋₁) # uₖ₋₁ ← uₖ
@kcopy!(n, vₖ, vₖ₋₁) # vₖ₋₁ ← vₖ
kcopy!(m, uₖ₋₁, uₖ) # uₖ₋₁ ← uₖ
kcopy!(n, vₖ₋₁, vₖ) # vₖ₋₁ ← vₖ

if βₖ₊₁ zero(T)
uₖ .= q ./ βₖ₊₁ # βₖ₊₁uₖ₊₁ = q
Expand All @@ -437,7 +437,7 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim

# Update directions for x.
if iter 2
@kswap(wₖ₋₂, wₖ₋₁)
@kswap!(wₖ₋₂, wₖ₋₁)
end

# Update sₖ₋₂, cₖ₋₂, sₖ₋₁, cₖ₋₁, ϕbarₖ, γₖ, βₖ.
Expand Down Expand Up @@ -480,8 +480,8 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
# (yᶜ)ₖ ← (yᴸ)ₖ₋₁ + ηbarₖ * d̅ₖ
# (zᶜ)ₖ ← (zᴸ)ₖ₋₁ - ηbarₖ * w̄ₖ
if solved_cg
@kaxpy!(n, ηbarₖ, d̅, yₖ)
@kaxpy!(m, -ηbarₖ, w̄, zₖ)
kaxpy!(n, ηbarₖ, d̅, yₖ)
kaxpy!(m, -ηbarₖ, w̄, zₖ)
end

# Termination status
Expand All @@ -497,12 +497,12 @@ kwargs_usymlqr = (:transfer_to_usymcg, :M, :N, :ldiv, :atol, :rtol, :itmax, :tim
# Compute the solution the saddle point system
# xₖ ← xₖ + zₖ
# yₖ ← yₖ + rₖ
@kaxpy!(n, one(FC), zₖ, xₖ)
@kaxpy!(m, one(FC), rₖ, yₖ)
kaxpy!(n, one(FC), zₖ, xₖ)
kaxpy!(m, one(FC), rₖ, yₖ)

# Update xₖ and yₖ
warm_start && @kaxpy!(n, one(FC), Δxz, xₖ)
warm_start && @kaxpy!(m, one(FC), Δyr, yₖ)
warm_start && kaxpy!(n, one(FC), Δxz, xₖ)
warm_start && kaxpy!(m, one(FC), Δyr, yₖ)
solver.warm_start = false

# Update stats
Expand Down

0 comments on commit e167d3f

Please sign in to comment.