Skip to content

Commit

Permalink
Feature/generic solver callback (#200)
Browse files Browse the repository at this point in the history
* prototype callback interface

* added api documentation to callback classes

* doc formatting

* doc formatting

* doc formatting

* converted SimpleSOL mass conservation test to use callbacks rather than files

* altered SimpleSOL integration test inputs to not write mass data to file

* updated doxygen param name
  • Loading branch information
will-saunders-ukaea authored Aug 3, 2023
1 parent bd4f314 commit 622bf54
Show file tree
Hide file tree
Showing 12 changed files with 514 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ set(HEADER_FILES
${INC_DIR}/run_info.hpp
${INC_DIR}/simulation.hpp
${INC_DIR}/species.hpp
${INC_DIR}/solvers/solver_callback_handler.hpp
${INC_DIR}/solvers/solver_runner.hpp
${INC_DIR}/velocity.hpp)

# Create library
Expand Down
21 changes: 21 additions & 0 deletions include/nektar_interface/utilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,38 @@ class NektarFieldIndexMap {
std::map<std::string, int> field_to_index;

public:
/**
* Create map from field names to indices. It is assumed that the field
* index is the position in the input vector.
*
* @param field_names Vector of field names.
*/
NektarFieldIndexMap(std::vector<std::string> field_names) {
int index = 0;
for (auto field_name : field_names) {
this->field_to_index[field_name] = index++;
}
}
/**
* Get the index of a field by name.
*
* @param field_name Name of field to get index for.
* @returns Non-negative integer if field exists -1 otherwise.
*/
int get_idx(std::string field_name) {
return (this->field_to_index.count(field_name) > 0)
? this->field_to_index[field_name]
: -1;
}

/**
* Identical to get_idx except this method mirrors the std library behaviour
* and is fatal if the named field does not exist in the map.
*
* @param field_name Name of field to get index for.
* @returns Non-negative integer if field exists.
*/
int at(std::string field_name) { return this->field_to_index.at(field_name); }
};

/**
Expand Down
227 changes: 227 additions & 0 deletions include/solvers/solver_callback_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#ifndef __SOLVER_CALLBACK_HANDLER_H_
#define __SOLVER_CALLBACK_HANDLER_H_
#include <functional>
#include <vector>

namespace NESO {

/**
* Base class which can be inherited from to create a callback for a solver
* class called NameOfSolver.
*
* class Foo: public SolverCallback<NameOfSolver> {
* void call(NameOfSolver * state){
* // Do something with state
* }
* }
*
* Deriving from this class is not compulsory to create a callback.
*/
template <typename SOLVER> struct SolverCallback {

/**
* Call the callback function with the current state passed as a pointer. The
* callback may modify the solver (at the callers peril). Note the order in
* which callbacks are called is undefined.
*
* @param[in, out] state Pointer to solver instance.
*/
virtual void call(SOLVER *state) = 0;
};

/**
* Class to handle calling callbacks from within a solver. This class can be a
* member variable of a solver or inherited from by the solver. The class is
* templated on the solver type which defines the pointer type passed to the
* callback functions.
*/
template <typename SOLVER> class SolverCallbackHandler {
protected:
/// Functions to be typically called before an integration step.
std::vector<std::function<void(SOLVER *)>> pre_integrate_funcs;
/// Functions to be typically called after an integration step.
std::vector<std::function<void(SOLVER *)>> post_integrate_funcs;

/**
* Helper function to convert an input function handle to a object which can
* be stored on the vector of function handles.
*
* @param[in] func Function handle to process.
* @returns standardised function handle.
*/
inline std::function<void(SOLVER *)>
get_as_consistent_type(std::function<void(SOLVER *)> func) {
std::function<void(SOLVER *)> f = std::bind(func, std::placeholders::_1);
return f;
}

/**
* Helper function to convert an input function handle to a object which can
* be stored on the vector of function handles.
*
* @param[in] func Class::method_name to call as function handle.
* @param[in] inst object on which to call method.
* @returns standardised function handle.
*/
template <typename T, typename U>
inline std::function<void(SOLVER *)> get_as_consistent_type(T func, U &inst) {
std::function<void(SOLVER *)> f =
std::bind(func, std::ref(inst), std::placeholders::_1);
return f;
}

public:
/**
* Register a function to be called before each time integration step. e.g.
*
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_pre_integrate(
* std::function<void(NameOfSolver *)>{
* [&](NameOfSolver *state) {
* // use state
* }
* }
* );
* }
*
* @param[in] func Function handle to add to callbacks.
*/
inline void register_pre_integrate(std::function<void(SOLVER *)> func) {
this->pre_integrate_funcs.push_back(this->get_as_consistent_type(func));
}

/**
* Register a class method to be called before each time integration step.
* e.g.
*
* struct TestInterface {
* void call(NameOfSolver *state) {
* // use state
* }
* };
*
* TestInterface ti;
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_pre_integrate(&TestInterface::call,
* ti);
*
* @param[in] func Function handle to add to callbacks in the form of
* &ClassName::method_name
* @param[in] inst Instance of ClassName object on which to call method_name.
*/
template <typename T, typename U>
inline void register_pre_integrate(T func, U &inst) {
this->pre_integrate_funcs.push_back(
this->get_as_consistent_type(func, inst));
}

/**
* Register a type derived of SolverCallback as a callback. e.g.
*
* struct TestInterface : public SolverCallback<NameOfSolver> {
* void call(NameOfSolver *state) {
* // use state
* }
* };
*
* TestInterface ti;
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_pre_integrate(ti);
*
* @param[in] obj Derived type to add as callback object.
*/
inline void register_pre_integrate(SolverCallback<SOLVER> &obj) {
this->pre_integrate_funcs.push_back(
this->get_as_consistent_type(&SolverCallback<SOLVER>::call, obj));
}

/**
* Register a function to be called after each time integration step. e.g.
*
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_post_integrate(
* std::function<void(NameOfSolver *)>{
* [&](NameOfSolver *state) {
* // use state
* }
* }
* );
* }
*
* @param[in] func Function handle to add to callbacks.
*/
inline void register_post_integrate(std::function<void(SOLVER *)> func) {
this->post_integrate_funcs.push_back(this->get_as_consistent_type(func));
}

/**
* Register a class method to be called after each time integration step. e.g.
*
* struct TestInterface {
* void call(NameOfSolver *state) {
* // use state
* }
* };
*
* TestInterface ti;
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_post_integrate(&TestInterface::call,
* ti);
*
* @param[in] func Function handle to add to callbacks in the form of
* &ClassName::method_name
* @param[in] inst Instance of ClassName object on which to call method_name.
*/
template <typename T, typename U>
inline void register_post_integrate(T func, U &inst) {
this->post_integrate_funcs.push_back(
this->get_as_consistent_type(func, inst));
}

/**
* Register a type derived of SolverCallback as a callback. e.g.
*
* struct TestInterface : public SolverCallback<NameOfSolver> {
* void call(NameOfSolver *state) {
* // use state
* }
* };
*
* TestInterface ti;
* SolverCallbackHandler<NameOfSolver> solver_callback_handler;
* solver_callback_handler.register_post_integrate(ti);
*
* @param[in] obj Derived type to add as callback object.
*/
inline void register_post_integrate(SolverCallback<SOLVER> &obj) {
this->post_integrate_funcs.push_back(
this->get_as_consistent_type(&SolverCallback<SOLVER>::call, obj));
}

/**
* Call all function handles which were registered as pre-integration
* callbacks.
*
* @param[in, out] state Solver state used to call each function handle.
*/
inline void call_pre_integrate(SOLVER *state) {
for (auto &func : this->pre_integrate_funcs) {
func(state);
}
}

/**
* Call all function handles which were registered as post-integration
* callbacks.
*
* @param[in, out] state Solver state used to call each function handle.
*/
inline void call_post_integrate(SOLVER *state) {
for (auto &func : this->post_integrate_funcs) {
func(state);
}
}
};

} // namespace NESO
#endif
51 changes: 51 additions & 0 deletions include/solvers/solver_runner.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef __SOLVER_RUNNER_H_
#define __SOLVER_RUNNER_H_

