Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

829 optimize testingstrategy #830

Merged
merged 6 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cpp/examples/abm_minimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ int main()
world.parameters.get<mio::abm::IncubationPeriod>() = 4.;

// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
world.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59)
world.parameters.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
world.parameters.get<mio::abm::AgeGroupGotoWork>() = false;
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

// Check if the parameters satisfy their contraints.
world.parameters.check_constraints();
Expand Down Expand Up @@ -169,4 +172,4 @@ int main()
std::cout << "Results written to abm_minimal.txt" << std::endl;

return 0;
}
}
36 changes: 36 additions & 0 deletions cpp/models/abm/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (C) 2020-2024 MEmilio
*
* Authors: Daniel Abele
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MIO_ABM_CONFIG_H
#define MIO_ABM_CONFIG_H

namespace mio
{
namespace abm
{

/**
* Maximum number of age groups allowed in the model.
*/
const constexpr int MAX_NUM_AGE_GROUPS = 64;

}
} // namespace mio

#endif
4 changes: 2 additions & 2 deletions cpp/models/abm/migration_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ LocationType go_to_school(Person::RandomNumberGenerator& /*rng*/, const Person&
if (current_loc == LocationType::Home && t < params.get<LockdownDate>() && t.day_of_week() < 5 &&
person.get_go_to_school_time(params) >= t.time_since_midnight() &&
person.get_go_to_school_time(params) < t.time_since_midnight() + dt &&
params.get<mio::abm::AgeGroupGotoSchool>().count(person.get_age()) && person.goes_to_school(t, params) &&
params.get<mio::abm::AgeGroupGotoSchool>()[person.get_age()] && person.goes_to_school(t, params) &&
!person.is_in_quarantine()) {
return LocationType::School;
}
Expand All @@ -73,7 +73,7 @@ LocationType go_to_work(Person::RandomNumberGenerator& /*rng*/, const Person& pe
auto current_loc = person.get_location().get_type();

if (current_loc == LocationType::Home && t < params.get<LockdownDate>() &&
params.get<mio::abm::AgeGroupGotoWork>().count(person.get_age()) && t.day_of_week() < 5 &&
params.get<mio::abm::AgeGroupGotoWork>()[person.get_age()] && t.day_of_week() < 5 &&
t.time_since_midnight() + dt > person.get_go_to_work_time(params) &&
t.time_since_midnight() <= person.get_go_to_work_time(params) && person.goes_to_work(t, params) &&
!person.is_in_quarantine()) {
Expand Down
17 changes: 11 additions & 6 deletions cpp/models/abm/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,12 @@ struct GotoSchoolTimeMaximum {
* @brief The set of AgeGroups that can go to school.
*/
struct AgeGroupGotoSchool {
using Type = std::set<AgeGroup>;
static Type get_default(AgeGroup /*size*/)
using Type = CustomIndexArray<bool, AgeGroup>;
static Type get_default(AgeGroup num_agegroups)
{
return std::set<AgeGroup>{AgeGroup(1)};
auto a = Type(num_agegroups, false);
a[AgeGroup(1)] = true;
return a;
}
static std::string name()
{
Expand All @@ -512,10 +514,13 @@ struct AgeGroupGotoSchool {
* @brief The set of AgeGroups that can go to work.
*/
struct AgeGroupGotoWork {
using Type = std::set<AgeGroup>;
static Type get_default(AgeGroup /*size*/)
using Type = CustomIndexArray<bool, AgeGroup>;
static Type get_default(AgeGroup num_agegroups)
{
return std::set<AgeGroup>{AgeGroup(2), AgeGroup(3)};
auto a = Type(num_agegroups, false);
a[AgeGroup(2)] = true;
a[AgeGroup(3)] = true;
DavidKerkmann marked this conversation as resolved.
Show resolved Hide resolved
return a;
}
static std::string name()
{
Expand Down
77 changes: 52 additions & 25 deletions cpp/models/abm/testing_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace abm
TestingCriteria::TestingCriteria(const std::vector<AgeGroup>& ages, const std::vector<InfectionState>& infection_states)
{
for (auto age : ages) {
m_ages.insert(static_cast<size_t>(age));
m_ages.set(static_cast<size_t>(age), true);
}
for (auto infection_state : infection_states) {
m_infection_states.set(static_cast<size_t>(infection_state), true);
Expand All @@ -43,13 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const

void TestingCriteria::add_age_group(const AgeGroup age_group)
{

m_ages.insert(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), true);
}

void TestingCriteria::remove_age_group(const AgeGroup age_group)
{
m_ages.erase(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), false);
}

void TestingCriteria::add_infection_state(const InfectionState infection_state)
Expand All @@ -65,7 +64,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat
bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
{
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
return (m_ages.empty() || m_ages.count(static_cast<size_t>(p.get_age()))) &&
return (m_ages.none() || m_ages[static_cast<size_t>(p.get_age())]) &&
(m_infection_states.none() || m_infection_states[static_cast<size_t>(p.get_infection_state(t))]);
}

Expand Down Expand Up @@ -104,9 +103,9 @@ void TestingScheme::update_activity_status(TimePoint t)
bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& person, TimePoint t) const
{
if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) {
double random = UniformDistribution<double>::get_instance()(rng);
if (random < m_probability) {
if (m_testing_criteria.evaluate(person, t)) {
if (m_testing_criteria.evaluate(person, t)) {
double random = UniformDistribution<double>::get_instance()(rng);
if (random < m_probability) {
return !person.get_tested(rng, t, m_test_type.get_default());
}
}
Expand All @@ -116,23 +115,45 @@ bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& perso

TestingStrategy::TestingStrategy(
const std::unordered_map<LocationId, std::vector<TestingScheme>>& location_to_schemes_map)
: m_location_to_schemes_map(location_to_schemes_map)
: m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end())
{
}

void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
{
auto& schemes_vector = m_location_to_schemes_map[loc_id];
if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) {
schemes_vector.emplace_back(scheme);
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
return p.first == loc_id;
});
if (iter_schemes == m_location_to_schemes_map.end()) {
//no schemes for this location yet, add a new list with one scheme
m_location_to_schemes_map.emplace_back(loc_id, std::vector<TestingScheme>(1, scheme));
}
else {
//add scheme to existing vector if the scheme doesn't exist yet
auto& schemes = iter_schemes->second;
if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) {
schemes.push_back(scheme);
}
}
}

void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
{
auto& schemes_vector = m_location_to_schemes_map[loc_id];
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
schemes_vector.erase(last, schemes_vector.end());
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
return p.first == loc_id;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
//remove the scheme from the list
auto& schemes_vector = iter_schemes->second;
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
schemes_vector.erase(last, schemes_vector.end());
//delete the list of schemes for this location if no schemes left
if (schemes_vector.empty()) {
m_location_to_schemes_map.erase(iter_schemes);
}
}
}

void TestingStrategy::update_activity_status(TimePoint t)
Expand All @@ -152,16 +173,22 @@ bool TestingStrategy::run_strategy(Person::RandomNumberGenerator& rng, Person& p
return true;
}

// Combine two vectors of schemes at corresponding location and location stype
std::vector<TestingScheme>* schemes_vector[] = {
&m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}],
&m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]};

for (auto vec_ptr : schemes_vector) {
if (!std::all_of(vec_ptr->begin(), vec_ptr->end(), [&rng, &person, t](TestingScheme& ts) {
return !ts.is_active() || ts.run_scheme(rng, person, t);
})) {
return false;
//lookup schemes for this specific location as well as the location type
//lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes
for (auto loc_key : {LocationId{location.get_index(), location.get_type()},
LocationId{INVALID_LOCATION_INDEX, location.get_type()}}) {
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_key](auto& p) {
return p.first == loc_key;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
//apply all testing schemes that are found
auto& schemes = iter_schemes->second;
if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) {
return !ts.is_active() || ts.run_scheme(rng, person, t);
})) {
return false;
}
}
}
return true;
Expand Down
5 changes: 3 additions & 2 deletions cpp/models/abm/testing_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_TESTING_SCHEME_H
#define EPI_ABM_TESTING_SCHEME_H

#include "abm/config.h"
#include "abm/parameters.h"
#include "abm/person.h"
#include "abm/location.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ class TestingCriteria
bool evaluate(const Person& p, TimePoint t) const;

private:
std::unordered_set<size_t> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<MAX_NUM_AGE_GROUPS> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<(size_t)InfectionState::Count>
m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to
be tested.*/
Expand Down Expand Up @@ -221,7 +222,7 @@ class TestingStrategy
bool run_strategy(Person::RandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t);

private:
std::unordered_map<LocationId, std::vector<TestingScheme>>
std::vector<std::pair<LocationId, std::vector<TestingScheme>>>
m_location_to_schemes_map; ///< Set of schemes that are checked for testing.
};

