From 622bf54ba61eb48e6a6446ad1b0542b51d9f9d53 Mon Sep 17 00:00:00 2001 From: Will Saunders <77331509+will-saunders-ukaea@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:24:19 +0100 Subject: [PATCH] Feature/generic solver callback (#200) * prototype callback interface * added api documentation to callback classes * doc formatting * doc formatting * doc formatting * converted SimpleSOL mass conservation test to use callbacks rather than files * altered SimpleSOL integration test inputs to not write mass data to file * updated doxygen param name --- CMakeLists.txt | 2 + include/nektar_interface/utilities.hpp | 21 ++ include/solvers/solver_callback_handler.hpp | 227 ++++++++++++++++++ include/solvers/solver_runner.hpp | 51 ++++ .../Diagnostics/mass_conservation.hpp | 37 +-- .../SOLWithParticlesSystem.cpp | 26 +- .../EquationSystems/SOLWithParticlesSystem.h | 28 ++- test/CMakeLists.txt | 3 +- .../2DWithParticles_config.xml | 4 +- .../solvers/SimpleSOL/test_SimpleSOL.cpp | 26 +- .../solvers/SimpleSOL/test_SimpleSOL.h | 27 ++- test/unit/test_solver_callback.cpp | 91 +++++++ 12 files changed, 514 insertions(+), 29 deletions(-) create mode 100644 include/solvers/solver_callback_handler.hpp create mode 100644 include/solvers/solver_runner.hpp create mode 100644 test/unit/test_solver_callback.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f602f8b..e6cc0688 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,6 +188,8 @@ set(HEADER_FILES ${INC_DIR}/run_info.hpp ${INC_DIR}/simulation.hpp ${INC_DIR}/species.hpp + ${INC_DIR}/solvers/solver_callback_handler.hpp + ${INC_DIR}/solvers/solver_runner.hpp ${INC_DIR}/velocity.hpp) # Create library diff --git a/include/nektar_interface/utilities.hpp b/include/nektar_interface/utilities.hpp index ef39cd91..0437ac5e 100644 --- a/include/nektar_interface/utilities.hpp +++ b/include/nektar_interface/utilities.hpp @@ -31,17 +31,38 @@ class NektarFieldIndexMap { std::map field_to_index; public: + /** + * Create map from field names to indices. It is assumed that the field + * index is the position in the input vector. + * + * @param field_names Vector of field names. + */ NektarFieldIndexMap(std::vector field_names) { int index = 0; for (auto field_name : field_names) { this->field_to_index[field_name] = index++; } } + /** + * Get the index of a field by name. + * + * @param field_name Name of field to get index for. + * @returns Non-negative integer if field exists -1 otherwise. + */ int get_idx(std::string field_name) { return (this->field_to_index.count(field_name) > 0) ? this->field_to_index[field_name] : -1; } + + /** + * Identical to get_idx except this method mirrors the std library behaviour + * and is fatal if the named field does not exist in the map. + * + * @param field_name Name of field to get index for. + * @returns Non-negative integer if field exists. + */ + int at(std::string field_name) { return this->field_to_index.at(field_name); } }; /** diff --git a/include/solvers/solver_callback_handler.hpp b/include/solvers/solver_callback_handler.hpp new file mode 100644 index 00000000..5628a4fd --- /dev/null +++ b/include/solvers/solver_callback_handler.hpp @@ -0,0 +1,227 @@ +#ifndef __SOLVER_CALLBACK_HANDLER_H_ +#define __SOLVER_CALLBACK_HANDLER_H_ +#include +#include + +namespace NESO { + +/** + * Base class which can be inherited from to create a callback for a solver + * class called NameOfSolver. + * + * class Foo: public SolverCallback { + * void call(NameOfSolver * state){ + * // Do something with state + * } + * } + * + * Deriving from this class is not compulsory to create a callback. + */ +template struct SolverCallback { + + /** + * Call the callback function with the current state passed as a pointer. The + * callback may modify the solver (at the callers peril). Note the order in + * which callbacks are called is undefined. + * + * @param[in, out] state Pointer to solver instance. + */ + virtual void call(SOLVER *state) = 0; +}; + +/** + * Class to handle calling callbacks from within a solver. This class can be a + * member variable of a solver or inherited from by the solver. The class is + * templated on the solver type which defines the pointer type passed to the + * callback functions. + */ +template class SolverCallbackHandler { +protected: + /// Functions to be typically called before an integration step. + std::vector> pre_integrate_funcs; + /// Functions to be typically called after an integration step. + std::vector> post_integrate_funcs; + + /** + * Helper function to convert an input function handle to a object which can + * be stored on the vector of function handles. + * + * @param[in] func Function handle to process. + * @returns standardised function handle. + */ + inline std::function + get_as_consistent_type(std::function func) { + std::function f = std::bind(func, std::placeholders::_1); + return f; + } + + /** + * Helper function to convert an input function handle to a object which can + * be stored on the vector of function handles. + * + * @param[in] func Class::method_name to call as function handle. + * @param[in] inst object on which to call method. + * @returns standardised function handle. + */ + template + inline std::function get_as_consistent_type(T func, U &inst) { + std::function f = + std::bind(func, std::ref(inst), std::placeholders::_1); + return f; + } + +public: + /** + * Register a function to be called before each time integration step. e.g. + * + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_pre_integrate( + * std::function{ + * [&](NameOfSolver *state) { + * // use state + * } + * } + * ); + * } + * + * @param[in] func Function handle to add to callbacks. + */ + inline void register_pre_integrate(std::function func) { + this->pre_integrate_funcs.push_back(this->get_as_consistent_type(func)); + } + + /** + * Register a class method to be called before each time integration step. + * e.g. + * + * struct TestInterface { + * void call(NameOfSolver *state) { + * // use state + * } + * }; + * + * TestInterface ti; + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_pre_integrate(&TestInterface::call, + * ti); + * + * @param[in] func Function handle to add to callbacks in the form of + * &ClassName::method_name + * @param[in] inst Instance of ClassName object on which to call method_name. + */ + template + inline void register_pre_integrate(T func, U &inst) { + this->pre_integrate_funcs.push_back( + this->get_as_consistent_type(func, inst)); + } + + /** + * Register a type derived of SolverCallback as a callback. e.g. + * + * struct TestInterface : public SolverCallback { + * void call(NameOfSolver *state) { + * // use state + * } + * }; + * + * TestInterface ti; + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_pre_integrate(ti); + * + * @param[in] obj Derived type to add as callback object. + */ + inline void register_pre_integrate(SolverCallback &obj) { + this->pre_integrate_funcs.push_back( + this->get_as_consistent_type(&SolverCallback::call, obj)); + } + + /** + * Register a function to be called after each time integration step. e.g. + * + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_post_integrate( + * std::function{ + * [&](NameOfSolver *state) { + * // use state + * } + * } + * ); + * } + * + * @param[in] func Function handle to add to callbacks. + */ + inline void register_post_integrate(std::function func) { + this->post_integrate_funcs.push_back(this->get_as_consistent_type(func)); + } + + /** + * Register a class method to be called after each time integration step. e.g. + * + * struct TestInterface { + * void call(NameOfSolver *state) { + * // use state + * } + * }; + * + * TestInterface ti; + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_post_integrate(&TestInterface::call, + * ti); + * + * @param[in] func Function handle to add to callbacks in the form of + * &ClassName::method_name + * @param[in] inst Instance of ClassName object on which to call method_name. + */ + template + inline void register_post_integrate(T func, U &inst) { + this->post_integrate_funcs.push_back( + this->get_as_consistent_type(func, inst)); + } + + /** + * Register a type derived of SolverCallback as a callback. e.g. + * + * struct TestInterface : public SolverCallback { + * void call(NameOfSolver *state) { + * // use state + * } + * }; + * + * TestInterface ti; + * SolverCallbackHandler solver_callback_handler; + * solver_callback_handler.register_post_integrate(ti); + * + * @param[in] obj Derived type to add as callback object. + */ + inline void register_post_integrate(SolverCallback &obj) { + this->post_integrate_funcs.push_back( + this->get_as_consistent_type(&SolverCallback::call, obj)); + } + + /** + * Call all function handles which were registered as pre-integration + * callbacks. + * + * @param[in, out] state Solver state used to call each function handle. + */ + inline void call_pre_integrate(SOLVER *state) { + for (auto &func : this->pre_integrate_funcs) { + func(state); + } + } + + /** + * Call all function handles which were registered as post-integration + * callbacks. + * + * @param[in, out] state Solver state used to call each function handle. + */ + inline void call_post_integrate(SOLVER *state) { + for (auto &func : this->post_integrate_funcs) { + func(state); + } + } +}; + +} // namespace NESO +#endif diff --git a/include/solvers/solver_runner.hpp b/include/solvers/solver_runner.hpp new file mode 100644 index 00000000..ff636f29 --- /dev/null +++ b/include/solvers/solver_runner.hpp @@ -0,0 +1,51 @@ +#ifndef __SOLVER_RUNNER_H_ +#define __SOLVER_RUNNER_H_ + +#include +#include +#include + +using namespace Nektar; + +/** + * Class to abstract setting up sessions and drivers for Nektar++ solvers. + */ +class SolverRunner { +public: + /// Nektar++ session object. + LibUtilities::SessionReaderSharedPtr session; + /// MeshGraph instance for solver. + SpatialDomains::MeshGraphSharedPtr graph; + /// The Driver created for the solver. + SolverUtils::DriverSharedPtr driver; + + /** + * Create session, graph and driver from files. + * + * @param argc Number of arguments (like for main). + * @param argv Array of char* filenames (like for main). + */ + SolverRunner(int argc, char **argv) { + // Create session reader. + this->session = LibUtilities::SessionReader::CreateInstance(argc, argv); + // Read the mesh and create a MeshGraph object. + this->graph = SpatialDomains::MeshGraph::Read(this->session); + // Create driver. + std::string driverName; + session->LoadSolverInfo("Driver", driverName, "Standard"); + this->driver = SolverUtils::GetDriverFactory().CreateInstance( + driverName, session, graph); + } + + /** + * Calls Execute on the underlying driver object to run the solver. + */ + inline void execute() { this->driver->Execute(); } + + /** + * Calls Finalise on the underlying session object. + */ + inline void finalise() { this->session->Finalise(); } +}; + +#endif diff --git a/solvers/SimpleSOL/Diagnostics/mass_conservation.hpp b/solvers/SimpleSOL/Diagnostics/mass_conservation.hpp index 4c2d6520..f09740d6 100644 --- a/solvers/SimpleSOL/Diagnostics/mass_conservation.hpp +++ b/solvers/SimpleSOL/Diagnostics/mass_conservation.hpp @@ -122,21 +122,30 @@ template class MassRecording { } } + inline double get_initial_mass() { + NESOASSERT(this->initial_mass_computed == true, + "initial mass not computed"); + return this->initial_mass_fluid; + } + inline void compute(int step) { - if (step % mass_recording_step == 0) { - const double mass_particles = this->compute_particle_mass(); - const double mass_fluid = this->compute_fluid_mass(); - const double mass_total = mass_particles + mass_fluid; - const double mass_added = this->compute_total_added_mass(); - const double correct_total = mass_added + this->initial_mass_fluid; - - // Write values to file - if (rank == 0) { - nprint(step, ",", abs(correct_total - mass_total) / abs(correct_total), - ",", mass_particles, ",", mass_fluid, ","); - fh << step << "," - << abs(correct_total - mass_total) / abs(correct_total) << "," - << mass_particles << "," << mass_fluid << "\n"; + if (mass_recording_step > 0) { + if (step % mass_recording_step == 0) { + const double mass_particles = this->compute_particle_mass(); + const double mass_fluid = this->compute_fluid_mass(); + const double mass_total = mass_particles + mass_fluid; + const double mass_added = this->compute_total_added_mass(); + const double correct_total = mass_added + this->initial_mass_fluid; + + // Write values to file + if (rank == 0) { + nprint(step, ",", + abs(correct_total - mass_total) / abs(correct_total), ",", + mass_particles, ",", mass_fluid, ","); + fh << step << "," + << abs(correct_total - mass_total) / abs(correct_total) << "," + << mass_particles << "," << mass_fluid << "\n"; + } } } }; diff --git a/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.cpp b/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.cpp index 4228b447..6d781bdb 100644 --- a/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.cpp +++ b/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.cpp @@ -100,11 +100,9 @@ void SOLWithParticlesSystem::v_InitObject(bool DeclareField) { m_particle_sys->setup_evaluate_n(m_discont_fields["rho"]); m_particle_sys->setup_evaluate_T(m_discont_fields["T"]); - if (m_diag_mass_recording_enabled) { - m_diag_mass_recording = - std::make_shared>( - m_session, m_particle_sys, m_discont_fields["rho"]); - } + m_diag_mass_recording = + std::make_shared>( + m_session, m_particle_sys, m_discont_fields["rho"]); } /** @@ -124,10 +122,13 @@ bool SOLWithParticlesSystem::v_PostIntegrate(int step) { if (m_diag_mass_recording_enabled) { m_diag_mass_recording->compute(step); } + + m_solver_callback_handler.call_post_integrate(this); return SOLSystem::v_PostIntegrate(step); } bool SOLWithParticlesSystem::v_PreIntegrate(int step) { + m_solver_callback_handler.call_pre_integrate(this); if (m_diag_mass_recording_enabled) { m_diag_mass_recording->compute_initial_fluid_mass(); @@ -142,4 +143,19 @@ bool SOLWithParticlesSystem::v_PreIntegrate(int step) { return SOLSystem::v_PreIntegrate(step); } +ExpListSharedPtr +SOLWithParticlesSystem::GetField(const std::string field_name) { + ExpListSharedPtr ptr(nullptr); + int idx = m_field_to_index.get_idx(field_name); + if (idx > -1) { + ptr = m_fields[idx]; + } + return ptr; +} + +std::shared_ptr +SOLWithParticlesSystem::GetNeutralParticleSystem() { + return m_particle_sys; +} + } // namespace Nektar diff --git a/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.h b/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.h index 19832769..df0fdc44 100644 --- a/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.h +++ b/solvers/SimpleSOL/EquationSystems/SOLWithParticlesSystem.h @@ -39,6 +39,7 @@ #include "../ParticleSystems/neutral_particles.hpp" #include "SOLSystem.h" #include +#include namespace Nektar { /** @@ -53,6 +54,13 @@ class SOLWithParticlesSystem : public SOLSystem, /// Name of class. static std::string className; + /// Callback handler to call user defined callbacks. + SolverCallbackHandler m_solver_callback_handler; + + // Object that allows optional recording of stats related to mass conservation + std::shared_ptr> + m_diag_mass_recording; + /// Creates an instance of this class. static SolverUtils::EquationSystemSharedPtr create(const LibUtilities::SessionReaderSharedPtr &pSession, @@ -69,10 +77,24 @@ class SOLWithParticlesSystem : public SOLSystem, virtual ~SOLWithParticlesSystem(); + /** + * Get a field in the equation system by specifiying the field name. + * + * @param field_name Name of field to extract. + * @returns Requested field if it exists otherwise nullptr + */ + ExpListSharedPtr GetField(const std::string field_name); + + /** + * Get a shared pointer to the neutral particle system. + * + * @returns Pointer to neutral particle system. + */ + std::shared_ptr GetNeutralParticleSystem(); + + protected: - // Object that allows optional recording of stats related to mass conservation - std::shared_ptr> - m_diag_mass_recording; + // Flag to toggle mass conservation checking bool m_diag_mass_recording_enabled; // Map of field name to field index diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 760b9c03..72d33b91 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,8 @@ set(UNIT_SRC_FILES ${UNIT_SRC}/nektar_interface/test_particle_geometry_interface_3d.cpp ${UNIT_SRC}/nektar_interface/test_basis_evaluation.cpp ${UNIT_SRC}/nektar_interface/test_particle_mapping.cpp - ${UNIT_SRC}/nektar_interface/test_utility_cartesian_mesh.cpp) + ${UNIT_SRC}/nektar_interface/test_utility_cartesian_mesh.cpp + ${UNIT_SRC}/test_solver_callback.cpp) set(INTEGRATION_SRC ${CMAKE_CURRENT_SOURCE_DIR}/integration) set(INTEGRATION_SRC_FILES diff --git a/test/integration/solvers/SimpleSOL/2DWithParticles/2DWithParticles_config.xml b/test/integration/solvers/SimpleSOL/2DWithParticles/2DWithParticles_config.xml index 089cde91..f0332ffc 100755 --- a/test/integration/solvers/SimpleSOL/2DWithParticles/2DWithParticles_config.xml +++ b/test/integration/solvers/SimpleSOL/2DWithParticles/2DWithParticles_config.xml @@ -28,7 +28,7 @@

unrotated_x_max = 110.0

unrotated_y_max = 1.0

srcs_mask = 0.0

-

mass_recording_step = 1

+

mass_recording_step = 0

@@ -114,4 +114,4 @@ - \ No newline at end of file + diff --git a/test/integration/solvers/SimpleSOL/test_SimpleSOL.cpp b/test/integration/solvers/SimpleSOL/test_SimpleSOL.cpp index 61cfd75d..722ef37b 100644 --- a/test/integration/solvers/SimpleSOL/test_SimpleSOL.cpp +++ b/test/integration/solvers/SimpleSOL/test_SimpleSOL.cpp @@ -40,7 +40,27 @@ TEST_F(SimpleSOLTest, 2Drot45) { } TEST_F(SimpleSOLTest, 2DWithParticles) { - int ret_code = run({NESO::Solvers::run_SimpleSOL}); + + SOLWithParticlesMassConservationPre callback_pre; + SOLWithParticlesMassConservationPost callback_post; + + MainFuncType runner = [&](int argc, char **argv) { + SolverRunner solver_runner(argc, argv); + auto equation_system = std::dynamic_pointer_cast( + solver_runner.driver->GetEqu()[0]); + + equation_system->m_solver_callback_handler.register_pre_integrate( + callback_pre); + equation_system->m_solver_callback_handler.register_post_integrate( + callback_post); + + solver_runner.execute(); + solver_runner.finalise(); + return 0; + }; + + int ret_code = run(runner); EXPECT_EQ(ret_code, 0); - check_mass_conservation(mass_cons_tolerance); -} \ No newline at end of file + ASSERT_THAT(callback_post.mass_error, + testing::Each(testing::Le(mass_cons_tolerance))); +} diff --git a/test/integration/solvers/SimpleSOL/test_SimpleSOL.h b/test/integration/solvers/SimpleSOL/test_SimpleSOL.h index d09d67a4..6970ebe1 100644 --- a/test/integration/solvers/SimpleSOL/test_SimpleSOL.h +++ b/test/integration/solvers/SimpleSOL/test_SimpleSOL.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -13,6 +14,9 @@ #include "SimpleSOL.h" #include "solver_test_utils.h" +#include "solvers/solver_runner.hpp" +#include "solvers/solver_callback_handler.hpp" +#include "EquationSystems/SOLWithParticlesSystem.h" namespace LU = Nektar::LibUtilities; namespace FU = Nektar::FieldUtils; @@ -182,4 +186,25 @@ class SimpleSOLTest : public NektarSolverTest { } }; -#endif // SIMPLESOL_TESTS_COMMON \ No newline at end of file + +struct SOLWithParticlesMassConservationPre : public NESO::SolverCallback { + void call(SOLWithParticlesSystem *state) { + state->m_diag_mass_recording->compute_initial_fluid_mass(); + } +}; + +struct SOLWithParticlesMassConservationPost : public NESO::SolverCallback { + std::vector mass_error; + void call(SOLWithParticlesSystem *state) { + auto md = state->m_diag_mass_recording; + const double mass_particles = md->compute_particle_mass(); + const double mass_fluid = md->compute_fluid_mass(); + const double mass_total = mass_particles + mass_fluid; + const double mass_added = md->compute_total_added_mass(); + const double correct_total = mass_added + md->get_initial_mass(); + this->mass_error.push_back(std::fabs(correct_total - mass_total)/std::fabs(correct_total)); + } +}; + + +#endif // SIMPLESOL_TESTS_COMMON diff --git a/test/unit/test_solver_callback.cpp b/test/unit/test_solver_callback.cpp new file mode 100644 index 00000000..c22ef2f2 --- /dev/null +++ b/test/unit/test_solver_callback.cpp @@ -0,0 +1,91 @@ +#include +#include +#include + +using namespace NESO; +using namespace NESO::Particles; + +namespace { + +struct TestClass { + int ia; + double da; +}; + +struct TestFunc { + double da; + void operator()(TestClass *test_class) { this->da = test_class->da; } + void call(TestClass *test_class) { this->da = test_class->da; } +}; + +struct TestInterface : public SolverCallback { + double da; + double ia; + void call(TestClass *state) { + this->da = state->da; + this->ia = state->ia; + } +}; + +} // namespace + +TEST(SolverCallback, Base) { + + SolverCallbackHandler sc; + + TestFunc test_func_0; + sc.register_pre_integrate(&TestFunc::call, test_func_0); + + int tia = -1; + std::function lambda_func_0 = [&](TestClass *test_class) { + tia = test_class->ia; + }; + + sc.register_post_integrate(lambda_func_0); + + TestClass tc{}; + tc.ia = 42; + tc.da = 3.1415; + sc.call_pre_integrate(&tc); + ASSERT_EQ(test_func_0.da, 3.1415); + + sc.call_post_integrate(&tc); + ASSERT_EQ(tia, 42); +} + +TEST(SolverCallback, Scope) { + + SolverCallbackHandler sc; + + int tia = -1; + + { + sc.register_post_integrate(std::function{ + [&](TestClass *test_class) { tia = test_class->ia; }}); + } + + TestClass tc{}; + tc.ia = 42; + tc.da = 3.1415; + + sc.call_post_integrate(&tc); + ASSERT_EQ(tia, 42); +} + +TEST(SolverCallback, Class) { + + SolverCallbackHandler sc; + TestClass tc{}; + tc.ia = 42; + tc.da = 3.1415; + TestInterface test_func_0; + sc.register_pre_integrate(test_func_0); + sc.call_pre_integrate(&tc); + ASSERT_EQ(test_func_0.da, 3.1415); + + TestInterface test_func_1; + sc.register_post_integrate(test_func_1); + sc.call_post_integrate(&tc); + ASSERT_EQ(test_func_1.da, 3.1415); + ASSERT_EQ(test_func_1.ia, 42); +}