Skip to content

Commit

Permalink
Merge pull request #39 from gridap/krylov-solvers
Browse files Browse the repository at this point in the history
Krylov solvers
  • Loading branch information
JordiManyer authored Oct 2, 2023
2 parents 55acbe9 + 1ed6a39 commit 744896d
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 52 deletions.
92 changes: 92 additions & 0 deletions src/LinearSolvers/Krylov/CGSolvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

struct CGSolver <: Gridap.Algebra.LinearSolver
Pl :: Gridap.Algebra.LinearSolver
maxiter :: Int64
atol :: Float64
rtol :: Float64
flexible :: Bool
verbose :: Bool
end

function CGSolver(Pl;maxiter=10000,atol=1e-12,rtol=1.e-6,flexible=false,verbose=false)
return CGSolver(Pl,maxiter,atol,rtol,flexible,verbose)
end

struct CGSymbolicSetup <: Gridap.Algebra.SymbolicSetup
solver
end

function Gridap.Algebra.symbolic_setup(solver::CGSolver, A::AbstractMatrix)
return CGSymbolicSetup(solver)
end

mutable struct CGNumericalSetup <: Gridap.Algebra.NumericalSetup
solver
A
Pl_ns
caches
end

function get_solver_caches(solver::CGSolver,A)
w = allocate_col_vector(A)
p = allocate_col_vector(A)
z = allocate_col_vector(A)
r = allocate_col_vector(A)
return (w,p,z,r)
end

