Skip to content

Commit

Permalink
829 Optimization of TestingStrategy (#830)
Browse files Browse the repository at this point in the history
Performance optimization of TestingStrategies
- vector instead of unordered_map for testing schemes): ~25% decrease of run time
- switch of ifs in testing strategy: additional ~8 % decrease of run time
- CustomIndexArray for GoToWork/GoToSchool paramater: no measureable change
- bitset for agegroups in TestingCriteria with fixed number of age groups: ~7% decrease of run time

Co-authored-by: DavidKerkmann <44698825+DavidKerkmann@users.noreply.github.com>
  • Loading branch information
dabele and DavidKerkmann authored Dec 20, 2023
1 parent 8ed8364 commit 08c1753
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 62 deletions.
9 changes: 6 additions & 3 deletions cpp/examples/abm_minimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ int main()
world.parameters.get<mio::abm::IncubationPeriod>() = 4.;

// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
world.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59)
world.parameters.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
world.parameters.get<mio::abm::AgeGroupGotoWork>() = false;
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

// Check if the parameters satisfy their contraints.
world.parameters.check_constraints();
Expand Down Expand Up @@ -169,4 +172,4 @@ int main()
std::cout << "Results written to abm_minimal.txt" << std::endl;

return 0;
}
}
36 changes: 36 additions & 0 deletions cpp/models/abm/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (C) 2020-2024 MEmilio
*
* Authors: Daniel Abele
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MIO_ABM_CONFIG_H
#define MIO_ABM_CONFIG_H

namespace mio
{
namespace abm
{

/**
* Maximum number of age groups allowed in the model.
*/
const constexpr int MAX_NUM_AGE_GROUPS = 64;

}
} // namespace mio

#endif
4 changes: 2 additions & 2 deletions cpp/models/abm/migration_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ LocationType go_to_school(Person::RandomNumberGenerator& /*rng*/, const Person&
if (current_loc == LocationType::Home && t < params.get<LockdownDate>() && t.day_of_week() < 5 &&
person.get_go_to_school_time(params) >= t.time_since_midnight() &&
person.get_go_to_school_time(params) < t.time_since_midnight() + dt &&
params.get<mio::abm::AgeGroupGotoSchool>().count(person.get_age()) && person.goes_to_school(t, params) &&
params.get<mio::abm::AgeGroupGotoSchool>()[person.get_age()] && person.goes_to_school(t, params) &&
!person.is_in_quarantine()) {
return LocationType::School;
}
Expand All @@ -73,7 +73,7 @@ LocationType go_to_work(Person::RandomNumberGenerator& /*rng*/, const Person& pe
auto current_loc = person.get_location().get_type();

if (current_loc == LocationType::Home && t < params.get<LockdownDate>() &&
params.get<mio::abm::AgeGroupGotoWork>().count(person.get_age()) && t.day_of_week() < 5 &&
params.get<mio::abm::AgeGroupGotoWork>()[person.get_age()] && t.day_of_week() < 5 &&
t.time_since_midnight() + dt > person.get_go_to_work_time(params) &&
t.time_since_midnight() <= person.get_go_to_work_time(params) && person.goes_to_work(t, params) &&
!person.is_in_quarantine()) {
Expand Down
17 changes: 11 additions & 6 deletions cpp/models/abm/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,12 @@ struct GotoSchoolTimeMaximum {
* @brief The set of AgeGroups that can go to school.
*/
struct AgeGroupGotoSchool {
using Type = std::set<AgeGroup>;
static Type get_default(AgeGroup /*size*/)
using Type = CustomIndexArray<bool, AgeGroup>;
static Type get_default(AgeGroup num_agegroups)
{
return std::set<AgeGroup>{AgeGroup(1)};
auto a = Type(num_agegroups, false);
a[AgeGroup(1)] = true;
return a;
}
static std::string name()
{
Expand All @@ -512,10 +514,13 @@ struct AgeGroupGotoSchool {
* @brief The set of AgeGroups that can go to work.
*/
struct AgeGroupGotoWork {
using Type = std::set<AgeGroup>;
static Type get_default(AgeGroup /*size*/)
using Type = CustomIndexArray<bool, AgeGroup>;
static Type get_default(AgeGroup num_agegroups)
{
return std::set<AgeGroup>{AgeGroup(2), AgeGroup(3)};
auto a = Type(num_agegroups, false);
a[AgeGroup(2)] = true;
a[AgeGroup(3)] = true;
return a;
}
static std::string name()
{
Expand Down
77 changes: 52 additions & 25 deletions cpp/models/abm/testing_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace abm
TestingCriteria::TestingCriteria(const std::vector<AgeGroup>& ages, const std::vector<InfectionState>& infection_states)
{
for (auto age : ages) {
m_ages.insert(static_cast<size_t>(age));
m_ages.set(static_cast<size_t>(age), true);
}
for (auto infection_state : infection_states) {
m_infection_states.set(static_cast<size_t>(infection_state), true);
Expand All @@ -43,13 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const

void TestingCriteria::add_age_group(const AgeGroup age_group)
{

m_ages.insert(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), true);
}

void TestingCriteria::remove_age_group(const AgeGroup age_group)
{
m_ages.erase(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), false);
}

void TestingCriteria::add_infection_state(const InfectionState infection_state)
Expand All @@ -65,7 +64,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat
bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
{
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
return (m_ages.empty() || m_ages.count(static_cast<size_t>(p.get_age()))) &&
return (m_ages.none() || m_ages[static_cast<size_t>(p.get_age())]) &&
(m_infection_states.none() || m_infection_states[static_cast<size_t>(p.get_infection_state(t))]);
}

Expand Down Expand Up @@ -104,9 +103,9 @@ void TestingScheme::update_activity_status(TimePoint t)
bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& person, TimePoint t) const
{
if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) {
double random = UniformDistribution<double>::get_instance()(rng);
if (random < m_probability) {
if (m_testing_criteria.evaluate(person, t)) {
if (m_testing_criteria.evaluate(person, t)) {
double random = UniformDistribution<double>::get_instance()(rng);
if (random < m_probability) {
return !person.get_tested(rng, t, m_test_type.get_default());
}
}
Expand All @@ -116,23 +115,45 @@ bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& perso

TestingStrategy::TestingStrategy(
const std::unordered_map<LocationId, std::vector<TestingScheme>>& location_to_schemes_map)
: m_location_to_schemes_map(location_to_schemes_map)
: m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end())
{
}

void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
{
auto& schemes_vector = m_location_to_schemes_map[loc_id];
if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) {
schemes_vector.emplace_back(scheme);
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
return p.first == loc_id;
});
if (iter_schemes == m_location_to_schemes_map.end()) {
//no schemes for this location yet, add a new list with one scheme
m_location_to_schemes_map.emplace_back(loc_id, std::vector<TestingScheme>(1, scheme));
}
else {
//add scheme to existing vector if the scheme doesn't exist yet
auto& schemes = iter_schemes->second;
if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) {
schemes.push_back(scheme);
}
}
}

void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
{
auto& schemes_vector = m_location_to_schemes_map[loc_id];
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
schemes_vector.erase(last, schemes_vector.end());
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
return p.first == loc_id;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
//remove the scheme from the list
auto& schemes_vector = iter_schemes->second;
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
schemes_vector.erase(last, schemes_vector.end());
//delete the list of schemes for this location if no schemes left
if (schemes_vector.empty()) {
m_location_to_schemes_map.erase(iter_schemes);
}
}
}

void TestingStrategy::update_activity_status(TimePoint t)
Expand All @@ -152,16 +173,22 @@ bool TestingStrategy::run_strategy(Person::RandomNumberGenerator& rng, Person& p
return true;
}

// Combine two vectors of schemes at corresponding location and location stype
std::vector<TestingScheme>* schemes_vector[] = {
&m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}],
&m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]};

for (auto vec_ptr : schemes_vector) {
if (!std::all_of(vec_ptr->begin(), vec_ptr->end(), [&rng, &person, t](TestingScheme& ts) {
return !ts.is_active() || ts.run_scheme(rng, person, t);
})) {
return false;
//lookup schemes for this specific location as well as the location type
//lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes
for (auto loc_key : {LocationId{location.get_index(), location.get_type()},
LocationId{INVALID_LOCATION_INDEX, location.get_type()}}) {
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_key](auto& p) {
return p.first == loc_key;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
//apply all testing schemes that are found
auto& schemes = iter_schemes->second;
if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) {
return !ts.is_active() || ts.run_scheme(rng, person, t);
})) {
return false;
}
}
}
return true;
Expand Down
5 changes: 3 additions & 2 deletions cpp/models/abm/testing_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_TESTING_SCHEME_H
#define EPI_ABM_TESTING_SCHEME_H

#include "abm/config.h"
#include "abm/parameters.h"
#include "abm/person.h"
#include "abm/location.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ class TestingCriteria
bool evaluate(const Person& p, TimePoint t) const;

private:
std::unordered_set<size_t> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<MAX_NUM_AGE_GROUPS> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<(size_t)InfectionState::Count>
m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to
be tested.*/
Expand Down Expand Up @@ -221,7 +222,7 @@ class TestingStrategy
bool run_strategy(Person::RandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t);

private:
std::unordered_map<LocationId, std::vector<TestingScheme>>
std::vector<std::pair<LocationId, std::vector<TestingScheme>>>
m_location_to_schemes_map; ///< Set of schemes that are checked for testing.
};

