Skip to content

Commit

Permalink
Merge pull request #205 from amath-idm/age-distributions
Browse files Browse the repository at this point in the history
Age distributions
  • Loading branch information
cliffckerr authored Jan 21, 2024
2 parents 97ec1d8 + 164c99c commit 5d4182b
Show file tree
Hide file tree
Showing 16 changed files with 679 additions and 227 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ All notable changes to the codebase are documented in this file. Changes that ma
:local:
:depth: 1

Version 0.1.3 (2024-01-22)
--------------------------
- Read in age distributions for people initializations
- *GitHub info*: PR `205 <https://github.com/amath-idm/stisim/pull/205>`_


Version 0.1.2 (2024-01-19)
--------------------------
Expand Down
78 changes: 45 additions & 33 deletions stisim/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def standardize_birth_data(self):
return birth_rate

def init_results(self, sim):
self.results += ss.Result(self.name, 'new', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'cumulative', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'cbr', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'new', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'cumulative', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'cbr', sim.npts, dtype=int, scale=False)
return

def update(self, sim):
Expand All @@ -94,6 +94,7 @@ def add_births(self, sim):
# Add n_new births to each state in the sim
n_new = self.get_births(sim)
new_uids = sim.people.grow(n_new)
sim.people.age[new_uids] = 0
return new_uids

def update_results(self, n_new, sim):
Expand All @@ -102,7 +103,7 @@ def update_results(self, n_new, sim):
def finalize(self, sim):
super().finalize(sim)
self.results['cumulative'] = np.cumsum(self.results['new'])
self.results['cbr'] = np.divide(self.results['new'], sim.results['n_alive'], where=sim.results['n_alive']>0)
self.results['cbr'] = 1/self.pars.units*np.divide(self.results['new'], sim.results['n_alive'], where=sim.results['n_alive']>0)


class background_deaths(DemographicModule):
Expand Down Expand Up @@ -202,9 +203,9 @@ def standardize_death_data(self):
return death_rate

def init_results(self, sim):
self.results += ss.Result(self.name, 'new', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'cumulative', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'cmr', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'new', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'cumulative', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'cmr', sim.npts, dtype=int, scale=False)
return

def update(self, sim):
Expand All @@ -223,8 +224,9 @@ def update_results(self, n_deaths, sim):
self.results['new'][sim.ti] = n_deaths

def finalize(self, sim):
super().finalize(sim)
self.results['cumulative'] = np.cumsum(self.results['new'])
self.results['cmr'] = np.divide(self.results['new'], sim.results['n_alive'], where=sim.results['n_alive']>0)
self.results['cmr'] = 1/self.pars.units*np.divide(self.results['new'], sim.results['n_alive'], where=sim.results['n_alive']>0)


class Pregnancy(DemographicModule):
Expand All @@ -247,7 +249,7 @@ def __init__(self, pars=None, metadata=None):
'dur_postpartum': 0.5, # Make this a distribution?
'fertility_rate': 0, # Usually this will be provided in CSV format
'rel_fertility': 1,
'maternal_death_rate': 0.15,
'maternal_death_rate': 0,
'sex_ratio': 0.5, # Ratio of babies born female
'units': 1e-3, # Assumes fertility rates are per 1000. If using percentages, switch this to 1
}, self.pars)
Expand Down Expand Up @@ -287,25 +289,34 @@ def make_fertility_prob_fn(module, sim, uids):
val_label = module.metadata.data_cols['value']

available_years = module.pars.fertility_rate[year_label].unique()
year_ind = sc.findnearest(available_years, sim.year)
year_ind = sc.findnearest(available_years, sim.year-module.pars.dur_pregnancy)
nearest_year = available_years[year_ind]

df = module.pars.fertility_rate.loc[module.pars.fertility_rate[year_label] == nearest_year]
conception_arr = df[val_label].values
conception_arr = np.append(conception_arr, 0) # Add zeros for those outside data range
df_arr = df[val_label].values # Pull out dataframe values
df_arr = np.append(df_arr, 0) # Add zeros for those outside data range

# Process age data
age_bins = df[age_label].unique()
age_bins = np.append(age_bins, 50)
age_inds = np.digitize(sim.people.age[uids], age_bins) - 1
age_inds[age_inds>=max(age_inds)] = -1 # This ensures women outside the data range will get a value of 0

# Make array of fertility rates - TODO, check indexing works
age_inds[age_inds >= max(age_inds)] = -1 # This ensures women outside the data range will get a value of 0

