Skip to content

Commit

Permalink
Merge pull request #504 from starsimhub/demographic-results-v2-497
Browse files Browse the repository at this point in the history
Demographic update_results normalization
  • Loading branch information
cliffckerr authored May 14, 2024
2 parents 71472d5 + 4a6d512 commit 86e7615
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
34 changes: 23 additions & 11 deletions starsim/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def initialize(self, sim):
self.init_results()
return

def init_results(self):
pass

def update_results(self):
pass

def update(self):
# Note that for demographic modules, any result updates should be
# carried out inside this function
Expand All @@ -45,6 +51,7 @@ def __init__(self, pars=None, metadata=None, **kwargs):
# Process data, which may be provided as a number, dict, dataframe, or series
# If it's a number it's left as-is; otherwise it's converted to a dataframe
self.pars.birth_rate = self.standardize_birth_data()
self.n_births = 0 # For results tracking
return

def initialize(self, sim):
Expand Down Expand Up @@ -73,7 +80,7 @@ def init_results(self):

def update(self):
new_uids = self.add_births()
self.update_results(len(new_uids))
self.n_births = len(new_uids)
return new_uids

def get_births(self):
Expand All @@ -100,8 +107,8 @@ def add_births(self):
people.age[new_uids] = 0
return new_uids

def update_results(self, n_new):
self.results['new'][self.sim.ti] = n_new
def update_results(self):
self.results['new'][self.sim.ti] = self.n_births
return

def finalize(self):
Expand Down Expand Up @@ -162,6 +169,7 @@ def __init__(self, pars=None, metadata=None, **kwargs):
# If it's a number it's left as-is; otherwise it's converted to a dataframe
self.death_rate_data = self.standardize_death_data() # TODO: refactor
self.pars.death_rate = ss.bernoulli(p=self.make_death_prob_fn)
self.n_deaths = 0 # For results tracking
return

def standardize_death_data(self):
Expand Down Expand Up @@ -225,8 +233,7 @@ def init_results(self):
return

def update(self):
n_deaths = self.apply_deaths()
self.update_results(n_deaths)
self.n_deaths = self.apply_deaths()
return

def apply_deaths(self):
Expand All @@ -235,8 +242,8 @@ def apply_deaths(self):
self.sim.people.request_death(death_uids)
return len(death_uids)

def update_results(self, n_deaths):
self.results['new'][self.sim.ti] = n_deaths
def update_results(self):
self.results['new'][self.sim.ti] = self.n_deaths
return

def finalize(self):
Expand Down Expand Up @@ -285,6 +292,10 @@ def __init__(self, pars=None, metadata=None, **kwargs):
# If it's a number it's left as-is; otherwise it's converted to a dataframe
self.fertility_rate_data = self.standardize_fertility_data()
self.pars.fertility_rate = ss.bernoulli(self.make_fertility_prob_fn)

# For results tracking
self.n_pregnancies = 0
self.n_births = 0
return

@staticmethod
Expand Down Expand Up @@ -367,15 +378,16 @@ def update(self):
""" Perform all updates """
self.update_states()
conceive_uids = self.make_pregnancies()
self.n_pregnancies = len(conceive_uids)
self.make_embryos(conceive_uids)
self.update_results()
return

def update_states(self):
""" Update states """
# Check for new deliveries
ti = self.sim.ti
deliveries = self.pregnant & (self.ti_delivery <= ti)
self.n_births = np.count_nonzero(deliveries)
self.pregnant[deliveries] = False
self.postpartum[deliveries] = True
self.fecund[deliveries] = False
Expand Down Expand Up @@ -417,7 +429,7 @@ def make_embryos(self, conceive_uids):
new_uids = people.grow(len(new_slots), new_slots)
people.age[new_uids] = -self.pars.dur_pregnancy
people.slot[new_uids] = new_slots # Before sampling female_dist
people.female[new_uids] = self.pars.sex_ratio.rvs(new_uids)
people.female[new_uids] = self.pars.sex_ratio.rvs(conceive_uids)

# Add connections to any vertical transmission layers
# Placeholder code to be moved / refactored. The maternal network may need to be
Expand Down Expand Up @@ -457,8 +469,8 @@ def set_prognoses(self, uids):

def update_results(self):
ti = self.sim.ti
self.results['pregnancies'][ti] = np.count_nonzero(self.ti_pregnant == ti)
self.results['births'][ti] = np.count_nonzero(self.ti_delivery == ti)
self.results['pregnancies'][ti] = self.n_pregnancies
self.results['births'][ti] = self.n_births
return

def finalize(self):
Expand Down
3 changes: 3 additions & 0 deletions starsim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def step(self):
# Update results
self.people.update_results()

for dem_mod in self.demographics():
dem_mod.update_results()

for disease in self.diseases():
disease.update_results()

Expand Down
8 changes: 4 additions & 4 deletions tests/benchmark.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"time": {
"people": 0.001,
"initialize": 0.013,
"run": 0.383
"people": 0.002,
"initialize": 0.014,
"run": 0.358
},
"parameters": {
"n_agents": 10000,
"n_years": 20,
"dt": 0.2
},
"cpu_performance": 1.0009447278438974
"cpu_performance": 0.8271107100782993
}

0 comments on commit 86e7615

Please sign in to comment.