Skip to content

Commit

Permalink
Add a logging function that gets called at the start of the iteration…
Browse files Browse the repository at this point in the history
…s. This commit can be rolled back if desired
  • Loading branch information
shivupa committed May 31, 2024
1 parent e3e0869 commit 7bdec45
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 22 deletions.
19 changes: 15 additions & 4 deletions include/Spectra/GenEigsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,18 +441,29 @@ class GenEigsBase
Index i, nconv = 0, nev_adj;
for (i = 0; i < maxit; i++)
{
nconv = num_converged(tol);
if (m_logger)
{
const ComplexVector eigs = m_ritz_val.head(m_nev);
const IterationData<Scalar, ComplexVector> data(i, nconv, m_ncv, eigs, m_resid, m_ritz_conv);
m_logger->iteration_log(data);
m_logger->call_iteration_start();
}
nconv = num_converged(tol);
const ComplexVector eigs = m_ritz_val.head(m_nev);
const IterationData<Scalar, ComplexVector> data(i, nconv, m_ncv, eigs, m_resid, m_ritz_conv);

if (nconv >= m_nev)
{
if (m_logger)
{
m_logger->call_iteration_end(data);
}
break;
}

nev_adj = nev_adjusted(nconv);
restart(nev_adj, selection);
if (m_logger)
{
m_logger->call_iteration_end(data);
}
}
// Sorting results
sort_ritzpair(sorting);
Expand Down
34 changes: 24 additions & 10 deletions include/Spectra/JDSymEigsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ class JDSymEigsBase
niter_ = 0;
for (niter_ = 0; niter_ < maxit; niter_++)
{
if (m_logger)
{
m_logger->call_iteration_start();
}

bool do_restart = (m_search_space.size() > m_max_search_space_size);

if (do_restart)
Expand All @@ -189,30 +194,39 @@ class JDSymEigsBase
m_ritz_pairs.sort(selection);

bool converged = m_ritz_pairs.check_convergence(tol, m_number_eigenvalues);
if (m_logger)
{
const Eigen::Array<bool, Eigen::Dynamic, 1> conv_eig = m_ritz_pairs.converged_eigenvalues().head(m_number_eigenvalues);
const Index num_conv = conv_eig.count();
const Vector evals = eigenvalues();
const Vector res = m_ritz_pairs.residues().colwise().norm().head(m_number_eigenvalues);
const Index search_space_size = m_search_space.size();
const IterationData<Scalar, Vector> data(niter_, num_conv, search_space_size, evals, res, conv_eig);
m_logger->iteration_log(data);
}
const Eigen::Array<bool, Eigen::Dynamic, 1> conv_eig = m_ritz_pairs.converged_eigenvalues().head(m_number_eigenvalues);
const Index num_conv = conv_eig.count();
const Vector evals = eigenvalues();
const Vector res = m_ritz_pairs.residues().colwise().norm().head(m_number_eigenvalues);
const Index search_space_size = m_search_space.size();
const IterationData<Scalar, Vector> data(niter_, num_conv, search_space_size, evals, res, conv_eig);

if (converged)
{
m_info = CompInfo::Successful;
if (m_logger)
{
m_logger->call_iteration_end(data);
}
break;
}
else if (niter_ == maxit - 1)
{
m_info = CompInfo::NotConverging;
if (m_logger)
{
m_logger->call_iteration_end(data);
}
break;
}
Derived& derived = static_cast<Derived&>(*this);
Matrix corr_vect = derived.calculate_correction_vector();

m_search_space.extend_basis(corr_vect);
if (m_logger)
{
m_logger->call_iteration_end(data);
}
}
return (m_ritz_pairs.converged_eigenvalues()).template cast<Index>().head(m_number_eigenvalues).sum();
}
Expand Down
3 changes: 2 additions & 1 deletion include/Spectra/LoggerBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ class LoggerBase
///
virtual ~LoggerBase() {}