# Adjust rates: rates are based on the entire population, but we need to remove
# anyone already pregnant and then inflate the rates for the remainder
pregnant_uids = ss.true(module.pregnant[uids]) # Find agents who are already pregnant
pregnant_age_counts, _ = np.histogram(sim.people.age[pregnant_uids], age_bins) # Count them by age
age_counts, _ = np.histogram(sim.people.age[uids], age_bins) # Count overall number per age bin
new_denom = age_counts - pregnant_age_counts # New denominator for rates
num_to_make = df_arr[:-1]*age_counts # Number that we need to make pregnant
new_percent = sc.dcp(df_arr) # Initialize array with new rates
inds_to_rescale = new_denom > 0 # Rescale any non-zero age bins
new_percent[:-1][inds_to_rescale] = num_to_make[inds_to_rescale] / new_denom[inds_to_rescale] # New rates

# Make array of fertility rates
fertility_rate = pd.Series(index=uids)
fertility_rate[uids] = conception_arr[age_inds]
fertility_rate[uids[sim.people.male[uids]]] = 0
fertility_rate[uids[(sim.people.age < 0)[uids]]] = 0
fertility_rate[uids[(sim.people.age > max(age_inds))[uids]]] = 0
fertility_rate[uids] = new_percent[age_inds]
fertility_rate[pregnant_uids] = 0

# Scale from rate to probability. Consider an exponential here.
fertility_prob = fertility_rate * (module.pars.units * module.pars.rel_fertility * sim.pars.dt)
Expand All @@ -330,16 +341,17 @@ def init_results(self, sim):
Still unclear whether this logic should live in the pregnancy module, the
individual disease modules, the connectors, or the sim.
"""
self.results += ss.Result(self.name, 'pregnancies', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'births', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'pregnancies', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'births', sim.npts, dtype=int, scale=True)
self.results += ss.Result(self.name, 'cbr', sim.npts, dtype=int, scale=False)
return

def update(self, sim):
"""
Perform all updates
"""
self.make_pregnancies(sim)
self.update_states(sim)
self.make_pregnancies(sim)
self.update_results(sim)
return

Expand All @@ -353,13 +365,11 @@ def update_states(self, sim):
self.pregnant[deliveries] = False
self.postpartum[deliveries] = True
self.susceptible[deliveries] = False
self.ti_delivery[deliveries] = sim.ti

# Check for new women emerging from post-partum
postpartum = ~self.pregnant & (self.ti_postpartum <= sim.ti)
self.postpartum[postpartum] = False
self.susceptible[postpartum] = True
self.ti_postpartum[postpartum] = sim.ti

# Maternal deaths
maternal_deaths = ss.true(self.ti_dead <= sim.ti)
Expand All @@ -370,14 +380,13 @@ def update_states(self, sim):
def make_pregnancies(self, sim):
"""
Select people to make pregnancy using incidence data
This should use ASFR data from https://population.un.org/wpp/Download/Standard/Fertility/
"""
# Abbreviate key variables
# Abbreviate
ppl = sim.people

# If incidence of pregnancy is non-zero, make some cases
# Think about how to deal with age/time-varying fertility
denom_conds = ppl.female & self.susceptible
# People eligible to become pregnant. We don't remove pregnant people here, these
# are instead handled in the fertility_dist logic as the rates need to be adjusted
denom_conds = ppl.female & ppl.alive
inds_to_choose_from = ss.true(denom_conds)
uids = self.fertility_dist.filter(inds_to_choose_from)

Expand All @@ -390,10 +399,9 @@ def make_pregnancies(self, sim):

# Grow the arrays and set properties for the unborn agents
new_uids = sim.people.grow(len(new_slots))

sim.people.age[new_uids] = -self.pars.dur_pregnancy
sim.people.slot[new_uids] = new_slots # Before sampling female_dist
sim.people.female[new_uids] = self.sex_dist.rvs(uids)
sim.people.slot[new_uids] = new_slots # Before sampling female_dist
sim.people.female[new_uids] = self.sex_dist.rvs(new_uids)

# Add connections to any vertical transmission layers
# Placeholder code to be moved / refactored. The maternal network may need to be
Expand All @@ -404,7 +412,7 @@ def make_pregnancies(self, sim):
layer.add_pairs(uids, new_uids, dur=durs)

# Set prognoses for the pregnancies
self.set_prognoses(sim, uids) # Could set from_uids to network partners?
self.set_prognoses(sim, uids) # Could set from_uids to network partners?

return

Expand Down Expand Up @@ -436,3 +444,7 @@ def update_results(self, sim):
self.results['pregnancies'][sim.ti] = np.count_nonzero(self.ti_pregnant == sim.ti)
self.results['births'][sim.ti] = np.count_nonzero(self.ti_delivery == sim.ti)
return

def finalize(self, sim):
super().finalize(sim)
self.results['cbr'] = 1/self.pars.units * np.divide(self.results['births'], sim.results['n_alive'], where=sim.results['n_alive']>0)
10 changes: 5 additions & 5 deletions stisim/diseases/disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,15 @@ def init_results(self, sim):
Result for 'n_susceptible'
"""
for state in self._boolean_states:
self.results += ss.Result(self.name, f'n_{state.name}', sim.npts, dtype=int)
self.results += ss.Result(self.name, f'n_{state.name}', sim.npts, dtype=int, scale=True)
return

