From 08c1753a2680f5fe84fea4a49bc96d88ded62d13 Mon Sep 17 00:00:00 2001 From: dabele Date: Wed, 20 Dec 2023 14:48:36 +0100 Subject: [PATCH] 829 Optimization of TestingStrategy (#830) Performance optimization of TestingStrategies - vector instead of unordered_map for testing schemes): ~25% decrease of run time - switch of ifs in testing strategy: additional ~8 % decrease of run time - CustomIndexArray for GoToWork/GoToSchool paramater: no measureable change - bitset for agegroups in TestingCriteria with fixed number of age groups: ~7% decrease of run time Co-authored-by: DavidKerkmann <44698825+DavidKerkmann@users.noreply.github.com> --- cpp/examples/abm_minimal.cpp | 9 ++- cpp/models/abm/config.h | 36 ++++++++++++ cpp/models/abm/migration_rules.cpp | 4 +- cpp/models/abm/parameters.h | 17 ++++-- cpp/models/abm/testing_strategy.cpp | 77 +++++++++++++++++-------- cpp/models/abm/testing_strategy.h | 5 +- cpp/models/abm/world.h | 4 +- cpp/tests/test_abm_lockdown_rules.cpp | 28 ++++++--- cpp/tests/test_abm_migration_rules.cpp | 42 ++++++++++---- cpp/tests/test_abm_testing_strategy.cpp | 5 +- 10 files changed, 165 insertions(+), 62 deletions(-) create mode 100644 cpp/models/abm/config.h diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index 623b9153a1..e42b48b651 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -41,9 +41,12 @@ int main() world.parameters.get() = 4.; // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - world.parameters.get() = {age_group_5_to_14}; + world.parameters.get() = false; + world.parameters.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59) - world.parameters.get() = {age_group_15_to_34, age_group_35_to_59}; + world.parameters.get() = false; + world.parameters.get()[age_group_15_to_34] = true; + world.parameters.get()[age_group_35_to_59] = true; // Check if the parameters satisfy their contraints. world.parameters.check_constraints(); @@ -169,4 +172,4 @@ int main() std::cout << "Results written to abm_minimal.txt" << std::endl; return 0; -} \ No newline at end of file +} diff --git a/cpp/models/abm/config.h b/cpp/models/abm/config.h new file mode 100644 index 0000000000..5f281331ec --- /dev/null +++ b/cpp/models/abm/config.h @@ -0,0 +1,36 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Daniel Abele +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef MIO_ABM_CONFIG_H +#define MIO_ABM_CONFIG_H + +namespace mio +{ +namespace abm +{ + +/** + * Maximum number of age groups allowed in the model. + */ +const constexpr int MAX_NUM_AGE_GROUPS = 64; + +} +} // namespace mio + +#endif diff --git a/cpp/models/abm/migration_rules.cpp b/cpp/models/abm/migration_rules.cpp index eb12e65ed4..1faad0b44f 100644 --- a/cpp/models/abm/migration_rules.cpp +++ b/cpp/models/abm/migration_rules.cpp @@ -56,7 +56,7 @@ LocationType go_to_school(Person::RandomNumberGenerator& /*rng*/, const Person& if (current_loc == LocationType::Home && t < params.get() && t.day_of_week() < 5 && person.get_go_to_school_time(params) >= t.time_since_midnight() && person.get_go_to_school_time(params) < t.time_since_midnight() + dt && - params.get().count(person.get_age()) && person.goes_to_school(t, params) && + params.get()[person.get_age()] && person.goes_to_school(t, params) && !person.is_in_quarantine()) { return LocationType::School; } @@ -73,7 +73,7 @@ LocationType go_to_work(Person::RandomNumberGenerator& /*rng*/, const Person& pe auto current_loc = person.get_location().get_type(); if (current_loc == LocationType::Home && t < params.get() && - params.get().count(person.get_age()) && t.day_of_week() < 5 && + params.get()[person.get_age()] && t.day_of_week() < 5 && t.time_since_midnight() + dt > person.get_go_to_work_time(params) && t.time_since_midnight() <= person.get_go_to_work_time(params) && person.goes_to_work(t, params) && !person.is_in_quarantine()) { diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 52c121a1ce..ee3444d1bf 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -497,10 +497,12 @@ struct GotoSchoolTimeMaximum { * @brief The set of AgeGroups that can go to school. */ struct AgeGroupGotoSchool { - using Type = std::set; - static Type get_default(AgeGroup /*size*/) + using Type = CustomIndexArray; + static Type get_default(AgeGroup num_agegroups) { - return std::set{AgeGroup(1)}; + auto a = Type(num_agegroups, false); + a[AgeGroup(1)] = true; + return a; } static std::string name() { @@ -512,10 +514,13 @@ struct AgeGroupGotoSchool { * @brief The set of AgeGroups that can go to work. */ struct AgeGroupGotoWork { - using Type = std::set; - static Type get_default(AgeGroup /*size*/) + using Type = CustomIndexArray; + static Type get_default(AgeGroup num_agegroups) { - return std::set{AgeGroup(2), AgeGroup(3)}; + auto a = Type(num_agegroups, false); + a[AgeGroup(2)] = true; + a[AgeGroup(3)] = true; + return a; } static std::string name() { diff --git a/cpp/models/abm/testing_strategy.cpp b/cpp/models/abm/testing_strategy.cpp index 33bcf3ede5..2fce3d2b17 100644 --- a/cpp/models/abm/testing_strategy.cpp +++ b/cpp/models/abm/testing_strategy.cpp @@ -29,7 +29,7 @@ namespace abm TestingCriteria::TestingCriteria(const std::vector& ages, const std::vector& infection_states) { for (auto age : ages) { - m_ages.insert(static_cast(age)); + m_ages.set(static_cast(age), true); } for (auto infection_state : infection_states) { m_infection_states.set(static_cast(infection_state), true); @@ -43,13 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const void TestingCriteria::add_age_group(const AgeGroup age_group) { - - m_ages.insert(static_cast(age_group)); + m_ages.set(static_cast(age_group), true); } void TestingCriteria::remove_age_group(const AgeGroup age_group) { - m_ages.erase(static_cast(age_group)); + m_ages.set(static_cast(age_group), false); } void TestingCriteria::add_infection_state(const InfectionState infection_state) @@ -65,7 +64,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat bool TestingCriteria::evaluate(const Person& p, TimePoint t) const { // An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set. - return (m_ages.empty() || m_ages.count(static_cast(p.get_age()))) && + return (m_ages.none() || m_ages[static_cast(p.get_age())]) && (m_infection_states.none() || m_infection_states[static_cast(p.get_infection_state(t))]); } @@ -104,9 +103,9 @@ void TestingScheme::update_activity_status(TimePoint t) bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& person, TimePoint t) const { if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) { - double random = UniformDistribution::get_instance()(rng); - if (random < m_probability) { - if (m_testing_criteria.evaluate(person, t)) { + if (m_testing_criteria.evaluate(person, t)) { + double random = UniformDistribution::get_instance()(rng); + if (random < m_probability) { return !person.get_tested(rng, t, m_test_type.get_default()); } } @@ -116,23 +115,45 @@ bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& perso TestingStrategy::TestingStrategy( const std::unordered_map>& location_to_schemes_map) - : m_location_to_schemes_map(location_to_schemes_map) + : m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end()) { } void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme) { - auto& schemes_vector = m_location_to_schemes_map[loc_id]; - if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) { - schemes_vector.emplace_back(scheme); + auto iter_schemes = + std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) { + return p.first == loc_id; + }); + if (iter_schemes == m_location_to_schemes_map.end()) { + //no schemes for this location yet, add a new list with one scheme + m_location_to_schemes_map.emplace_back(loc_id, std::vector(1, scheme)); + } + else { + //add scheme to existing vector if the scheme doesn't exist yet + auto& schemes = iter_schemes->second; + if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) { + schemes.push_back(scheme); + } } } void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme) { - auto& schemes_vector = m_location_to_schemes_map[loc_id]; - auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme); - schemes_vector.erase(last, schemes_vector.end()); + auto iter_schemes = + std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) { + return p.first == loc_id; + }); + if (iter_schemes != m_location_to_schemes_map.end()) { + //remove the scheme from the list + auto& schemes_vector = iter_schemes->second; + auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme); + schemes_vector.erase(last, schemes_vector.end()); + //delete the list of schemes for this location if no schemes left + if (schemes_vector.empty()) { + m_location_to_schemes_map.erase(iter_schemes); + } + } } void TestingStrategy::update_activity_status(TimePoint t) @@ -152,16 +173,22 @@ bool TestingStrategy::run_strategy(Person::RandomNumberGenerator& rng, Person& p return true; } - // Combine two vectors of schemes at corresponding location and location stype - std::vector* schemes_vector[] = { - &m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}], - &m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]}; - - for (auto vec_ptr : schemes_vector) { - if (!std::all_of(vec_ptr->begin(), vec_ptr->end(), [&rng, &person, t](TestingScheme& ts) { - return !ts.is_active() || ts.run_scheme(rng, person, t); - })) { - return false; + //lookup schemes for this specific location as well as the location type + //lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes + for (auto loc_key : {LocationId{location.get_index(), location.get_type()}, + LocationId{INVALID_LOCATION_INDEX, location.get_type()}}) { + auto iter_schemes = + std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_key](auto& p) { + return p.first == loc_key; + }); + if (iter_schemes != m_location_to_schemes_map.end()) { + //apply all testing schemes that are found + auto& schemes = iter_schemes->second; + if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) { + return !ts.is_active() || ts.run_scheme(rng, person, t); + })) { + return false; + } } } return true; diff --git a/cpp/models/abm/testing_strategy.h b/cpp/models/abm/testing_strategy.h index 0591e3c931..e947ec30c7 100644 --- a/cpp/models/abm/testing_strategy.h +++ b/cpp/models/abm/testing_strategy.h @@ -20,6 +20,7 @@ #ifndef EPI_ABM_TESTING_SCHEME_H #define EPI_ABM_TESTING_SCHEME_H +#include "abm/config.h" #include "abm/parameters.h" #include "abm/person.h" #include "abm/location.h" @@ -91,7 +92,7 @@ class TestingCriteria bool evaluate(const Person& p, TimePoint t) const; private: - std::unordered_set m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. + std::bitset m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested. std::bitset<(size_t)InfectionState::Count> m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to be tested.*/ @@ -221,7 +222,7 @@ class TestingStrategy bool run_strategy(Person::RandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t); private: - std::unordered_map> + std::vector>> m_location_to_schemes_map; ///< Set of schemes that are checked for testing. }; diff --git a/cpp/models/abm/world.h b/cpp/models/abm/world.h index 52f41bad7f..8aa65ac5de 100644 --- a/cpp/models/abm/world.h +++ b/cpp/models/abm/world.h @@ -20,6 +20,7 @@ #ifndef EPI_ABM_WORLD_H #define EPI_ABM_WORLD_H +#include "abm/config.h" #include "abm/location_type.h" #include "abm/parameters.h" #include "abm/location.h" @@ -55,7 +56,7 @@ class World /** * @brief Create a World. - * @param[in] num_agegroups The number of AgeGroup%s in the simulated World. + * @param[in] num_agegroups The number of AgeGroup%s in the simulated World. Must be less than MAX_NUM_AGE_GROUPS. */ World(size_t num_agegroups) : parameters(num_agegroups) @@ -63,6 +64,7 @@ class World , m_use_migration_rules(true) , m_cemetery_id(add_location(LocationType::Cemetery)) { + assert(num_agegroups < MAX_NUM_AGE_GROUPS && "MAX_NUM_AGE_GROUPS exceeded."); } /** diff --git a/cpp/tests/test_abm_lockdown_rules.cpp b/cpp/tests/test_abm_lockdown_rules.cpp index 3f39dadc00..2124f5ac6d 100644 --- a/cpp/tests/test_abm_lockdown_rules.cpp +++ b/cpp/tests/test_abm_lockdown_rules.cpp @@ -53,9 +53,12 @@ TEST(TestLockdownRules, school_closure) p2.set_assigned_location(school); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; mio::abm::set_school_closure(t, 0.7, params); auto p1_rng = mio::abm::Person::RandomNumberGenerator(rng, p1); @@ -88,9 +91,12 @@ TEST(TestLockdownRules, school_opening) p.set_assigned_location(school); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; mio::abm::set_school_closure(t_closing, 1., params); mio::abm::set_school_closure(t_opening, 0., params); @@ -110,9 +116,12 @@ TEST(TestLockdownRules, home_office) mio::abm::Parameters params(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; mio::abm::set_home_office(t, 0.4, params); @@ -164,9 +173,12 @@ TEST(TestLockdownRules, no_home_office) p.set_assigned_location(work); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; mio::abm::set_home_office(t_closing, 0.5, params); mio::abm::set_home_office(t_opening, 0., params); diff --git a/cpp/tests/test_abm_migration_rules.cpp b/cpp/tests/test_abm_migration_rules.cpp index 5d3a50d38e..f0320d3029 100644 --- a/cpp/tests/test_abm_migration_rules.cpp +++ b/cpp/tests/test_abm_migration_rules.cpp @@ -44,9 +44,12 @@ TEST(TestMigrationRules, student_goes_to_school) auto child_rng = mio::abm::Person::RandomNumberGenerator(rng, p_child); auto adult_rng = mio::abm::Person::RandomNumberGenerator(rng, p_child); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ(mio::abm::go_to_school(child_rng, p_child, t_morning, dt, params), mio::abm::LocationType::School); ASSERT_EQ(mio::abm::go_to_school(adult_rng, p_adult, t_morning, dt, params), mio::abm::LocationType::Home); @@ -85,9 +88,12 @@ TEST(TestMigrationRules, students_go_to_school_in_different_times) mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ( mio::abm::go_to_school(rng_child_goes_to_school_at_6, p_child_goes_to_school_at_6, t_morning_6, dt, params), @@ -142,9 +148,12 @@ TEST(TestMigrationRules, students_go_to_school_in_different_times_with_smaller_t auto dt = mio::abm::seconds(1800); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ( mio::abm::go_to_school(rng_child_goes_to_school_at_6, p_child_goes_to_school_at_6, t_morning_6, dt, params), @@ -203,9 +212,12 @@ TEST(TestMigrationRules, worker_goes_to_work) mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ(mio::abm::go_to_work(rng_retiree, p_retiree, t_morning, dt, params), mio::abm::LocationType::Home); ASSERT_EQ(mio::abm::go_to_work(rng_adult, p_adult, t_morning, dt, params), mio::abm::LocationType::Home); @@ -240,9 +252,12 @@ TEST(TestMigrationRules, worker_goes_to_work_with_non_dividable_timespan) mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ(mio::abm::go_to_work(rng_retiree, p_retiree, t_morning, dt, params), mio::abm::LocationType::Home); ASSERT_EQ(mio::abm::go_to_work(rng_adult, p_adult, t_morning, dt, params), mio::abm::LocationType::Home); @@ -278,9 +293,12 @@ TEST(TestMigrationRules, workers_go_to_work_in_different_times) auto dt = mio::abm::hours(1); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) - params.get() = {age_group_5_to_14}; + params.get() = false; + params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59) - params.get() = {age_group_15_to_34, age_group_35_to_59}; + params.get() = false; + params.get()[age_group_15_to_34] = true; + params.get()[age_group_35_to_59] = true; ASSERT_EQ(mio::abm::go_to_work(rng_adult_goes_to_work_at_6, p_adult_goes_to_work_at_6, t_morning_6, dt, params), mio::abm::LocationType::Work); diff --git a/cpp/tests/test_abm_testing_strategy.cpp b/cpp/tests/test_abm_testing_strategy.cpp index b4714b24a9..523e4b0eb2 100644 --- a/cpp/tests/test_abm_testing_strategy.cpp +++ b/cpp/tests/test_abm_testing_strategy.cpp @@ -137,10 +137,9 @@ TEST(TestTestingScheme, initAndRunTestingStrategy) ScopedMockDistribution>>> mock_uniform_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) - .Times(testing::Exactly(3)) + .Times(testing::Exactly(2)) //only sampled twice, testing criteria don't apply to third person .WillOnce(testing::Return(0.7)) - .WillOnce(testing::Return(0.5)) - .WillOnce(testing::Return(0.9)); + .WillOnce(testing::Return(0.5)); mio::abm::TestingStrategy test_strategy = mio::abm::TestingStrategy(std::unordered_map>());