Expand Down
4 changes: 3 additions & 1 deletion cpp/models/abm/world.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_WORLD_H
#define EPI_ABM_WORLD_H

#include "abm/config.h"
#include "abm/location_type.h"
#include "abm/parameters.h"
#include "abm/location.h"
Expand Down Expand Up @@ -55,14 +56,15 @@ class World

/**
* @brief Create a World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World. Must be less than MAX_NUM_AGE_GROUPS.
*/
World(size_t num_agegroups)
: parameters(num_agegroups)
, m_trip_list()
, m_use_migration_rules(true)
, m_cemetery_id(add_location(LocationType::Cemetery))
{
assert(num_agegroups < MAX_NUM_AGE_GROUPS && "MAX_NUM_AGE_GROUPS exceeded.");
}

/**
Expand Down
28 changes: 20 additions & 8 deletions cpp/tests/test_abm_lockdown_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ TEST(TestLockdownRules, school_closure)
p2.set_assigned_location(school);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
mio::abm::set_school_closure(t, 0.7, params);

auto p1_rng = mio::abm::Person::RandomNumberGenerator(rng, p1);
Expand Down Expand Up @@ -88,9 +91,12 @@ TEST(TestLockdownRules, school_opening)
p.set_assigned_location(school);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
mio::abm::set_school_closure(t_closing, 1., params);
mio::abm::set_school_closure(t_opening, 0., params);

Expand All @@ -110,9 +116,12 @@ TEST(TestLockdownRules, home_office)
mio::abm::Parameters params(num_age_groups);

// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

mio::abm::set_home_office(t, 0.4, params);

Expand Down Expand Up @@ -164,9 +173,12 @@ TEST(TestLockdownRules, no_home_office)
p.set_assigned_location(work);
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
params.get<mio::abm::AgeGroupGotoSchool>() = false;
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
params.get<mio::abm::AgeGroupGotoWork>() = false;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;

mio::abm::set_home_office(t_closing, 0.5, params);
mio::abm::set_home_office(t_opening, 0., params);
Expand Down
Loading