def finalize_results(self, sim):
"""
Finalize results
"""
# TODO - will probably need to account for rescaling outputs for the default results here
pass
super().finalize_results(sim)
return

def update_pre(self, sim):
"""
Expand Down Expand Up @@ -299,8 +299,8 @@ def init_results(self, sim):
Initialize results
"""
super().init_results(sim)
self.results += ss.Result(self.name, 'prevalence', sim.npts, dtype=float)
self.results += ss.Result(self.name, 'new_infections', sim.npts, dtype=int)
self.results += ss.Result(self.name, 'prevalence', sim.npts, dtype=float, scale=False)
self.results += ss.Result(self.name, 'new_infections', sim.npts, dtype=int, scale=True)
return

def update_pre(self, sim):
Expand Down
9 changes: 9 additions & 0 deletions stisim/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,18 @@ def initialize(self, sim):
return

def finalize(self, sim):
self.finalize_results(sim)
self.finalized = True
return

def finalize_results(self, sim):
"""
Finalize results
"""
# Scale results
for reskey, res in self.results.items():
if isinstance(res, ss.Result) and res.scale:
self.results[reskey] = self.results[reskey]*sim.pars.pop_scale
@property
def states(self):
"""
Expand Down
33 changes: 20 additions & 13 deletions stisim/people.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import sciris as sc
import stisim as ss
from scipy.stats import bernoulli, uniform
import scipy.stats as sps

__all__ = ['BasePeople', 'People']

Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, n, age_data=None, extra_states=None, networks=None, rand_seed
# Handle states
states = [
ss.State('age', float, np.nan), # NaN until conceived
ss.State('female', bool, bernoulli(p=0.5)),
ss.State('female', bool, sps.bernoulli(p=0.5)),
ss.State('debut', float),
ss.State('ti_dead', int, ss.INT_NAN), # Time index for death
ss.State('alive', bool, True), # Time index for death
Expand All @@ -191,17 +191,24 @@ def __init__(self, n, age_data=None, extra_states=None, networks=None, rand_seed
self.networks = ss.Networks(networks)

# Set initial age distribution - likely move this somewhere else later
self.age_data_dist = self.get_age_dist(age_data)
self.age_data = age_data
self.age_dist_gen = sps.uniform() # Store a uniform distribution for generating ages

return

@staticmethod
def get_age_dist(age_data):
def get_age_dist(self):
""" Return an age distribution based on provided data """
if age_data is None:
return uniform(loc=0, scale=100) # low and width
if sc.checktype(age_data, pd.DataFrame):
return ss.data_dist(vals=age_data['value'].values, bins=age_data['age'].values)
age_draws = self.age_dist_gen.rvs(size=np.max(self.slot) + 1)
if self.age_data is None:
return age_draws * 100
if sc.checktype(self.age_data, pd.DataFrame):
bins = self.age_data['age'].values
vals = self.age_data['value'].values
bin_midpoints = bins[:-1] + np.diff(bins) / 2
cdf = np.cumsum(vals)
cdf = cdf / cdf[-1]
value_bins = np.searchsorted(cdf, age_draws)
return bin_midpoints[value_bins]

def _initialize_states(self, sim=None):
for state in self.states.values():
Expand All @@ -222,8 +229,8 @@ def initialize(self, sim):

# Define age (CK: why is age handled differently than sex?)
self._initialize_states(sim=sim) # Now initialize with the sim
self.age[:] = self.age_data_dist.rvs(len(self))
self.age[:] = self.get_age_dist()

self.initialized = True
return

Expand Down Expand Up @@ -322,8 +329,8 @@ def m(self):
return self.male

def init_results(self, sim):
sim.results += ss.Result(None, 'n_alive', sim.npts, ss.int_)
sim.results += ss.Result(None, 'new_deaths', sim.npts, ss.int_)
sim.results += ss.Result(None, 'n_alive', sim.npts, ss.int_, scale=True)
sim.results += ss.Result(None, 'new_deaths', sim.npts, ss.int_, scale=True)
return

def update_results(self, sim):
Expand Down
4 changes: 3 additions & 1 deletion stisim/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

class Result(np.ndarray):

def __new__(cls, module=None, name=None, shape=None, dtype=None):
def __new__(cls, module=None, name=None, shape=None, dtype=None, scale=None):
arr = np.zeros(shape=shape, dtype=dtype).view(cls)
arr.name = name
arr.module = module
arr.scale = scale
return arr

def __repr__(self):
Expand All @@ -29,6 +30,7 @@ def __array_finalize__(self, obj):
return
self.name = getattr(obj, 'name', None)
self.module = getattr(obj, 'module', None)
self.scale = getattr(obj, 'scale', None)
return

def __array_wrap__(self, obj, **kwargs):
Expand Down
Loading

0 comments on commit 5d4182b

Please sign in to comment.