Skip to content

Commit

Permalink
perf: bitset of age groups
Browse files Browse the repository at this point in the history
check num age groups when creating the world
  • Loading branch information
dabele committed Dec 8, 2023
1 parent 38d609c commit 4a6073c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
36 changes: 36 additions & 0 deletions cpp/models/abm/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (C) 2020-2024 MEmilio
*
* Authors: Daniel Abele
*
* Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MIO_ABM_CONFIG_H
#define MIO_ABM_CONFIG_H

namespace mio
{
namespace abm
{

/**
* Maximum number of age groups allowed in the model.
*/
const constexpr int MAX_NUM_AGE_GROUPS = 64;

}
} // namespace mio

#endif
8 changes: 4 additions & 4 deletions cpp/models/abm/testing_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace abm
TestingCriteria::TestingCriteria(const std::vector<AgeGroup>& ages, const std::vector<InfectionState>& infection_states)
{
for (auto age : ages) {
m_ages.insert(static_cast<size_t>(age));
m_ages.set(static_cast<size_t>(age), true);
}
for (auto infection_state : infection_states) {
m_infection_states.set(static_cast<size_t>(infection_state), true);
Expand All @@ -43,12 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const

void TestingCriteria::add_age_group(const AgeGroup age_group)
{
m_ages.insert(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), true);
}

void TestingCriteria::remove_age_group(const AgeGroup age_group)
{
m_ages.erase(static_cast<size_t>(age_group));
m_ages.set(static_cast<size_t>(age_group), false);
}

void TestingCriteria::add_infection_state(const InfectionState infection_state)
Expand All @@ -64,7 +64,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat
bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
{
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
return (m_ages.empty() || m_ages.count(static_cast<size_t>(p.get_age()))) &&
return (m_ages.none() || m_ages[static_cast<size_t>(p.get_age())]) &&
(m_infection_states.none() || m_infection_states[static_cast<size_t>(p.get_infection_state(t))]);
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/models/abm/testing_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_TESTING_SCHEME_H
#define EPI_ABM_TESTING_SCHEME_H

#include "abm/config.h"
#include "abm/parameters.h"
#include "abm/person.h"
#include "abm/location.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ class TestingCriteria
bool evaluate(const Person& p, TimePoint t) const;

private:
std::unordered_set<size_t> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<MAX_NUM_AGE_GROUPS> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
std::bitset<(size_t)InfectionState::Count>
m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to
be tested.*/
Expand Down
4 changes: 3 additions & 1 deletion cpp/models/abm/world.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef EPI_ABM_WORLD_H
#define EPI_ABM_WORLD_H

#include "abm/config.h"
#include "abm/location_type.h"
#include "abm/parameters.h"
#include "abm/location.h"
Expand Down Expand Up @@ -55,14 +56,15 @@ class World

/**
* @brief Create a World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World.
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World. Must be less than MAX_NUM_AGE_GROUPS.
*/
World(size_t num_agegroups)
: parameters(num_agegroups)
, m_trip_list()
, m_use_migration_rules(true)
, m_cemetery_id(add_location(LocationType::Cemetery))
{
assert(num_agegroups < MAX_NUM_AGE_GROUPS && "MAX_NUM_AGE_GROUPS exceeded.");
}

/**
Expand Down

0 comments on commit 4a6073c

Please sign in to comment.