function Gridap.Algebra.numerical_setup(ss::CGSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pl_ns = numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_solver_caches(solver,A)
return CGNumericalSetup(solver,A,Pl_ns,caches)
end

function Gridap.Algebra.numerical_setup!(ns::CGNumericalSetup, A::AbstractMatrix)
numerical_setup!(ns.Pl_ns,A)
ns.A = A
end

function Gridap.Algebra.solve!(x::AbstractVector,ns::CGNumericalSetup,b::AbstractVector)
solver, A, Pl, caches = ns.solver, ns.A, ns.Pl_ns, ns.caches
maxiter, atol, rtol = solver.maxiter, solver.atol, solver.rtol
flexible, verbose = solver.flexible, solver.verbose
w,p,z,r = caches
verbose && println(" > Starting CG solver: ")

# Initial residual
mul!(w,A,x); r .= b .- w
fill!(p,0.0); γ = 1.0

res = norm(r); res_0 = res
iter = 0; converged = false
while !converged && (iter < maxiter)
verbose && println(" > Iteration ", iter," - Residual: ", res)

if !flexible # β = (zₖ₊₁ ⋅ rₖ₊₁)/(zₖ ⋅ rₖ)
solve!(z, Pl, r)
β = γ; γ = dot(z, r); β = γ / β
else # β = (zₖ₊₁ ⋅ (rₖ₊₁-rₖ))/(zₖ ⋅ rₖ)
β = γ; γ = dot(z, r)
solve!(z, Pl, r)
γ = dot(z, r) - γ; β = γ / β
end
p .= z .+ β .* p

# w = A⋅p
mul!(w,A,p)
α = γ / dot(p, w)

# Update solution and residual
x .+= α .* p
r .-= α .* w

res = norm(r)
converged = (res < atol || res < rtol*res_0)
iter += 1
end
verbose && println(" > Num Iter: ", iter," - Final residual: ", res)
verbose && println(" Exiting CG solver.")

return x
end
127 changes: 127 additions & 0 deletions src/LinearSolvers/Krylov/FGMRESSolvers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@

# FGMRES Solver
struct FGMRESSolver <: Gridap.Algebra.LinearSolver
m :: Int
Pr :: Gridap.Algebra.LinearSolver
Pl :: Union{Gridap.Algebra.LinearSolver,Nothing}
atol :: Float64
rtol :: Float64
verbose :: Bool
end

function FGMRESSolver(m,Pr;Pl=nothing,atol=1e-12,rtol=1.e-6,verbose=false)
return FGMRESSolver(m,Pr,Pl,atol,rtol,verbose)
end

struct FGMRESSymbolicSetup <: Gridap.Algebra.SymbolicSetup
solver
end

function Gridap.Algebra.symbolic_setup(solver::FGMRESSolver, A::AbstractMatrix)
return FGMRESSymbolicSetup(solver)
end

mutable struct FGMRESNumericalSetup <: Gridap.Algebra.NumericalSetup
solver
A
Pr_ns
Pl_ns
caches
end

function get_solver_caches(solver::FGMRESSolver,A)
m = solver.m; Pl = solver.Pl

V = [allocate_col_vector(A) for i in 1:m+1]
Z = [allocate_col_vector(A) for i in 1:m]
zl = !isa(Pl,Nothing) ? allocate_col_vector(A) : nothing

H = zeros(m+1,m) # Hessenberg matrix
g = zeros(m+1) # Residual vector
c = zeros(m) # Gibens rotation cosines
s = zeros(m) # Gibens rotation sines
return (V,Z,zl,H,g,c,s)
end

function Gridap.Algebra.numerical_setup(ss::FGMRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pr_ns = numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_solver_caches(solver,A)
return FGMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
end

function Gridap.Algebra.numerical_setup!(ns::FGMRESNumericalSetup, A::AbstractMatrix)
numerical_setup!(ns.Pr_ns,A)
if !isa(ns.Pl_ns,Nothing)
numerical_setup!(ns.Pl_ns,A)
end
ns.A = A
end

function Gridap.Algebra.solve!(x::AbstractVector,ns::FGMRESNumericalSetup,b::AbstractVector)
solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches
m, atol, rtol, verbose = solver.m, solver.atol, solver.rtol, solver.verbose
V, Z, zl, H, g, c, s = caches
verbose && println(" > Starting FGMRES solver: ")

# Initial residual
krylov_residual!(V[1],x,A,b,Pl,zl)

iter = 0
β = norm(V[1]); β0 = β
converged =< atol || β < rtol*β0)
while !converged
verbose && println(" > Iteration ", iter," - Residual: ", β)
fill!(H,0.0)

# Arnoldi process
j = 1
V[1] ./= β
fill!(g,0.0); g[1] = β
while ( j < m+1 && !converged )
verbose && println(" > Inner iteration ", j," - Residual: ", β)
# Arnoldi orthogonalization by Modified Gram-Schmidt
krylov_mul!(V[j+1],A,V[j],Pr,Pl,Z[j],zl)
for i in 1:j
H[i,j] = dot(V[j+1],V[i])
V[j+1] .= V[j+1] .- H[i,j] .* V[i]
end
H[j+1,j] = norm(V[j+1])
V[j+1] ./= H[j+1,j]

# Update QR
for i in 1:j-1
γ = c[i]*H[i,j] + s[i]*H[i+1,j]
H[i+1,j] = -s[i]*H[i,j] + c[i]*H[i+1,j]
H[i,j] = γ
end

# New Givens rotation, update QR and residual
c[j], s[j], _ = LinearAlgebra.givensAlgorithm(H[j,j],H[j+1,j])
H[j,j] = c[j]*H[j,j] + s[j]*H[j+1,j]; H[j+1,j] = 0.0
g[j+1] = -s[j]*g[j]; g[j] = c[j]*g[j]

β = abs(g[j+1]); converged =< atol || β < rtol*β0)
j += 1
end
j = j-1

# Solve least squares problem Hy = g by backward substitution
for i in j:-1:1
g[i] = (g[i] - dot(H[i,i+1:j],g[i+1:j])) / H[i,i]
end

# Update solution & residual
for i in 1:j
x .+= g[i] .* Z[i]
end
krylov_residual!(V[1],x,A,b,Pl,zl)

iter += 1
end
verbose && println(" > Num Iter: ", iter," - Final residual: ", β)
verbose && println(" Exiting FGMRES solver.")

return x
end
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@

# GMRES Solver
struct GMRESSolver <: Gridap.Algebra.LinearSolver
m ::Int
Pl ::Gridap.Algebra.LinearSolver
atol::Float64
rtol::Float64
verbose::Bool
m :: Int
Pr :: Union{Gridap.Algebra.LinearSolver,Nothing}
Pl :: Union{Gridap.Algebra.LinearSolver,Nothing}
atol :: Float64
rtol :: Float64
verbose :: Bool
end

function GMRESSolver(m,Pl;atol=1e-12,rtol=1.e-6,verbose=false)
return GMRESSolver(m,Pl,atol,rtol,verbose)
function GMRESSolver(m;Pr=nothing,Pl=nothing,atol=1e-12,rtol=1.e-6,verbose=false)
return GMRESSolver(m,Pr,Pl,atol,rtol,verbose)
end