virtual void call_iteration_start();
///
/// Virtual logging function
///
virtual void iteration_log(const IterationData<Scalar, Vector>& data) = 0;
virtual void call_iteration_end(const IterationData<Scalar, Vector>& data) = 0;
};

} // namespace Spectra
Expand Down
20 changes: 16 additions & 4 deletions include/Spectra/SymEigsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,18 +378,30 @@ class SymEigsBase
Index i, nconv = 0, nev_adj;
for (i = 0; i < maxit; i++)
{
nconv = num_converged(tol);
if (m_logger)
{
const Vector eigs = m_ritz_val.head(m_nev);
const IterationData<Scalar, Vector> data(i, nconv, m_ncv, eigs, m_resid, m_ritz_conv);
m_logger->iteration_log(data);
m_logger->call_iteration_start();
}
nconv = num_converged(tol);
const Vector eigs = m_ritz_val.head(m_nev);
const IterationData<Scalar, Vector> data(i, nconv, m_ncv, eigs, m_resid, m_ritz_conv);

if (nconv >= m_nev)
{
if (m_logger)
{
m_logger->call_iteration_end(data);
}
break;
}

nev_adj = nev_adjusted(nconv);
restart(nev_adj, selection);

if (m_logger)
{
m_logger->call_iteration_end(data);
}
}
// Sorting results
sort_ritzpair(sorting);
Expand Down
25 changes: 24 additions & 1 deletion test/LoggingDavidsonSymEigs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,36 @@ class DerivedLogger : public LoggerBase<Scalar, Vector>
{
// This derived logging class could have some reference to an ostream or call to another class that wraps ostreams etc.
public:
std::chrono::time_point<std::chrono::high_resolution_clock> start;
std::chrono::time_point<std::chrono::high_resolution_clock> end;
DerivedLogger(){};
void iteration_log(const IterationData<Scalar, Vector>& data) override
void call_iteration_start() override
{
this->start = std::chrono::high_resolution_clock::now();
}
void call_iteration_end(const IterationData<Scalar, Vector>& data) override
{
this->end = std::chrono::high_resolution_clock::now();
auto duration = this->end - this->start;
auto h = std::chrono::duration_cast<std::chrono::hours>(duration);
duration -= h;
auto m = std::chrono::duration_cast<std::chrono::minutes>(duration);
duration -= m;
auto s = std::chrono::duration_cast<std::chrono::seconds>(duration);
duration -= s;
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
duration -= ms;
auto us = std::chrono::duration_cast<std::chrono::microseconds>(duration);
duration -= us;
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration);
duration -= ns;
std::stringstream buffer;
std::cout << "--------------------------------------------------------------------------------------------" << std::endl;
std::cout << " Iteration : " << data.iteration << std::endl;
std::cout << " Number of converged eigenvalues : " << data.number_of_converged << std::endl;
std::cout << " Size of subspace : " << data.subspace_size << std::endl;
std::cout << " Iteration Time " << std::endl;
std::cout << " " << h.count() << "h:" << m.count() << "m:" << s.count() << "s:" << ms.count() << "ms:" << us.count() << "us:" << ns.count() << "ns" << std::endl;
std::cout << " ------------------------------------------------------------------------ " << std::endl;
REQUIRE(data.residues.size() == data.current_eigenvalues.size());
REQUIRE(data.residues.size() == data.current_eig_converged.size());
Expand Down
25 changes: 24 additions & 1 deletion test/LoggingGenEigs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,36 @@ class DerivedLogger : public LoggerBase<Scalar, Vector>
{
// This derived logging class could have some reference to an ostream or call to another class that wraps ostreams etc.
public:
std::chrono::time_point<std::chrono::high_resolution_clock> start;
std::chrono::time_point<std::chrono::high_resolution_clock> end;
DerivedLogger(){};
void iteration_log(const IterationData<Scalar, Vector>& data) override
void call_iteration_start() override
{
this->start = std::chrono::high_resolution_clock::now();
}
void call_iteration_end(const IterationData<Scalar, Vector>& data) override
{
this->end = std::chrono::high_resolution_clock::now();
auto duration = this->end - this->start;
auto h = std::chrono::duration_cast<std::chrono::hours>(duration);
duration -= h;
auto m = std::chrono::duration_cast<std::chrono::minutes>(duration);
duration -= m;
auto s = std::chrono::duration_cast<std::chrono::seconds>(duration);
duration -= s;
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
duration -= ms;
auto us = std::chrono::duration_cast<std::chrono::microseconds>(duration);
duration -= us;
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration);
duration -= ns;
std::stringstream buffer;
std::cout << "--------------------------------------------------------------------------------------------" << std::endl;
std::cout << " Iteration : " << data.iteration << std::endl;
std::cout << " Number of converged eigenvalues : " << data.number_of_converged << std::endl;
std::cout << " Size of subspace : " << data.subspace_size << std::endl;
std::cout << " Iteration Time " << std::endl;
std::cout << " " << h.count() << "h:" << m.count() << "m:" << s.count() << "s:" << ms.count() << "ms:" << us.count() << "us:" << ns.count() << "ns" << std::endl;
std::cout << " ------------------------------------------------------------------------ " << std::endl;
REQUIRE(data.residues.size() == data.current_eigenvalues.size());
REQUIRE(data.residues.size() == data.current_eig_converged.size());
Expand Down
25 changes: 24 additions & 1 deletion test/LoggingSymEigs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,36 @@ class DerivedLogger : public LoggerBase<Scalar, Vector>
{
// This derived logging class could have some reference to an ostream or call to another class that wraps ostreams etc.
public:
std::chrono::time_point<std::chrono::high_resolution_clock> start;
std::chrono::time_point<std::chrono::high_resolution_clock> end;
DerivedLogger(){};
void iteration_log(const IterationData<Scalar, Vector>& data) override
void call_iteration_start() override
{
this->start = std::chrono::high_resolution_clock::now();
}
void call_iteration_end(const IterationData<Scalar, Vector>& data) override
{
this->end = std::chrono::high_resolution_clock::now();
auto duration = this->end - this->start;
auto h = std::chrono::duration_cast<std::chrono::hours>(duration);
duration -= h;
auto m = std::chrono::duration_cast<std::chrono::minutes>(duration);
duration -= m;
auto s = std::chrono::duration_cast<std::chrono::seconds>(duration);
duration -= s;
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
duration -= ms;
auto us = std::chrono::duration_cast<std::chrono::microseconds>(duration);
duration -= us;
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration);
duration -= ns;
std::stringstream buffer;
std::cout << "--------------------------------------------------------------------------------------------" << std::endl;
std::cout << " Iteration : " << data.iteration << std::endl;
std::cout << " Number of converged eigenvalues : " << data.number_of_converged << std::endl;
std::cout << " Size of subspace : " << data.subspace_size << std::endl;
std::cout << " Iteration Time " << std::endl;
std::cout << " " << h.count() << "h:" << m.count() << "m:" << s.count() << "s:" << ms.count() << "ms:" << us.count() << "us:" << ns.count() << "ns" << std::endl;
std::cout << " ------------------------------------------------------------------------ " << std::endl;
REQUIRE(data.residues.size() == data.current_eigenvalues.size());
REQUIRE(data.residues.size() == data.current_eig_converged.size());
Expand Down

0 comments on commit 7bdec45

Please sign in to comment.