Skip to content

Commit

Permalink
Add methods for retreiving number of iterations and residual norm fro…
Browse files Browse the repository at this point in the history
…m KSM linear solvers (#185)

* Add methods for retreiving number of iterations and residual norm from KSM solvers

* PCG fix

* clang-format

* Test `getResidualNorm` and `getIterCount` in `StaticTestCase`

* Address Graeme's comments

* Treat resNorm consistently in all solvers

* typo

* Remove unnecessary abs

* Make resNorm and iterCount getters virtual
  • Loading branch information
A-CGray authored Mar 7, 2023
1 parent 401e975 commit 7c231b1
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 26 deletions.
44 changes: 28 additions & 16 deletions src/bpmat/KSM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ void PCG::setMonitor(KSMPrint *_monitor) {
*/
int PCG::solve(TACSVec *b, TACSVec *x, int zero_guess) {
int solve_flag = 0;
iterCount = 0;
TacsScalar rhs_norm = 0.0;
// R, Z and P are work-vectors
// R == the residual
Expand All @@ -455,6 +456,7 @@ int PCG::solve(TACSVec *b, TACSVec *x, int zero_guess) {

if (count == 0) {
rhs_norm = R->norm();
resNorm = rhs_norm;
}

if (monitor && count == 0) {
Expand All @@ -477,15 +479,16 @@ int PCG::solve(TACSVec *b, TACSVec *x, int zero_guess) {
pc->applyFactor(R, Z); // Z' = M^{-1} R
TacsScalar beta = R->dot(Z) / temp; // beta = (R',Z')/(R,Z)
P->axpby(1.0, beta, Z); // P' = Z' + beta*P
iterCount++;

TacsScalar norm = R->norm();
resNorm = R->norm();

if (monitor) {
monitor->printResidual(i + 1, norm);
monitor->printResidual(i + 1, resNorm);
}

if (TacsRealPart(norm) < atol ||
TacsRealPart(norm) < rtol * TacsRealPart(rhs_norm)) {
if (TacsRealPart(resNorm) < atol ||
TacsRealPart(resNorm) < rtol * TacsRealPart(rhs_norm)) {
solve_flag = 1;
break;
}
Expand Down Expand Up @@ -782,6 +785,7 @@ const char *GMRES::gmresName = "GMRES";
int GMRES::solve(TACSVec *b, TACSVec *x, int zero_guess) {
TacsScalar rhs_norm = 0.0;
int solve_flag = 0;
iterCount = 0;

double t_pc = 0.0, t_ortho = 0.0;
double t_total = 0.0;
Expand Down Expand Up @@ -814,6 +818,7 @@ int GMRES::solve(TACSVec *b, TACSVec *x, int zero_guess) {

if (count == 0) {
rhs_norm = res[0]; // The initial residual
resNorm = rhs_norm;
}

int niters = 0; // Keep track of the size of the Hessenberg matrix
Expand Down Expand Up @@ -882,21 +887,24 @@ int GMRES::solve(TACSVec *b, TACSVec *x, int zero_guess) {
res[i] = h1 * Qcos[i];
res[i + 1] = -h1 * Qsin[i];

niters++;
resNorm = fabs(res[i + 1]);

if (monitor) {
monitor->printResidual(i + 1, fabs(TacsRealPart(res[i + 1])));
monitor->printResidual(i + 1, resNorm);
}

niters++;

if (fabs(TacsRealPart(res[i + 1])) < atol ||
fabs(TacsRealPart(res[i + 1])) < rtol * TacsRealPart(rhs_norm)) {
if (TacsRealPart(resNorm) < atol ||
TacsRealPart(resNorm) < rtol * TacsRealPart(rhs_norm)) {
// Set the solve flag
solve_flag = 1;

break;
}
}

iterCount += niters;

// Now, compute the solution - the linear combination of the
// Arnoldi vectors. H is upper triangular

Expand Down Expand Up @@ -1189,6 +1197,7 @@ int GCROT::solve(TACSVec *b, TACSVec *x, int zero_guess) {
TacsScalar rhs_norm = 0.0;
int solve_flag = 0;
int mat_iters = 0;
iterCount = 0;

// Compute the residual
if (zero_guess) {
Expand All @@ -1204,6 +1213,7 @@ int GCROT::solve(TACSVec *b, TACSVec *x, int zero_guess) {
}

rhs_norm = R->norm(); // The initial residual
resNorm = rhs_norm;

if (TacsRealPart(rhs_norm) < atol) {
solve_flag = 1;
Expand All @@ -1220,7 +1230,7 @@ int GCROT::solve(TACSVec *b, TACSVec *x, int zero_guess) {
W[0]->scale(1.0 / res[0]); // W[0] = b/|| b ||

if (monitor) {
monitor->printResidual(mat_iters, fabs(TacsRealPart(res[0])));
monitor->printResidual(mat_iters, resNorm);
}

// The inner F/GMRES loop
Expand Down Expand Up @@ -1285,20 +1295,22 @@ int GCROT::solve(TACSVec *b, TACSVec *x, int zero_guess) {
res[i] = h1 * Qcos[i];
res[i + 1] = -h1 * Qsin[i];

if (monitor) {
monitor->printResidual(mat_iters, fabs(TacsRealPart(res[i + 1])));
}

niters++;

if (fabs(TacsRealPart(res[i + 1])) < atol ||
fabs(TacsRealPart(res[i + 1])) < rtol * TacsRealPart(rhs_norm)) {
resNorm = fabs(res[i + 1]);

if (monitor) {
monitor->printResidual(mat_iters, resNorm);
}
if (TacsRealPart(resNorm) < atol ||
TacsRealPart(resNorm) < rtol * TacsRealPart(rhs_norm)) {
// Set the solve flag
solve_flag = 1;

break;
}
}
iterCount += niters;

// Now, compute the solution - the linear combination of the
// Arnoldi vectors
Expand Down
11 changes: 11 additions & 0 deletions src/bpmat/KSM.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,14 @@ class KSMPrint : public TACSObject {
tolerances for the method
setMonitor(): Set the monitor - possibly NULL - that will be used
getIterCount(): Return the number of iterations taken during the last solve
getResidualNorm(): Return the residual norm from the end of the last solve
*/
class TACSKsm : public TACSObject {
public:
TACSKsm() : iterCount(0), resNorm(0.0) {}
virtual ~TACSKsm() {}

virtual TACSVec *createVec() = 0;
Expand All @@ -263,10 +268,16 @@ class TACSKsm : public TACSObject {
virtual int solve(TACSVec *b, TACSVec *x, int zero_guess = 1) = 0;
virtual void setTolerances(double _rtol, double _atol) = 0;
virtual void setMonitor(KSMPrint *_monitor) = 0;
virtual int getIterCount() { return iterCount; }
virtual TacsScalar getResidualNorm() { return resNorm; }
const char *getObjectName();

private:
static const char *ksmName;

protected:
int iterCount; ///< Number of iterations taken during the last solve
TacsScalar resNorm; ///< The residual norm at the end of the last solve
};

/*
Expand Down
2 changes: 2 additions & 0 deletions tacs/TACS.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ cdef extern from "KSM.h":
int solve(TACSVec *b, TACSVec *x, int zero_guess)
void setTolerances(double _rtol, double _atol)
void setMonitor(KSMPrint *_monitor)
int getIterCount()
TacsScalar getResidualNorm()

cdef cppclass GMRES(TACSKsm):
GMRES(TACSMat *_mat, TACSPc *_pc, int _m,
Expand Down
8 changes: 8 additions & 0 deletions tacs/TACS.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,14 @@ cdef class KSM:
self.ptr.incref()
return

def getIterCount(self):
"""Get the number of iterations performed in the last solve"""
return self.ptr.getIterCount()

def getResidualNorm(self):
"""Get the residual norm of the last solve"""
return self.ptr.getResidualNorm()

def __dealloc__(self):
if self.ptr:
self.ptr.decref()
Expand Down
33 changes: 23 additions & 10 deletions tests/integration_tests/static_analysis_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

"""
This is a base class for static problem unit test cases.
This base class will test function evaluations and total
and partial sensitivities for the user-specified problem
This base class will test function evaluations and total
and partial sensitivities for the user-specified problem
that inherits from it.
When the user creates a new test based on this class three
methods are required to be defined in the child class.
When the user creates a new test based on this class three
methods are required to be defined in the child class.
1. setup_assembler
2. setup_tacs_vecs
3. setup_funcs
See the virtual method implementations for each method
See the virtual method implementations for each method
below for more details.
NOTE: The child class must NOT implement its own setUp method
NOTE: The child class must NOT implement its own setUp method
for the unittest class. This is handled in the base class.
"""

Expand Down Expand Up @@ -89,10 +89,11 @@ def setUp(self):
# Create GMRES solver object
subspace = 100
restarts = 2
atol = 1e-30
rtol = 1e-12
self.linSolveIterLimit = subspace * restarts
self.linSolveAtol = 1e-30
self.linSolveRtol = 1e-12
self.gmres = TACS.KSM(self.mat, self.pc, subspace, restarts)
self.gmres.setTolerances(rtol, atol)
self.gmres.setTolerances(self.linSolveRtol, self.linSolveAtol)

# Create the function list
self.func_list, self.func_ref = self.setup_funcs(self.assembler)
Expand Down Expand Up @@ -149,6 +150,18 @@ def test_solve(self):
# solve
func_vals = self.run_solve()

# Test that linear solver residual is sufficiently small
linSolveRes = np.real(self.gmres.getResidualNorm())
converged = (
linSolveRes < self.linSolveAtol
or linSolveRes < self.linSolveRtol * np.real(self.res0.norm())
)
self.assertTrue(converged, "Linear solver did not converge")

# Test that linear solver took between 1 and subspce * restarts iterations
numIters = self.gmres.getIterCount()
self.assertTrue(numIters > 0 and numIters <= self.linSolveIterLimit)

# Test functions values against historical values
np.testing.assert_allclose(
func_vals, self.func_ref, rtol=self.rtol, atol=self.atol
Expand Down

0 comments on commit 7c231b1

Please sign in to comment.