Skip to content

Commit

Permalink
Added preconditioning to MINRES
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Oct 2, 2023
1 parent ec56491 commit 1ed6a39
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/LinearSolvers/Krylov/GMRESSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function get_solver_caches(solver::GMRESSolver,A)

V = [allocate_col_vector(A) for i in 1:m+1]
zr = !isa(Pr,Nothing) ? allocate_col_vector(A) : nothing
zl = !isa(Pl,Nothing) ? allocate_col_vector(A) : nothing
zl = !isa(Pr,Nothing) ? allocate_col_vector(A) : nothing

H = zeros(m+1,m) # Hessenberg matrix
g = zeros(m+1) # Residual vector
Expand All @@ -44,7 +44,7 @@ end

function Gridap.Algebra.numerical_setup(ss::GMRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pr_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pr_ns = isa(solver.Pr,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_solver_caches(solver,A)
return GMRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
Expand Down
28 changes: 17 additions & 11 deletions src/LinearSolvers/Krylov/MINRESSolvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ function get_solver_caches(solver::MINRESSolver,A)
Pl, Pr = solver.Pl, solver.Pr

V = [allocate_col_vector(A) for i in 1:3]
Z = [allocate_col_vector(A) for i in 1:3]
W = [allocate_col_vector(A) for i in 1:3]
zr = !isa(Pr,Nothing) ? allocate_col_vector(A) : nothing
zl = !isa(Pl,Nothing) ? allocate_col_vector(A) : nothing

H = zeros(4) # Hessenberg matrix
g = zeros(2) # Residual vector
c = zeros(2) # Gibens rotation cosines
s = zeros(2) # Gibens rotation sines
return (V,Z,zr,zl,H,g,c,s)
return (V,W,zr,zl,H,g,c,s)
end

function Gridap.Algebra.numerical_setup(ss::MINRESSymbolicSetup, A::AbstractMatrix)
solver = ss.solver
Pr_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pr_ns = isa(solver.Pr,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pr,A),A)
Pl_ns = isa(solver.Pl,Nothing) ? nothing : numerical_setup(symbolic_setup(solver.Pl,A),A)
caches = get_solver_caches(solver,A)
return MINRESNumericalSetup(solver,A,Pr_ns,Pl_ns,caches)
Expand All @@ -63,23 +63,24 @@ end
function Gridap.Algebra.solve!(x::AbstractVector,ns::MINRESNumericalSetup,b::AbstractVector)
solver, A, Pl, Pr, caches = ns.solver, ns.A, ns.Pl_ns, ns.Pr_ns, ns.caches
atol, rtol, verbose = solver.atol, solver.rtol, solver.verbose
V, Z, zr, zl, H, g, c, s = caches
V, W, zr, zl, H, g, c, s = caches
verbose && println(" > Starting MINRES solver: ")

Vjm1, Vj, Vjp1 = V
Zjm1, Zj, Zjp1 = Z
Wjm1, Wj, Wjp1 = W

fill!(Vjm1,0.0); fill!(Vjp1,0.0); copy!(Vj,b)
fill!(Wjm1,0.0); fill!(Wjp1,0.0); fill!(Wj,0.0)
fill!(H,0.0), fill!(c,1.0); fill!(s,0.0); fill!(g,0.0)

mul!(Vj,A,x,-1.0,1.0)
krylov_residual!(Vj,x,A,b,Pl,zl)
β = norm(Vj); β0 = β; Vj ./= β; g[1] = β
iter = 0
converged =< atol || β < rtol*β0)
while !converged
verbose && println(" > Iteration ", iter," - Residual: ", β)

mul!(Vjp1,A,Vj)
krylov_mul!(Vjp1,A,Vj,Pr,Pl,zr,zl)
H[3] = dot(Vjp1,Vj)
Vjp1 .= Vjp1 .- H[3] .* Vj .- H[2] .* Vjm1
H[4] = norm(Vjp1)
Expand All @@ -95,13 +96,18 @@ function Gridap.Algebra.solve!(x::AbstractVector,ns::MINRESNumericalSetup,b::Abs
g[2] = -s[2]*g[1]; g[1] = c[2]*g[1]

# Update solution
Zjp1 .= Vj .- H[2] .* Zj .- H[1] .* Zjm1
Zjp1 ./= H[3]
x .+= g[1] .* Zjp1
Wjp1 .= Vj .- H[2] .* Wj .- H[1] .* Wjm1
Wjp1 ./= H[3]
if isa(Pr,Nothing)
x .+= g[1] .* Wjp1
else
solve!(zr,Pr,Wjp1)
x .+= g[1] .* zr
end

β = abs(g[2]); converged =< atol || β < rtol*β0)
Vjm1, Vj, Vjp1 = Vj, Vjp1, Vjm1
Zjm1, Zj, Zjp1 = Zj, Zjp1, Zjm1
Wjm1, Wj, Wjp1 = Wj, Wjp1, Wjm1
g[1] = g[2]; H[2] = H[4];
iter += 1
end
Expand Down
2 changes: 1 addition & 1 deletion test/seq/KrylovSolversTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function main(model)
fpcg = LinearSolvers.CGSolver(P;flexible=true,rtol=1.e-8,verbose=true)
test_solver(fpcg,op,Uh,dΩ)

minres = LinearSolvers.MINRESSolver(;rtol=1.e-8,verbose=true)
minres = LinearSolvers.MINRESSolver(;Pl=P,Pr=P,rtol=1.e-8,verbose=true)
test_solver(minres,op,Uh,dΩ)
end

Expand Down

0 comments on commit 1ed6a39

Please sign in to comment.