struct GMRESSymbolicSetup <: Gridap.Algebra.SymbolicSetup
Expand All @@ -23,65 +23,72 @@ end
mutable struct GMRESNumericalSetup <: Gridap.Algebra.NumericalSetup
solver
A
Pr_ns
Pl_ns
caches
end

function get_gmres_caches(m,A)
w = allocate_col_vector(A)
V = [allocate_col_vector(A) for i in 1:m+1]
Z = [allocate_col_vector(A) for i in 1:m]
function get_solver_caches(solver::GMRESSolver,A)
m, Pl, Pr = solver.m, solver.Pl, solver.Pr

V = [allocate_col_vector(A) for i in 1:m+1]
zr = !isa(Pr,Nothing) ? allocate_col_vector(A) : nothing
zl = !isa(Pr,Nothing) ? allocate_col_vector(A) : nothing

H = zeros(m+1,m) # Hessenberg matrix
g = zeros(m+1) # Residual vector
c = zeros(m) # Gibens rotation cosines
s = zeros(m) # Gibens rotation sines
return (w,V,Z,H,g,c,s)
return (V,zr,zl,H,g,c,s)
end

function Gridap.Algebra.numerical_setup(ss::GMRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pl_ns = numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_gmres_caches(solver.m,A)
return GMRESNumericalSetup(solver,A,Pl_ns,caches)
Pr_ns = isa(solver.Pr,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_solver_caches(solver,A)
return GMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
end

function Gridap.Algebra.numerical_setup!(ns::GMRESNumericalSetup, A::AbstractMatrix)
numerical_setup!(ns.Pl_ns,A)
if !isa(ns.Pr_ns,Nothing)
numerical_setup!(ns.Pr_ns,A)
end
if !isa(ns.Pl_ns,Nothing)
numerical_setup!(ns.Pl_ns,A)
end
ns.A = A
end

function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::AbstractVector)
solver, A, Pl, caches = ns.solver, ns.A, ns.Pl_ns, ns.caches
solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches
m, atol, rtol, verbose = solver.m, solver.atol, solver.rtol, solver.verbose
w, V, Z, H, g, c, s = caches
V, zr, zl, H, g, c, s = caches
verbose && println(" > Starting GMRES solver: ")

# Initial residual
mul!(w,A,x); w .= b .- w

β = norm(w); β0 = β
converged =< atol || β < rtol*β0)
krylov_residual!(V[1],x,A,b,Pl,zl)
β = norm(V[1]); β0 = β
iter = 0
converged =< atol || β < rtol*β0)
while !converged
verbose && println(" > Iteration ", iter," - Residual: ", β)
fill!(H,0.0)

# Arnoldi process
fill!(g,0.0); g[1] = β
V[1] .= w ./ β
j = 1
V[1] ./= β
fill!(g,0.0); g[1] = β
while ( j < m+1 && !converged )
verbose && println(" > Inner iteration ", j," - Residual: ", β)
# Arnoldi orthogonalization by Modified Gram-Schmidt
solve!(Z[j],Pl,V[j])
mul!(w,A,Z[j])
krylov_mul!(V[j+1],A,V[j],Pr,Pl,zr,zl)
for i in 1:j
H[i,j] = dot(w,V[i])
w .= w .- H[i,j] .* V[i]
H[i,j] = dot(V[j+1],V[i])
V[j+1] .= V[j+1] .- H[i,j] .* V[i]
end
H[j+1,j] = norm(w)
V[j+1] = w ./ H[j+1,j]
H[j+1,j] = norm(V[j+1])
V[j+1] ./= H[j+1,j]

# Update QR
for i in 1:j-1
Expand All @@ -106,15 +113,24 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::GMRESNumericalSetup,b::Abst
end

# Update solution & residual
for i in 1:j
x .+= g[i] .* Z[i]
if isa(Pr,Nothing)
for i in 1:j
x .+= g[i] .* V[i]
end
else
fill!(zl,0.0)
for i in 1:j
zl .+= g[i] .* V[i]
end
solve!(zr,Pr,zl)
x .+= zr
end
mul!(w,A,x); w .= b .- w
krylov_residual!(V[1],x,A,b,Pl,zl)

iter += 1
end
verbose && println(" > Num Iter: ", iter," - Final residual: ", β)
verbose && println(" Exiting GMRES solver.")

return x
end
end
Loading

0 comments on commit 744896d

Please sign in to comment.