Expand Down
4 changes: 3 additions & 1 deletion cpp/models/abm/world.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_WORLD_H
#define EPI_ABM_WORLD_H

#include "abm/config.h"
#include "abm/location_type.h"
#include "abm/parameters.h"
#include "abm/location.h"
Expand Down Expand Up @@ -55,14 +56,15 @@ class World

/**
* @brief Create a World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World. Must be less than MAX_NUM_AGE_GROUPS.
*/
World(size_t num_agegroups)
: parameters(num_agegroups)
, m_trip_list()
, m_use_migration_rules(true)
, m_cemetery_id(add_location(LocationType::Cemetery))
{
assert(num_agegroups < MAX_NUM_AGE_GROUPS && "MAX_NUM_AGE_GROUPS exceeded.");
}

/**
Expand Down
28 changes: 20 additions & 8 deletions cpp/tests/test_abm_lockdown_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ TEST(TestLockdownRules, school_closure)
p2.set_assigned_location(school);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
mio::abm::set_school_closure(t, 0.7, params);

auto p1_rng = mio::abm::Person::RandomNumberGenerator(rng, p1);
Expand Down Expand Up @@ -88,9 +91,12 @@ TEST(TestLockdownRules, school_opening)
p.set_assigned_location(school);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
mio::abm::set_school_closure(t_closing, 1., params);
mio::abm::set_school_closure(t_opening, 0., params);

Expand All @@ -110,9 +116,12 @@ TEST(TestLockdownRules, home_office)
mio::abm::Parameters params(num_age_groups);

// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

mio::abm::set_home_office(t, 0.4, params);

Expand Down Expand Up @@ -164,9 +173,12 @@ TEST(TestLockdownRules, no_home_office)
p.set_assigned_location(work);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

mio::abm::set_home_office(t_closing, 0.5, params);
mio::abm::set_home_office(t_opening, 0., params);
Expand Down
Loading

0 comments on commit 08c1753

Please sign in to comment.