#include <LibUtilities/BasicUtils/SessionReader.h>
#include <SolverUtils/Driver.h>
#include <SpatialDomains/MeshGraph.h>

using namespace Nektar;

/**
* Class to abstract setting up sessions and drivers for Nektar++ solvers.
*/
class SolverRunner {
public:
/// Nektar++ session object.
LibUtilities::SessionReaderSharedPtr session;
/// MeshGraph instance for solver.
SpatialDomains::MeshGraphSharedPtr graph;
/// The Driver created for the solver.
SolverUtils::DriverSharedPtr driver;

/**
* Create session, graph and driver from files.
*
* @param argc Number of arguments (like for main).
* @param argv Array of char* filenames (like for main).
*/
SolverRunner(int argc, char **argv) {
// Create session reader.
this->session = LibUtilities::SessionReader::CreateInstance(argc, argv);
// Read the mesh and create a MeshGraph object.
this->graph = SpatialDomains::MeshGraph::Read(this->session);
// Create driver.
std::string driverName;
session->LoadSolverInfo("Driver", driverName, "Standard");
this->driver = SolverUtils::GetDriverFactory().CreateInstance(
driverName, session, graph);
}

/**
* Calls Execute on the underlying driver object to run the solver.
*/
inline void execute() { this->driver->Execute(); }

/**
* Calls Finalise on the underlying session object.
*/
inline void finalise() { this->session->Finalise(); }
};

#endif
37 changes: 23 additions & 14 deletions solvers/SimpleSOL/Diagnostics/mass_conservation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,30 @@ template <typename T> class MassRecording {
}
}

inline double get_initial_mass() {
NESOASSERT(this->initial_mass_computed == true,
"initial mass not computed");
return this->initial_mass_fluid;
}

inline void compute(int step) {
if (step % mass_recording_step == 0) {
const double mass_particles = this->compute_particle_mass();
const double mass_fluid = this->compute_fluid_mass();
const double mass_total = mass_particles + mass_fluid;
const double mass_added = this->compute_total_added_mass();
const double correct_total = mass_added + this->initial_mass_fluid;

// Write values to file
if (rank == 0) {
nprint(step, ",", abs(correct_total - mass_total) / abs(correct_total),
",", mass_particles, ",", mass_fluid, ",");
fh << step << ","
<< abs(correct_total - mass_total) / abs(correct_total) << ","
<< mass_particles << "," << mass_fluid << "\n";
if (mass_recording_step > 0) {
if (step % mass_recording_step == 0) {
const double mass_particles = this->compute_particle_mass();
const double mass_fluid = this->compute_fluid_mass();
const double mass_total = mass_particles + mass_fluid;
const double mass_added = this->compute_total_added_mass();
const double correct_total = mass_added + this->initial_mass_fluid;

// Write values to file
if (rank == 0) {
nprint(step, ",",
abs(correct_total - mass_total) / abs(correct_total), ",",
mass_particles, ",", mass_fluid, ",");
fh << step << ","
<< abs(correct_total - mass_total) / abs(correct_total) << ","
<< mass_particles << "," << mass_fluid << "\n";
}
}
}
};
Expand Down
Loading

0 comments on commit 622bf54

Please sign in to comment.