diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2e57e94b..e5bc96fd 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,9 +7,25 @@ What's new All notable changes to the codebase are documented in this file. Changes that may result in differences in model output, or are required in order to run an old parameter set with the current version, are flagged with the term "Regression information". +Version 0.5.2 (2024-06-04) +-------------------------- +- Renames ``network.contacts`` to ``network.edges``. +- For modules (including diseases, networks, etc.), renames ``initialize()`` to ``init_pre()`` and ``init_vals()`` to ``init_post()``. +- Renames ``ss.delta()`` to ``ss.constant()``. +- Allows ``Arr`` objects to be indexed by integer (which are assumed to be UIDs). +- Fixes bug when using callable parameters with ``ss.lognorm_ex()`` and ``ss.lognorm_im()``. +- Fixes bug when initializing ``ss.StaticNet()``. +- Updates default birth rate from 0 to 30 (so ``demographics=True`` is meaningful). +- Adds ``min_age`` and ``max_age`` parameters to the ``Pregnancy`` module (with defaults 15 and 50 years). +- Adds an option for the ``sir_vaccine`` to be all-or-nothing instead of leaky. +- Updates baseline test from HIV to SIR + SIS. +- Fixes issue with infection log not being populated. +- *GitHub info*: PR `527 `_ + + Version 0.5.1 (2024-05-15) -------------------------- -- Separates maternal transmission into prenatal and postnatal modules +- Separates maternal transmission into prenatal and postnatal modules. - *GitHub info*: PR `509 `_ diff --git a/README.rst b/README.rst index a19cefdf..e8e1caf3 100644 --- a/README.rst +++ b/README.rst @@ -3,9 +3,19 @@ Starsim **Warning! Starsim is still in the early stages of development. It is being shared solely for transparency and to facilitate collaborative development. It is not ready to be used for real research or policy questions.** -Starsim is an agent-based disease modeling framework in which users can design and configure simulations of pathogens that progress over time within each agent and pass from one agent to the next along dynamic transmission networks. The framework explicitly supports co-transmission of multiple pathogens, allowing users to concurrently simulate several diseases while capturing behavioral and biological interactions. Non-communicable diseases can easily be included as well, either as a co-factor for transmissible pathogens or as an independent exploration. Detailed modeling of mother-child relationships can be simulated from the timepoint of conception, enabling study of congenital diseases and associated birth outcomes. Finally, Starsim facilitates the comparison of one or more intervention scenarios to a baseline scenario in evaluating the impact of various products like vaccines, therapeutics, and novel diagnostics delivered via flexible routes including mass campaigns, screen and treat, and targeted outreach. +Starsim is an agent-based modeling framework in which users can design and configure simulations of diseases (or other health states) that progress over time within each agent and pass from one agent to the next along dynamic transmission networks. The framework explicitly supports co-transmission of multiple pathogens, allowing users to concurrently simulate several diseases while capturing behavioral and biological interactions. Non-communicable diseases can be included as well, either as a co-factor for transmissible pathogens or as an independent exploration. Detailed modeling of mother-child relationships can be simulated from the timepoint of conception, enabling study of congenital diseases and associated birth outcomes. Finally, Starsim facilitates the comparison of one or more intervention scenarios to a baseline scenario in evaluating the impact of various products like vaccines, therapeutics, and novel diagnostics delivered via flexible routes including mass campaigns, screen and treat, and targeted outreach. -The framework is appropriate for simulating one or more sexually transmitted infections (including syphilis, gonorrhea, chlamydia, HPV, and HIV), respiratory infections (like RSV and tuberculosis), and other diseases and underlying determinants (such as Ebola, diabetes, and malnutrition). +The framework is appropriate for simulating sexually transmitted infections (including syphilis, gonorrhea, chlamydia, HPV, and HIV, including co-transmission), respiratory infections (like RSV and tuberculosis), and other diseases and underlying determinants (such as Ebola, diabetes, and malnutrition). + +Starsim is a general-purpose modeling framework that is part of the same suite of tools as `Covasim `_, `HPVsim `_, and `FPsim `_. + + +Requirements +------------ + +Python 3.9-3.12. + +We recommend, but do not require, installing Starsim in a virtual environment, such as `Anaconda `__. Installation @@ -19,15 +29,68 @@ Starsim can also be installed locally. To do this, clone first this repository, Usage and documentation ----------------------- -Documentation is available at https://docs.starsim.org. - -Usage examples are available in the ``tests`` folder. +Documentation, including tutorials and an API reference, is available at https://docs.starsim.org. + +If everything is working, the following Python commands will run a simulation with the simplest version of a Starsim model. We'll make a version of a classic SIR model:: + + import starsim as ss + + # Define the parameters + pars = dict( + n_agents = 5_000, # Number of agents to simulate + networks = dict( # *Networks* add detail on how agents interact w/ each other + type = 'random', # Here, we use a 'random' network + n_contacts = 10 # Each person has an average of 10 contacts w/ other people + ), + diseases = dict( # *Diseases* add detail on what diseases to model + type = 'sir', # Here, we're creating an SIR disease + init_prev = 0.1, # Proportion of the population initially infected + beta = 0.5, # Probability of transmission between contacts + ) + ) + + # Make the sim, run and plot + sim = ss.Sim(pars) + sim.run() + sim.plot() + +More usage examples are available in the ``tests`` folder. + + +Model structure +--------------- + +All core model code is located in the ``starsim`` subfolder; standard usage is ``import starsim as ss``. + +The model consists of core classes including Sim, Run, People, State, Network, Connectors, Analyzers, Interventions, Results, and more. These classes contain methods for running, building simple or dynamic networks, generating random numbers, calculating results, plotting, etc. + +The structure of the starsim folder is as follows, roughly in the order in which the modules are imported, building from most fundamental to most complex: + +• ``demographics.py``: Classes to transform initial condition input parameters for use in building and utilizing networks. +• ``disease.py``: Classes to manage infection rate of spread, prevalence, waning effects, and other parameters for specific diseases. +• ``distributions.py``: Classes that handle statistical distributions used throughout Starsim. +• ``interventions.py``: The Intervention class, for adding interventions and dynamically modifying parameters, and classes for each of the specific interventions derived from it. The Analyzers class (for performing analyses on the sim while it's running), and other classes and functions for analyzing simulations. +• ``modules.py``: Class to handle "module" logic, such as updates (diseases, networks, etc). +• ``network.py``: Classes for creating simple and dynamic networks of people based on input parameters. +• ``parameters.py``: Classes for creating the simulation parameters. +• ``people.py``: The People class, for handling updates of state for each person. +• ``products.py``: Classes to manage the deployment of vaccines and treatments. +• ``results.py``: Classes to analyze and save results from simulations. +• ``run.py``: Classes for running simulations (e.g. parallel runs and the Scenarios and MultiSim classes). +• ``samples.py``: Class to store data from a large number of simulations. +• ``settings.py``: User-customizable options for Starsim (e.g. default font size). +• ``sim.py``: The Sim class, which performs most of the heavy lifting: initializing the model, running, and plotting. +• ``states.py``: Classes to handle store and update states for people in networks in the simulation including living, mother, child, susceptible, infected, inoculated, recovered, etc. +• ``utils.py``: Helper functions. +• ``version.py``: Version, date, and license information. + +The ``diseases`` folder within the Starsim package contains loading scripts for the epidemiological data specific to each respective disease. Contributing ------------ -If you wish to contribute, please see the code of conduct and contributing documents. +Questions or comments can be directed to `info@starsim.org `__ , or on this project’s `GitHub `__ page. Full information about Starsim is provided in the `documentation `__. Disclaimer diff --git a/starsim/__init__.py b/starsim/__init__.py index 78e308d4..1e873df3 100644 --- a/starsim/__init__.py +++ b/starsim/__init__.py @@ -27,5 +27,7 @@ print(__license__) # Double-check key requirements -- should match setup.py -sc.require(['sciris>=3.1.6', 'pandas>=2.0.0', 'scipy', 'numba', 'networkx'], message=f'The following dependencies for Starsim {__version__} were not met: .') -del sc # Don't keep this in the module \ No newline at end of file +reqs = ['sciris>=3.1.6', 'pandas>=2.0.0', 'scipy', 'numba', 'networkx'] +msg = f'The following dependencies for Starsim {__version__} were not met: .' +sc.require(reqs, message=msg) +del sc, reqs, msg # Don't keep this in the module \ No newline at end of file diff --git a/starsim/demographics.py b/starsim/demographics.py index 5de00cbd..56af1f83 100644 --- a/starsim/demographics.py +++ b/starsim/demographics.py @@ -16,8 +16,8 @@ class Demographics(ss.Module): place at the start of the timestep, before networks are updated and before any disease modules are executed. """ - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.init_results() return @@ -36,7 +36,7 @@ class Births(Demographics): def __init__(self, pars=None, metadata=None, **kwargs): super().__init__() self.default_pars( - birth_rate = 0, + birth_rate = 30, rel_birth = 1, units = 1e-3, # assumes birth rates are per 1000. If using percentages, switch this to 1 ) @@ -54,9 +54,9 @@ def __init__(self, pars=None, metadata=None, **kwargs): self.n_births = 0 # For results tracking return - def initialize(self, sim): + def init_pre(self, sim): """ Initialize with sim information """ - super().initialize(sim) + super().init_pre(sim) if isinstance(self.pars.birth_rate, pd.DataFrame): br_year = self.pars.birth_rate[self.metadata.data_cols['year']] br_val = self.pars.birth_rate[self.metadata.data_cols['cbr']] @@ -265,6 +265,8 @@ def __init__(self, pars=None, metadata=None, **kwargs): rel_fertility = 1, maternal_death_prob = ss.bernoulli(0), sex_ratio = ss.bernoulli(0.5), # Ratio of babies born female + min_age = 15, # Minimum age to become pregnant + max_age = 50, # Maximum age to become pregnant units = 1e-3, # Assumes fertility rates are per 1000. If using percentages, switch this to 1 ) self.update_pars(pars, **kwargs) @@ -342,9 +344,12 @@ def make_fertility_prob_fn(self, sim, uids): fertility_rate = pd.Series(index=uids) fertility_rate[uids] = new_percent[age_inds] - # Scale from rate to probability. Consider an exponential here. + # Scale from rate to probability + age = self.sim.people.age[uids] + invalid_age = (age < self.pars.min_age) | (age > self.pars.max_age) fertility_prob = fertility_rate * (self.pars.units * self.pars.rel_fertility * sim.pars.dt) fertility_prob[self.pregnant.uids] = 0 # Currently pregnant women cannot become pregnant again + fertility_prob[uids[invalid_age]] = 0 # Women too young or old cannot become pregnant fertility_prob = np.clip(fertility_prob, a_min=0, a_max=1) return fertility_prob @@ -354,8 +359,8 @@ def standardize_fertility_data(self): fertility_rate = ss.standardize_data(data=self.pars.fertility_rate, metadata=self.metadata) return fertility_rate - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) low = sim.pars.n_agents + 1 high = int(sim.pars.slot_scale*sim.pars.n_agents) self.choose_slots = ss.randint(low=low, high=high, sim=sim, module=self) @@ -402,9 +407,9 @@ def update_states(self): prenatalnet = [nw for nw in self.sim.networks.values() if nw.prenatal][0] # Find the prenatal connections that are ending - prenatal_ending = prenatalnet.contacts.end<=self.sim.ti - new_mother_uids = prenatalnet.contacts.p1[prenatal_ending] - new_infant_uids = prenatalnet.contacts.p2[prenatal_ending] + prenatal_ending = prenatalnet.edges.end<=self.sim.ti + new_mother_uids = prenatalnet.edges.p1[prenatal_ending] + new_infant_uids = prenatalnet.edges.p2[prenatal_ending] # Validation if not np.array_equal(new_mother_uids, deliveries.uids): diff --git a/starsim/disease.py b/starsim/disease.py index aa71a612..21fd504b 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -15,10 +15,10 @@ class Disease(ss.Module): """ Base module class for diseases """ - def __init__(self, *args, **kwargs): + def __init__(self, log=True, *args, **kwargs): super().__init__(*args, **kwargs) self.results = ss.Results(self.name) - self.log = InfectionLog() # See below for definition + self.log = InfectionLog() if log else None # See below for definition return @property @@ -34,8 +34,9 @@ def _boolean_states(self): yield state return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + """ Link the disease to the sim, create objects, and initialize results; see Module.init_pre() for details """ + super().init_pre(sim) self.init_results() return @@ -88,7 +89,7 @@ def make_new_cases(self): """ pass - def set_prognoses(self, target_uids, source_uids=None): + def set_prognoses(self, uids, source_uids=None): """ Set prognoses upon infection/acquisition @@ -105,13 +106,14 @@ def set_prognoses(self, target_uids, source_uids=None): uids (array): UIDs for agents to assign disease progoses to from_uids (array): Optionally specify the infecting agent """ - sim = self.sim - if source_uids is None: - for target in target_uids: - self.log.append(np.nan, target, sim.year) - else: - for target, source in zip(target_uids, source_uids): - self.log.append(source, target, sim.year) + if self.log is not None: + sim = self.sim + if source_uids is None: + for target in uids: + self.log.append(np.nan, target, sim.year) + else: + for target, source in zip(uids, source_uids): + self.log.append(source, target, sim.year) return def update_results(self): @@ -152,8 +154,8 @@ def __init__(self, *args, **kwargs): self.rng_source = ss.random(name='source') return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.validate_beta() return @@ -188,7 +190,7 @@ def infectious(self): """ return self.infected - def init_vals(self): + def init_post(self): """ Set initial values for states. This could involve passing in a full set of initial conditions, or using init_prev, or other. Note that this is different to initialization of the Arr objects @@ -256,12 +258,12 @@ def make_new_cases(self): break nbetas = betamap[nkey] - contacts = net.contacts + edges = net.edges rel_trans = self.rel_trans.asnew(self.infectious * self.rel_trans) rel_sus = self.rel_sus.asnew(self.susceptible * self.rel_sus) - p1p2b0 = [contacts.p1, contacts.p2, nbetas[0]] - p2p1b1 = [contacts.p2, contacts.p1, nbetas[1]] + p1p2b0 = [edges.p1, edges.p2, nbetas[0]] + p2p1b1 = [edges.p2, edges.p1, nbetas[1]] for src, trg, beta in [p1p2b0, p2p1b1]: # Skip networks with no transmission diff --git a/starsim/diseases/gonorrhea.py b/starsim/diseases/gonorrhea.py index f3dba8a7..fb63f9e6 100644 --- a/starsim/diseases/gonorrhea.py +++ b/starsim/diseases/gonorrhea.py @@ -34,9 +34,7 @@ def __init__(self, pars=None, *args, **kwargs): return def init_results(self): - """ - Initialize results - """ + """ Initialize results """ super().init_results() self.results += ss.Result(self.name, 'new_clearances', self.sim.npts, dtype=int) return @@ -49,10 +47,7 @@ def update_results(self): return def update_pre(self): - # What if something in here should depend on another module? - # I guess we could just check for it e.g., 'if HIV in sim.modules' or - # 'if 'hiv' in sim.people' or something - # Natural clearance + """ Natural clearance """ clearances = self.ti_clearance <= self.sim.ti self.susceptible[clearances] = True self.infected[clearances] = False @@ -61,24 +56,22 @@ def update_pre(self): return - def set_prognoses(self, target_uids, source_uids=None): - """ - Natural history of gonorrhea for adult infection - """ - super().set_prognoses(target_uids, source_uids) + def set_prognoses(self, uids, source_uids=None): + """ Natural history of gonorrhea for adult infection """ + super().set_prognoses(uids, source_uids) ti = self.sim.ti # Set infection status - self.susceptible[target_uids] = False - self.infected[target_uids] = True - self.ti_infected[target_uids] = ti + self.susceptible[uids] = False + self.infected[uids] = True + self.ti_infected[uids] = ti # Set infection status - symp_uids = self.pars.p_symp.filter(target_uids) + symp_uids = self.pars.p_symp.filter(uids) self.symptomatic[symp_uids] = True # Set natural clearance - clear_uids = self.pars.p_clear.filter(target_uids) + clear_uids = self.pars.p_clear.filter(uids) dur = ti + self.pars.dur_inf_in_days.rvs(clear_uids)/365/self.sim.dt # Convert from days to years and then adjust for dt self.ti_clearance[clear_uids] = dur return \ No newline at end of file diff --git a/starsim/diseases/hiv.py b/starsim/diseases/hiv.py index 65e8976f..bf451971 100644 --- a/starsim/diseases/hiv.py +++ b/starsim/diseases/hiv.py @@ -76,8 +76,8 @@ def set_prognoses(self, uids, source_uids=None): self.ti_infected[uids] = self.sim.ti return - def set_congenital(self, target_uids, source_uids): - return self.set_prognoses(target_uids, source_uids) + def set_congenital(self, uids, source_uids): + return self.set_prognoses(uids, source_uids) # %% HIV-related interventions @@ -94,8 +94,8 @@ def __init__(self, year: np.array, coverage: np.array, **kwargs): self.prob_art_at_infection = ss.bernoulli(p=lambda self, sim, uids: np.interp(sim.year, self.year, self.coverage)) return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.results += ss.Result(self.name, 'n_art', sim.npts, dtype=int) self.initialized = True return @@ -129,8 +129,8 @@ def __init__(self): self.cd4 = None return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.cd4 = np.zeros((sim.npts, sim.people.n), dtype=int) return diff --git a/starsim/diseases/measles.py b/starsim/diseases/measles.py index 669fcaba..e3f058b7 100644 --- a/starsim/diseases/measles.py +++ b/starsim/diseases/measles.py @@ -59,8 +59,7 @@ def update_pre(self): def set_prognoses(self, uids, source_uids=None): """ Set prognoses for those who get infected """ - # Do not call set_prognosis on parent - # super().set_prognoses(sim, uids, source_uids) + super().set_prognoses(uids, source_uids) ti = self.sim.ti dt = self.sim.dt diff --git a/starsim/diseases/ncd.py b/starsim/diseases/ncd.py index a4cf2151..daa9bd2a 100644 --- a/starsim/diseases/ncd.py +++ b/starsim/diseases/ncd.py @@ -38,14 +38,14 @@ def __init__(self, pars=None, **kwargs): def not_at_risk(self): return ~self.at_risk - def init_vals(self): + def init_post(self): """ Set initial values for states. This could involve passing in a full set of initial conditions, or using init_prev, or other. Note that this is different to initialization of the State objects i.e., creating their dynamic array, linking them to a People instance. That should have already taken place by the time this method is called. """ - super().init_vals() + super().init_post() initial_risk = self.pars['initial_risk'].filter() self.at_risk[initial_risk] = True self.ti_affected[initial_risk] = self.sim.ti + sc.randround(self.pars['dur_risk'].rvs(initial_risk) / self.sim.dt) diff --git a/starsim/diseases/sir.py b/starsim/diseases/sir.py index 991845ea..a21c40c0 100644 --- a/starsim/diseases/sir.py +++ b/starsim/diseases/sir.py @@ -20,7 +20,7 @@ class SIR(ss.Infection): def __init__(self, pars=None, **kwargs): super().__init__() self.default_pars( - beta = 0.5, + beta = 0.1, init_prev = ss.bernoulli(p=0.01), dur_inf = ss.lognorm_ex(mean=6), p_death = ss.bernoulli(p=0.01), @@ -49,6 +49,7 @@ def update_pre(self): def set_prognoses(self, uids, source_uids=None): """ Set prognoses """ + super().set_prognoses(uids, source_uids) ti = self.sim.ti dt = self.sim.dt self.susceptible[uids] = False @@ -67,7 +68,6 @@ def set_prognoses(self, uids, source_uids=None): rec_uids = uids[~will_die] self.ti_dead[dead_uids] = ti + dur_inf[will_die] / dt # Consider rand round, but not CRN safe self.ti_recovered[rec_uids] = ti + dur_inf[~will_die] / dt - return def update_death(self, uids): @@ -127,6 +127,7 @@ def update_immunity(self): def set_prognoses(self, uids, source_uids=None): """ Set prognoses """ + super().set_prognoses(uids, source_uids) self.susceptible[uids] = False self.infected[uids] = True self.ti_infected[uids] = self.sim.ti @@ -167,14 +168,30 @@ def plot(self): class sir_vaccine(ss.Vx): """ - Create a vaccine product that changes susceptible people to recovered (i.e., perfect immunity) + Create a vaccine product that affects the probability of infection. + + The vaccine can be either "leaky", in which everyone who receives the vaccine + receives the same amount of protection (specified by the efficacy parameter) + each time they are exposed to an infection. The alternative (leaky=False) is + that the efficacy is the probability that the vaccine "takes", in which case + that person is 100% protected (and the remaining people are 0% protected). + + Args: + efficacy (float): efficacy of the vaccine (0<=efficacy<=1) + leaky (bool): see above """ def __init__(self, pars=None, *args, **kwargs): super().__init__() - self.default_pars(efficacy=0.9) + self.default_pars( + efficacy = 0.9, + leaky = True + ) self.update_pars(pars, **kwargs) return - def administer(self, people, uids): - people.sir.rel_sus[uids] *= 1-self.pars.efficacy + def administer(self, people, uids): + if self.pars.leaky: + people.sir.rel_sus[uids] *= 1-self.pars.efficacy + else: + people.sir.rel_sus[uids] *= np.random.binomial(1, 1-self.pars.efficacy, len(uids)) return diff --git a/starsim/diseases/syphilis.py b/starsim/diseases/syphilis.py index 27bc1ab4..b75b80b7 100644 --- a/starsim/diseases/syphilis.py +++ b/starsim/diseases/syphilis.py @@ -195,6 +195,7 @@ def set_prognoses(self, uids, source_uids=None): """ Set initial prognoses for adults newly infected with syphilis """ + super().set_prognoses(uids, source_uids) ti = self.sim.ti dt = self.sim.dt @@ -256,7 +257,7 @@ def set_latent_long_prognoses(self, uids): return - def set_congenital(self, target_uids, source_uids=None): + def set_congenital(self, uids, source_uids=None): """ Natural history of syphilis for congenital infection """ sim = self.sim @@ -264,18 +265,18 @@ def set_congenital(self, target_uids, source_uids=None): for state in ['active', 'latent']: source_state_inds = getattr(self, state)[source_uids].nonzero()[0] - uids = target_uids[source_state_inds] + state_uids = uids[source_state_inds] - if len(uids) > 0: + if len(state_uids) > 0: # Birth outcomes must be modified to add probability of susceptible birth birth_outcomes = self.pars.birth_outcomes[state] - assigned_outcomes = birth_outcomes.rvs(len(uids)) + assigned_outcomes = birth_outcomes.rvs(len(state_uids)) time_to_birth = -sim.people.age.raw # TODO: make nicer # Schedule events for oi, outcome in enumerate(self.pars.birth_outcome_keys): - o_uids = uids[assigned_outcomes == oi] + o_uids = state_uids[assigned_outcomes == oi] if len(o_uids) > 0: ti_outcome = f'ti_{outcome}' vals = getattr(self, ti_outcome) @@ -344,8 +345,8 @@ def check_eligibility(self, sim): is_eligible = sim.people.auids # Probably not required return is_eligible - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.results += [ ss.Result('syphilis', 'n_screened', sim.npts, dtype=int, scale=True), ss.Result('syphilis', 'n_dx', sim.npts, dtype=int, scale=True), @@ -368,8 +369,8 @@ def _parse_product_str(self, product): else: return products[product] - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.results += ss.Result('syphilis', 'n_tx', sim.npts, dtype=int, scale=True) return diff --git a/starsim/distributions.py b/starsim/distributions.py index 41bc0d7b..2d65c0aa 100644 --- a/starsim/distributions.py +++ b/starsim/distributions.py @@ -450,8 +450,8 @@ def process_pars(self, call=True): """ Ensure the supplied dist and parameters are valid, and initialize them; called automatically """ self._pars = sc.cp(self.pars) # The actual keywords; shallow copy, modified below for special cases if call: - self.call_pars() - spars = self.sync_pars() + self.call_pars() # Convert from function to values if needed + spars = self.sync_pars() # Synchronize parameters between the NumPy and SciPy distributions return spars def call_pars(self): @@ -587,7 +587,7 @@ def plot_hist(self, n=1000, bins=None, fig_kw=None, hist_kw=None): # Add common distributions so they can be imported directly; assigned to a variable since used in help messages dist_list = ['random', 'uniform', 'normal', 'lognorm_ex', 'lognorm_im', 'expon', - 'poisson', 'weibull', 'delta', 'randint', 'bernoulli', 'choice'] + 'poisson', 'weibull', 'constant', 'randint', 'bernoulli', 'choice'] __all__ += dist_list @@ -668,10 +668,11 @@ def convert_ex_to_im(self): parameters of the underlying (implicit) distribution, which are the form expected by NumPy's and SciPy's lognorm() distributions. """ + self.call_pars() # Since can't work with functions p = self._pars mean = p.pop('mean') stdev = p.pop('stdev') - if mean <= 0: + if np.isscalar(mean) and mean <= 0: errormsg = f'Cannot create a lognorm_ex distribution with mean≤0 (mean={mean}); did you mean to use lognorm_im instead?' raise ValueError(errormsg) std2 = stdev**2 @@ -728,10 +729,10 @@ def make_rvs(self): return rvs -class delta(Dist): - """ Delta distribution: equivalent to np.full() """ +class constant(Dist): + """ Constant (delta) distribution: equivalent to np.full() """ def __init__(self, v=0, **kwargs): - super().__init__(distname='delta', v=v, **kwargs) + super().__init__(distname='const', v=v, **kwargs) return def make_rvs(self): diff --git a/starsim/interventions.py b/starsim/interventions.py index eb14b253..4c405002 100644 --- a/starsim/interventions.py +++ b/starsim/interventions.py @@ -18,8 +18,8 @@ def __init__(self, *args, **kwargs): def __call__(self, *args, **kwargs): return self.apply(*args, **kwargs) - def initialize(self, sim): - return super().initialize(sim) + def init_pre(self, sim): + return super().init_pre(sim) def apply(self, sim): pass @@ -115,7 +115,7 @@ def __init__(self, years=None, start_year=None, end_year=None, prob=None, annual self.coverage_dist = ss.bernoulli(p=0) # Placeholder - initialize delivery return - def initialize(self, sim): + def init_pre(self, sim): # Validate inputs if (self.years is not None) and (self.start_year is not None or self.end_year is not None): @@ -173,7 +173,7 @@ def __init__(self, years, interpolate=None, prob=None, *args, **kwargs): self.prob = sc.promotetoarray(prob) return - def initialize(self, sim): + def init_pre(self, sim): # Decide whether to apply the intervention at every timepoint throughout the year, or just once. self.timepoints = sc.findnearest(sim.yearvec, self.years) @@ -213,8 +213,8 @@ def __init__(self, product=None, prob=None, eligibility=None, **kwargs): self.ti_screened = ss.FloatArr('ti_screened') return - def initialize(self, sim): - Intervention.initialize(self, sim) + def init_pre(self, sim): + Intervention.init_pre(self, sim) self.outcomes = {k: np.array([], dtype=int) for k in self.product.hierarchy} return @@ -296,9 +296,9 @@ def __init__(self, product=None, prob=None, eligibility=None, RoutineDelivery.__init__(self, prob=prob, start_year=start_year, end_year=end_year, years=years) return - def initialize(self, sim): - RoutineDelivery.initialize(self, sim) # Initialize this first, as it ensures that prob is interpolated properly - BaseScreening.initialize(self, sim) # Initialize this next + def init_pre(self, sim): + RoutineDelivery.init_pre(self, sim) # Initialize this first, as it ensures that prob is interpolated properly + BaseScreening.init_pre(self, sim) # Initialize this next return @@ -319,9 +319,9 @@ def __init__(self, product=None, sex=None, eligibility=None, CampaignDelivery.__init__(self, prob=prob, years=years, interpolate=interpolate) return - def initialize(self, sim): - CampaignDelivery.initialize(self, sim) - BaseScreening.initialize(self, sim) # Initialize this next + def init_pre(self, sim): + CampaignDelivery.init_pre(self, sim) + BaseScreening.init_pre(self, sim) # Initialize this next return @@ -343,9 +343,9 @@ def __init__(self, product=None, prob=None, eligibility=None, annual_prob=annual_prob) return - def initialize(self, sim): - RoutineDelivery.initialize(self, sim) # Initialize this first, as it ensures that prob is interpolated properly - BaseTriage.initialize(self, sim) # Initialize this next + def init_pre(self, sim): + RoutineDelivery.init_pre(self, sim) # Initialize this first, as it ensures that prob is interpolated properly + BaseTriage.init_pre(self, sim) # Initialize this next return @@ -366,9 +366,9 @@ def __init__(self, product=None, sex=None, eligibility=None, CampaignDelivery.__init__(self, prob=prob, years=years, interpolate=interpolate, annual_prob=annual_prob) return - def initialize(self, sim): - CampaignDelivery.initialize(self, sim) - BaseTriage.initialize(self, sim) + def init_pre(self, sim): + CampaignDelivery.init_pre(self, sim) + BaseTriage.init_pre(self, sim) return @@ -394,8 +394,8 @@ def __init__(self, product=None, prob=None, eligibility=None, **kwargs): self.coverage_dist = ss.bernoulli(p=0) # Placeholder return - def initialize(self, sim): - Intervention.initialize(self, sim) + def init_pre(self, sim): + Intervention.init_pre(self, sim) self.outcomes = {k: np.array([], dtype=int) for k in ['unsuccessful', 'successful']} # Store outcomes on each timestep return @@ -536,9 +536,9 @@ def __init__(self, product=None, prob=None, eligibility=None, RoutineDelivery.__init__(self, prob=prob, start_year=start_year, end_year=end_year, years=years) return - def initialize(self, sim): - RoutineDelivery.initialize(self, sim) # Initialize this first, as it ensures that prob is interpolated properly - BaseVaccination.initialize(self, sim) # Initialize this next + def init_pre(self, sim): + RoutineDelivery.init_pre(self, sim) # Initialize this first, as it ensures that prob is interpolated properly + BaseVaccination.init_pre(self, sim) # Initialize this next return @@ -555,7 +555,7 @@ def __init__(self, product=None, prob=None, eligibility=None, CampaignDelivery.__init__(self, prob=prob, years=years, interpolate=interpolate) return - def initialize(self, sim): - CampaignDelivery.initialize(self, sim) # Initialize this first, as it ensures that prob is interpolated properly - BaseVaccination.initialize(self, sim) # Initialize this next + def init_pre(self, sim): + CampaignDelivery.init_pre(self, sim) # Initialize this first, as it ensures that prob is interpolated properly + BaseVaccination.init_pre(self, sim) # Initialize this next return diff --git a/starsim/modules.py b/starsim/modules.py index 5d816c1e..b8e2a82b 100644 --- a/starsim/modules.py +++ b/starsim/modules.py @@ -104,13 +104,16 @@ def check_requires(self, sim): raise Exception(errormsg) return - def initialize(self, sim): + def init_pre(self, sim): """ Perform initialization steps This method is called once, as part of initializing a Sim. Note: after - initialization, initialized=False until init_vals() is called (which is after + initialization, initialized=False until init_post() is called (which is after distributions are initialized). + + Note: distributions cannot be used here because they aren't initialized + until after init_pre() is called. Use init_post() instead. """ self.check_requires(sim) self.sim = sim # Link back to the sim object @@ -120,8 +123,8 @@ def initialize(self, sim): sim.people.add_module(self) # Connect the states to the people return - def init_vals(self): - """ Initialize the values of the states; the last step of initialization """ + def init_post(self): + """ Initialize the values of the states, including calling distributions; the last step of initialization """ for state in self.states: if not state.initialized: state.init_vals() diff --git a/starsim/network.py b/starsim/network.py index d090fc3a..a2315831 100644 --- a/starsim/network.py +++ b/starsim/network.py @@ -75,13 +75,13 @@ def __init__(self, key_dict=None, prenatal=False, postnatal=False, name=None, la self.postnatal = postnatal # Postnatal connections are added at the time of delivery. Requires ss.Pregnancy() # Initialize the keys of the network - self.contacts = sc.objdict() + self.edges = sc.objdict() for key, dtype in self.meta.items(): - self.contacts[key] = np.empty((0,), dtype=dtype) + self.edges[key] = np.empty((0,), dtype=dtype) # Set data, if provided for key, value in kwargs.items(): - self.contacts[key] = np.array(value, dtype=self.meta.get(key)) # Overwrite dtype if supplied, else keep original + self.edges[key] = np.array(value, dtype=self.meta.get(key)) # Overwrite dtype if supplied, else keep original self.initialized = True # Define states using placeholder values @@ -91,24 +91,24 @@ def __init__(self, key_dict=None, prenatal=False, postnatal=False, name=None, la @property def p1(self): - return self.contacts['p1'] if 'p1' in self.contacts else None + return self.edges['p1'] if 'p1' in self.edges else None @property def p2(self): - return self.contacts['p2'] if 'p2' in self.contacts else None + return self.edges['p2'] if 'p2' in self.edges else None @property def beta(self): - return self.contacts['beta'] if 'beta' in self.contacts else None + return self.edges['beta'] if 'beta' in self.edges else None - def init_vals(self, add_pairs=True): - super().init_vals() + def init_post(self, add_pairs=True): + super().init_post() if add_pairs: self.add_pairs() return def __len__(self): try: - return len(self.contacts.p1) + return len(self.edges.p1) except: # pragma: no cover return 0 @@ -116,7 +116,7 @@ def __repr__(self, **kwargs): """ Convert to a dataframe for printing """ namestr = self.name labelstr = f'"{self.label}"' if self.label else '' - keys_str = ', '.join(self.contacts.keys()) + keys_str = ', '.join(self.edges.keys()) output = f'{namestr}({labelstr}, {keys_str})\n' # e.g. Network("r", p1, p2, beta) output += self.to_df().__repr__() return output @@ -130,12 +130,12 @@ def __contains__(self, item): Returns: True if person index appears in any interactions """ - return (item in self.contacts.p1) or (item in self.contacts.p2) # TODO: chek if (item in self.members) is faster + return (item in self.edges.p1) or (item in self.edges.p2) # TODO: chek if (item in self.members) is faster @property def members(self): """ Return sorted array of all members """ - return np.unique([self.contacts.p1, self.contacts.p2]).view(ss.uids) + return np.unique([self.edges.p1, self.edges.p2]).view(ss.uids) def meta_keys(self): """ Return the keys for the network's meta information """ @@ -152,12 +152,12 @@ def set_network_states(self, people): def validate_uids(self): """ Ensure that p1, p2 are both UID arrays """ - contacts = self.contacts + edges = self.edges for key in ['p1', 'p2']: - if key in contacts: - arr = contacts[key] + if key in edges: + arr = edges[key] if not isinstance(arr, ss.uids): - self.contacts[key] = ss.uids(arr) + self.edges[key] = ss.uids(arr) return def validate(self, force=True): @@ -167,14 +167,14 @@ def validate(self, force=True): If dtype is incorrect, try to convert automatically; if length is incorrect, do not. """ - n = len(self.contacts.p1) + n = len(self.edges.p1) for key, dtype in self.meta.items(): if dtype: - actual = self.contacts[key].dtype + actual = self.edges[key].dtype expected = dtype if actual != expected: - self.contacts[key] = np.array(self.contacts[key], dtype=expected) # Try to convert to correct type - actual_n = len(self.contacts[key]) + self.edges[key] = np.array(self.edges[key], dtype=expected) # Try to convert to correct type + actual_n = len(self.edges[key]) if n != actual_n: errormsg = f'Expecting length {n} for network key "{key}"; got {actual_n}' # Report length mismatches raise TypeError(errormsg) @@ -190,9 +190,9 @@ def get_inds(self, inds, remove=False): """ output = {} for key in self.meta_keys(): - output[key] = self.contacts[key][inds] # Copy to the output object + output[key] = self.edges[key][inds] # Copy to the output object if remove: - self.contacts[key] = np.delete(self.contacts[key], inds) # Remove from the original + self.edges[key] = np.delete(self.edges[key], inds) # Remove from the original self.validate_uids() return output @@ -216,19 +216,19 @@ def append(self, contacts=None, **kwargs): """ contacts = sc.mergedicts(contacts, kwargs) for key in self.meta_keys(): - curr_arr = self.contacts[key] + curr_arr = self.edges[key] try: new_arr = contacts[key] except KeyError: errormsg = f'Cannot append contacts since required key "{key}" is missing' raise KeyError(errormsg) - self.contacts[key] = np.concatenate([curr_arr, new_arr]) # Resize to make room, preserving dtype + self.edges[key] = np.concatenate([curr_arr, new_arr]) # Resize to make room, preserving dtype self.validate_uids() return def to_dict(self): """ Convert to dictionary """ - d = {k: self.contacts[k] for k in self.meta_keys()} + d = {k: self.edges[k] for k in self.meta_keys()} return d def to_df(self): @@ -241,7 +241,7 @@ def from_df(self, df, keys=None): if keys is None: keys = self.meta_keys() for key in keys: - self.contacts[key] = df[key].to_numpy() + self.edges[key] = df[key].to_numpy() return self def find_contacts(self, inds, as_array=True): @@ -277,7 +277,7 @@ def find_contacts(self, inds, as_array=True): inds = np.array(inds, dtype=np.int64) # Find the contacts - contact_inds = ss.find_contacts(self.contacts.p1, self.contacts.p2, inds) + contact_inds = ss.find_contacts(self.edges.p1, self.edges.p2, inds) if as_array: contact_inds = np.fromiter(contact_inds, dtype=ss_int_) contact_inds.sort() @@ -298,15 +298,15 @@ def remove_uids(self, uids): This method is typically called via `People.remove()` and is specifically used when removing agents from the simulation. """ - keep = ~(np.isin(self.contacts.p1, uids) | np.isin(self.contacts.p2, uids)) + keep = ~(np.isin(self.edges.p1, uids) | np.isin(self.edges.p2, uids)) for k in self.meta_keys(): - self.contacts[k] = self.contacts[k][keep] + self.edges[k] = self.edges[k][keep] return def beta_per_dt(self, disease_beta=None, dt=None, uids=None): if uids is None: uids = Ellipsis - return self.contacts.beta[uids] * disease_beta * dt + return self.edges.beta[uids] * disease_beta * dt class DynamicNetwork(Network): @@ -318,12 +318,12 @@ def __init__(self, key_dict=None, **kwargs): def end_pairs(self): people = self.sim.people - self.contacts.dur = self.contacts.dur - self.sim.dt + self.edges.dur = self.edges.dur - self.sim.dt # Non-alive agents are removed - active = (self.contacts.dur > 0) & people.alive[self.contacts.p1] & people.alive[self.contacts.p2] + active = (self.edges.dur > 0) & people.alive[self.edges.p1] & people.alive[self.edges.p2] for k in self.meta_keys(): - self.contacts[k] = self.contacts[k][active] + self.edges[k] = self.edges[k][active] return len(active) @@ -355,7 +355,7 @@ def available(self, people, sex): def beta_per_dt(self, disease_beta=None, dt=None, uids=None): if uids is None: uids = Ellipsis - return self.contacts.beta[uids] * (1 - (1 - disease_beta) ** (self.contacts.acts[uids] * dt)) + return self.edges.beta[uids] * (1 - (1 - disease_beta) ** (self.edges.acts[uids] * dt)) # %% Specific instances of networks @@ -387,24 +387,23 @@ class StaticNet(Network): def __init__(self, graph=None, pars=None, **kwargs): super().__init__() self.graph = graph - self.default_pars(seed=True) + self.default_pars(seed=True, p=None, n_contacts=10) self.update_pars(pars, **kwargs) self.dist = ss.Dist(name='StaticNet') return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.n_agents = sim.pars.n_agents if self.graph is None: self.graph = nx.fast_gnp_random_graph # Fast random (Erdos-Renyi) graph creator - if 'p' not in self.pars and 'n_contacts' not in self.pars: # TODO: refactor - self.pars.n_contacts = 10 - if 'n_contacts' in self.pars: # Convert from n_contacts to probability - self.pars.p = self.pars.pop('n_contacts')/self.n_agents + n_contacts = self.pars.pop('n_contacts') # Remove from pars dict, but use only if p is not supplied + if self.pars.p is None: # Convert from n_contacts to probability + self.pars.p = n_contacts/self.n_agents return - def init_vals(self): - super().init_vals() + def init_post(self): + super().init_post() if 'seed' in self.pars and self.pars.seed is True: self.pars.seed = self.dist.rng if callable(self.graph): @@ -447,14 +446,14 @@ def __init__(self, pars=None, key_dict=None, **kwargs): """ Initialize """ super().__init__(key_dict=key_dict) self.default_pars( - n_contacts = ss.delta(10), + n_contacts = ss.constant(10), dur = 0, ) self.update_pars(pars, **kwargs) self.dist = ss.Dist(distname='RandomNet') # Default RNG return - def init_vals(self): + def init_post(self): self.add_pairs() return @@ -546,8 +545,8 @@ def __init__(self, n_people=None, **kwargs): super().__init__(**kwargs) return - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) popsize = sim.pars['n_agents'] if self.n is None: self.n = popsize @@ -584,7 +583,7 @@ def __init__(self, pars=None, key_dict=None, **kwargs): self.dist = ss.choice(name='MFNet', replace=False) # Set the array later return - def init_vals(self): + def init_post(self): self.set_network_states() self.add_pairs() return @@ -671,11 +670,11 @@ def __init__(self, pars=None, key_dict=None, **kwargs): self.update_pars(pars, **kwargs) return - def initialize(self, sim): + def init_pre(self, sim): # Add more here in line with MF network, e.g. age of debut # Or if too much replication then perhaps both these networks # should be subclasss of a specific network type (ask LY/DK) - super().initialize(sim) + super().init_pre(sim) self.set_network_states(sim.people) self.add_pairs(sim.people, ti=0) return @@ -795,15 +794,15 @@ def update(self): Set beta to 0 for women who complete duration of transmission Keep connections for now, might want to consider removing """ - inactive = self.contacts.end <= self.sim.ti - self.contacts.beta[inactive] = 0 + inactive = self.edges.end <= self.sim.ti + self.edges.beta[inactive] = 0 return def end_pairs(self): people = self.sim.people - active = (self.contacts.end > self.sim.ti) & people.alive[self.contacts.p1] & people.alive[self.contacts.p2] + active = (self.edges.end > self.sim.ti) & people.alive[self.edges.p1] & people.alive[self.edges.p2] for k in self.meta_keys(): - self.contacts[k] = self.contacts[k][active] + self.edges[k] = self.edges[k][active] return len(active) def add_pairs(self, mother_inds=None, unborn_inds=None, dur=None, start=None): @@ -811,7 +810,8 @@ def add_pairs(self, mother_inds=None, unborn_inds=None, dur=None, start=None): if mother_inds is None: return 0 else: - if start is None: start = np.full_like(dur, fill_value=self.sim.ti) + if start is None: + start = np.full_like(dur, fill_value=self.sim.ti) n = len(mother_inds) beta = np.ones(n) end = start + sc.promotetoarray(dur) / self.sim.dt diff --git a/starsim/products.py b/starsim/products.py index f5e7a6b2..bcaea4f7 100644 --- a/starsim/products.py +++ b/starsim/products.py @@ -12,9 +12,9 @@ class Product(ss.Module): """ Generic product implementation """ - def initialize(self, sim): + def init_pre(self, sim): if not self.initialized: - super().initialize(sim) + super().init_pre(sim) else: return diff --git a/starsim/sim.py b/starsim/sim.py index 8771a69e..a717f7a2 100644 --- a/starsim/sim.py +++ b/starsim/sim.py @@ -68,12 +68,12 @@ def initialize(self, **kwargs): # Initialize all the modules with the sim for mod in self.modules: - mod.initialize(self) + mod.init_pre(self) # Initialize products # TODO: think about simplifying for mod in self.interventions: if hasattr(mod, 'product') and isinstance(mod.product, ss.Product): - mod.product.initialize(self) + mod.product.init_pre(self) # Initialize all distributions now that everything else is in place, then set states self.dists.initialize(obj=self, base_seed=self.pars.rand_seed, force=True) @@ -133,7 +133,7 @@ def init_vals(self): # Initialize values in other modules, including networks for mod in self.modules: - mod.init_vals() + mod.init_post() return def init_results(self): diff --git a/starsim/states.py b/starsim/states.py index 845a089e..4c864ec8 100644 --- a/starsim/states.py +++ b/starsim/states.py @@ -107,35 +107,26 @@ def _convert_key(self, key): the raw array (``raw``) or the active agents (``values``), and to convert the key to array indices if needed. """ - if isinstance(key, uids): - use_raw = True + if isinstance(key, (uids, int)): + return key elif isinstance(key, (BoolArr, IndexArr)): - use_raw = True - key = key.uids + return key.uids elif isinstance(key, slice): - use_raw = False + return self.auids[key] elif not np.isscalar(key) and len(key) == 0: # Handle [], np.array([]), etc. - use_raw = True # Doesn't matter since returning empty, but this is faster - key = uids() + return uids() else: errormsg = f'Indexing an Arr ({self.name}) by ({key}) is ambiguous or not supported. Use ss.uids() instead, or index Arr.raw or Arr.values.' raise Exception(errormsg) - - return key, use_raw def __getitem__(self, key): - key, use_raw = self._convert_key(key) - if use_raw: - return self.raw[key] - else: - return self.values[key] + key = self._convert_key(key) + return self.raw[key] def __setitem__(self, key, value): - key, use_raw = self._convert_key(key) - if use_raw: - self.raw[key] = value - else: - self.raw[self.auids[key]] = value + key = self._convert_key(key) + self.raw[key] = value + return def __getattr__(self, attr): """ Make it behave like a regular array mostly -- enables things like sum(), mean(), etc. """ diff --git a/starsim/version.py b/starsim/version.py index de39d231..11c9f8e5 100644 --- a/starsim/version.py +++ b/starsim/version.py @@ -4,6 +4,6 @@ __all__ = ['__version__', '__versiondate__', '__license__'] -__version__ = '0.5.1' -__versiondate__ = '2024-05-15' +__version__ = '0.5.2' +__versiondate__ = '2024-06-04' __license__ = f'Starsim {__version__} ({__versiondate__}) — © 2023-2024 by IDM' diff --git a/tests/baseline.json b/tests/baseline.json index 04479f95..43bc42f1 100644 --- a/tests/baseline.json +++ b/tests/baseline.json @@ -1,18 +1,26 @@ { "summary": { "yearvec": 2010.1000000000026, - "pregnancy_pregnancies": 0.0, - "pregnancy_births": 0.0, - "pregnancy_cbr": 0.0, - "hiv_n_susceptible": 8867.64705882353, - "hiv_n_infected": 794.6078431372549, - "hiv_n_on_art": 0.0, - "hiv_prevalence": 0.0823032321868553, - "hiv_new_infections": 15.323529411764707, - "hiv_cum_infections": 1563.0, - "hiv_new_deaths": 6.931372549019608, - "n_alive": 9662.254901960785, - "new_deaths": 6.931372549019608, - "cum_deaths": 692.0 + "births_new": 65.25490196078431, + "births_cumulative": 3262.4313725490197, + "births_cbr": 29.713410489238804, + "deaths_new": 44.68627450980392, + "deaths_cumulative": 2227.549019607843, + "deaths_cmr": 20.34315186671041, + "sir_n_susceptible": 2429.970588235294, + "sir_n_infected": 3579.8823529411766, + "sir_n_recovered": 4970.764705882353, + "sir_prevalence": 0.334889084274871, + "sir_new_infections": 128.2156862745098, + "sir_cum_infections": 13078.0, + "sis_n_susceptible": 3951.1960784313724, + "sis_n_infected": 7029.421568627451, + "sis_prevalence": 0.627871954232106, + "sis_new_infections": 209.7843137254902, + "sis_cum_infections": 21398.0, + "sis_rel_sus": 0.46047714966479397, + "n_alive": 10980.617647058823, + "new_deaths": 45.705882352941174, + "cum_deaths": 4625.0 } } \ No newline at end of file diff --git a/tests/benchmark.json b/tests/benchmark.json index 8078aac1..aefe3c26 100644 --- a/tests/benchmark.json +++ b/tests/benchmark.json @@ -1,13 +1,12 @@ { "time": { - "people": 0.002, - "initialize": 0.014, - "run": 0.358 + "initialize": 0.015, + "run": 1.082 }, "parameters": { "n_agents": 10000, "n_years": 20, "dt": 0.2 }, - "cpu_performance": 0.8271107100782993 + "cpu_performance": 0.9927627642937098 } \ No newline at end of file diff --git a/tests/devtests/devtest_art_impact_viz.py b/tests/devtests/devtest_art_impact_viz.py index b3c46ddf..fd38f3be 100644 --- a/tests/devtests/devtest_art_impact_viz.py +++ b/tests/devtests/devtest_art_impact_viz.py @@ -40,12 +40,12 @@ def __init__(self, **kwargs): super().__init__(**kwargs) return - def initialize(self, sim): + def init_pre(self, sim): n = len(sim.people._uid_map) n_edges = n//2 - self.contacts.p1 = np.arange(0, 2*n_edges, 2) # EVEN - self.contacts.p2 = np.arange(1, 2*n_edges, 2) # ODD - self.contacts.beta = np.ones(n_edges) + self.edges.p1 = np.arange(0, 2*n_edges, 2) # EVEN + self.edges.p2 = np.arange(1, 2*n_edges, 2) # ODD + self.edges.beta = np.ones(n_edges) return @@ -90,7 +90,7 @@ def __init__(self, **kwargs): self.graphs = {} return - def initialize(self, sim): + def init_pre(self, sim): self.initialized = True self.update_results(sim, init=True) return diff --git a/tests/devtests/devtest_remove_people.py b/tests/devtests/devtest_remove_people.py index 93e9d120..93d876b2 100644 --- a/tests/devtests/devtest_remove_people.py +++ b/tests/devtests/devtest_remove_people.py @@ -10,8 +10,8 @@ import sciris as sc class agent_analyzer(ss.Analyzer): - def initialize(self, sim): - super().initialize(sim) + def init_pre(self, sim): + super().init_pre(sim) self.n_agents = np.zeros(sim.npts) def update_results(self, sim): diff --git a/tests/test_baselines.py b/tests/test_baselines.py index 66796a57..45f8714c 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -7,8 +7,6 @@ import sciris as sc import starsim as ss -do_plot = True -do_save = False baseline_filename = sc.thisdir(__file__, 'baseline.json') benchmark_filename = sc.thisdir(__file__, 'benchmark.json') parameters_filename = sc.thisdir(ss.__file__, 'regression', f'pars_v{ss.__version__}.json') @@ -16,38 +14,26 @@ # Define the parameters pars = sc.objdict( - start = 2000, # Starting year - n_years = 20, # Number of years to simulate - dt = 0.2, # Timestep - verbose = 0, # Don't print details of the run - rand_seed = 2, # Set a non-default seed + n_agents = 10e3, # Number of agents + start = 2000, # Starting year + n_years = 20, # Number of years to simulate + dt = 0.2, # Timestep + verbose = 0, # Don't print details of the run + rand_seed = 2, # Set a non-default seed ) -def make_people(): - ss.set_seed(pars.rand_seed) - n_agents = int(10e3) - ppl = ss.People(n_agents=n_agents) - return ppl - -def make_sim(ppl=None, do_run=False, **kwargs): - ''' - Define a default simulation for testing the baseline, including - interventions to increase coverage. If run directly (not via pytest), also - plot the sim by default. - ''' - - if ppl is None: - ppl = make_people() - - # Make the sim - hiv = ss.HIV() - hiv.pars.beta = {'mf': [0.15, 0.10], 'maternal': [0.2, 0]} - networks = [ss.MFNet(), ss.MaternalNet()] - sim = ss.Sim(pars=pars, people=ppl, networks=networks, demographics=ss.Pregnancy(), diseases=hiv) +def make_sim(run=False): + """ + Define a default simulation for testing the baseline. If run directly (not + via pytest), also plot the sim by default. + """ + diseases = ['sir', 'sis'] + networks = ['random', 'mf', 'maternal'] + sim = ss.Sim(pars=pars, networks=networks, diseases=diseases, demographics=True) # Optionally run and plot - if do_run: + if run: sim.run() sim.plot() @@ -55,29 +41,28 @@ def make_sim(ppl=None, do_run=False, **kwargs): def save_baseline(): - ''' + """ Refresh the baseline results. This function is not called during standard testing, but instead is called by the update_baseline script. - ''' + """ + sc.heading('Updating baseline values...') - print('Updating baseline values...') - - # Export default parameters - s1 = make_sim(use_defaults=True) - s1.export_pars(filename=parameters_filename) # If not different from previous version, can safely delete + # Make and run sim + sim = make_sim() + sim.run() # Export results - s2 = make_sim(use_defaults=False) - s2.run() - s2.to_json(filename=baseline_filename, keys='summary') + sim.to_json(filename=baseline_filename, keys='summary') + + # CK: To restore once export_pars is fixed + # sim.export_pars(filename=parameters_filename) # If not different from previous version, can safely delete print('Done.') - return def test_baseline(): - ''' Compare the current default sim against the saved baseline ''' + """ Compare the current default sim against the saved baseline """ # Load existing baseline baseline = sc.loadjson(baseline_filename) @@ -93,8 +78,8 @@ def test_baseline(): return new -def test_benchmark(do_save=do_save, repeats=1, verbose=True): - ''' Compare benchmark performance ''' +def test_benchmark(do_save=False, repeats=1, verbose=True): + """ Compare benchmark performance """ if verbose: print('Running benchmark...') try: @@ -102,12 +87,11 @@ def test_benchmark(do_save=do_save, repeats=1, verbose=True): except FileNotFoundError: previous = None - t_peoples = [] - t_inits = [] - t_runs = [] + t_inits = [] + t_runs = [] def normalize_performance(): - ''' Normalize performance across CPUs ''' + """ Normalize performance across CPUs """ t_bls = [] bl_repeats = 3 n_outer = 10 @@ -132,16 +116,11 @@ def normalize_performance(): # Do the actual benchmarking for r in range(repeats): - print("Repeat ", r) - - # Time people - t0 = sc.tic() - ppl = make_people() - t_people = sc.toc(t0, output=True) + print(f'Repeat {r}') # Time initialization t0 = sc.tic() - sim = make_sim(ppl, verbose=0) + sim = make_sim() sim.initialize() t_init = sc.toc(t0, output=True) @@ -151,30 +130,25 @@ def normalize_performance(): t_run = sc.toc(t0, output=True) # Store results - t_peoples.append(t_people) t_inits.append(t_init) t_runs.append(t_run) - - # print(t_people, t_init, t_run) # Test CPU performance after the run r2 = normalize_performance() ratio = (r1+r2)/2 - t_people = min(t_peoples)*ratio - t_init = min(t_inits)*ratio - t_run = min(t_runs)*ratio + t_init = ratio*min(t_inits) + t_run = ratio*min(t_runs) # Construct json n_decimals = 3 json = {'time': { - 'people': round(t_people, n_decimals), 'initialize': round(t_init, n_decimals), 'run': round(t_run, n_decimals), }, 'parameters': { - 'n_agents': sim.pars['n_agents'], - 'n_years': sim.pars['n_years'], - 'dt': sim.pars['dt'], + 'n_agents': sim.pars.n_agents, + 'n_years': sim.pars.n_years, + 'dt': sim.pars.dt, }, 'cpu_performance': ratio, } @@ -200,17 +174,13 @@ def normalize_performance(): return json - if __name__ == '__main__': - - # Start timing and optionally enable interactive plotting + do_plot = True sc.options(interactive=do_plot) - T = sc.tic() + T = sc.timer() - json = test_benchmark(do_save=do_save, repeats=5) # Run this first so benchmarking is available even if results are different + json = test_benchmark() # Run this first so benchmarking is available even if results are different new = test_baseline() - sim = make_sim(do_run=do_plot) + sim = make_sim(run=do_plot) - print('\n'*2) - sc.toc(T) - print('Done.') + T.toc() diff --git a/tests/test_diseases.py b/tests/test_diseases.py index 429c0052..c88579fc 100644 --- a/tests/test_diseases.py +++ b/tests/test_diseases.py @@ -37,11 +37,6 @@ def test_sir(): sim = ss.Sim(people=ppl, diseases=sir, networks=networks) sim.run() - # CK: parameters changed - # assert len(sir.log.out_edges(np.nan)) == sir.pars.initial # Log should match initial infections - df = sir.log.line_list # Check generation of line-list # TODO: fix - # assert df.source.isna().sum() == sir.pars.initial # Check seed infections in line list - plt.figure() res = sim.results plt.stackplot( diff --git a/tests/test_dist.py b/tests/test_dist.py index c9a5d5ab..eaedd018 100644 --- a/tests/test_dist.py +++ b/tests/test_dist.py @@ -84,7 +84,7 @@ def test_dists(n=n, do_plot=False): obj = sc.prettyobj() obj.a = sc.objdict() obj.a.mylist = [ss.random(), ss.Dist(distname='uniform', low=2, high=3)] - obj.b = dict(d3=ss.weibull(c=2), d4=ss.delta(v=0.3)) + obj.b = dict(d3=ss.weibull(c=2), d4=ss.constant(v=0.3)) dists = ss.Dists(obj).initialize(sim=make_sim()) # Call each distribution twice @@ -200,16 +200,19 @@ def custom_loc(module, sim, uids): return out scale = 1 - d = ss.normal(name='callable', loc=custom_loc, scale=scale).initialize(sim=sim) + d1 = ss.normal(name='callable', loc=custom_loc, scale=scale).initialize(sim=sim) + d2 = ss.lognorm_ex(name='callable', mean=custom_loc, stdev=scale).initialize(sim=sim) uids = np.array([1, 3, 7, 9]) - draws = d.rvs(uids) + draws1 = d1.rvs(uids) + draws2 = d2.rvs(uids) print(f'Input ages were: {sim.people.age[uids]}') - print(f'Output samples were: {draws}') + print(f'Output samples were: {draws1}, {draws2}') - meandiff = np.abs(sim.people.age[uids] - draws).mean() - assert meandiff < scale*3 - return d + for draws in [draws1, draws2]: + meandiff = np.abs(sim.people.age[uids] - draws).mean() + assert meandiff < scale*3, 'Outputs should match ages' + return d1 def test_array(n=n): diff --git a/tests/test_interventions.py b/tests/test_interventions.py new file mode 100644 index 00000000..1f7f05e7 --- /dev/null +++ b/tests/test_interventions.py @@ -0,0 +1,103 @@ +""" +Run tests of vaccines +""" + +# %% Imports and settings +import sciris as sc +import numpy as np +import starsim as ss + + +def run_sir_vaccine(efficacy, leaky=True): + # parameters + v_frac = 0.5 # fraction of population vaccinated + total_cases = 500 # total cases at which point we check results + tol = 3 # tolerance in standard deviations for simulated checks + + # create a basic SIR sim + sim = ss.Sim( + n_agents = 1000, + pars = dict( + networks = dict( + type = 'random', + n_contacts = 4 + ), + diseases = dict( + type = 'sir', + init_prev = 0.01, + dur_inf = 0.1, + p_death = 0, + beta = 6, + ) + ), + n_years = 10, + dt = 0.01 + ) + sim.initialize(verbose=False) + + # work out who to vaccinate + in_trial = sim.people.sir.susceptible.uids + n_vac = round(len(in_trial) * v_frac) + in_vac = np.random.choice(in_trial, n_vac, replace=False) + in_pla = np.setdiff1d(in_trial, in_vac) + uids = ss.uids(in_vac) + + # create and apply the vaccination + vac = ss.sir_vaccine(efficacy=efficacy, leaky=leaky) + vac.init_pre(sim) + vac.administer(sim.people, uids) + + # check the relative susceptibility at the start of the simulation + rel_susc = sim.people.sir.rel_sus.values + assert min(rel_susc[in_pla]) == 1.0, 'Placebo arm is not fully susceptible' + if not leaky: + assert min(rel_susc[in_vac]) == 0.0, 'Nobody fully vaccinated (all_or_nothing)' + assert max(rel_susc[in_vac]) == 1.0, 'Vaccine effective in everyone (all_or_nothing)' + mean = n_vac * (1 - efficacy) + sd = np.sqrt(n_vac * efficacy * (1 - efficacy)) + assert (np.mean(rel_susc[in_vac]) - mean) / sd < tol, 'Incorrect mean susceptibility in vaccinated (all_or_nothing)' + else: + assert max(abs(rel_susc[in_vac] - (1 - efficacy))) < 0.0001, 'Relative susceptibility not 1-efficacy (leaky)' + + # run the simulation until sufficient cases + old_cases = [] + for idx in range(1000): + sim.step() + susc = sim.people.sir.susceptible.uids + cases = np.setdiff1d(in_trial, susc) + if len(cases) > total_cases: + break + old_cases = cases + + if len(cases) > total_cases: + cases = np.concatenate([old_cases, np.random.choice(np.setdiff1d(cases, old_cases), total_cases - len(old_cases), replace=False)]) + vac_cases = np.intersect1d(cases, in_vac) + + # check to see whether the number of cases are as expected + p = v_frac * (1 - efficacy) / (1 - efficacy * v_frac) + mean = total_cases * p + sd = np.sqrt(total_cases * p * (1 - p)) + assert (len(vac_cases) - mean) / sd < tol, 'Incorrect proportion of vaccincated infected' + + # for all or nothing check that fully vaccinated did not get infected + if not leaky: + assert len(np.intersect1d(vac_cases, in_vac[rel_susc[in_vac] == 1.0])) == len(vac_cases), 'Not all vaccine cases amongst vaccine failures (all or nothing)' + assert len(np.intersect1d(vac_cases, in_vac[rel_susc[in_vac] == 0.0])) == 0, 'Vaccine cases amongst fully vaccincated (all or nothing)' + + return sim + + +def test_sir_vaccine_leaky(): + return run_sir_vaccine(0.3, False) + +def test_sir_vaccine_all_or_nothing(): + return run_sir_vaccine(0.3, True) + + +if __name__ == '__main__': + T = sc.timer() + + sir_vaccine_leaky = test_sir_vaccine_leaky(leaky=True) + sir_vaccine_a_or_n = test_sir_vaccine_all_or_nothing(leaky=False) + + T.toc() diff --git a/tests/test_networks.py b/tests/test_networks.py new file mode 100644 index 00000000..085866e9 --- /dev/null +++ b/tests/test_networks.py @@ -0,0 +1,103 @@ +""" +Test networks +""" + +# %% Imports and settings +import sciris as sc +import numpy as np +import starsim as ss + +sc.options(interactive=False) # Assume not running interactively + +small = 100 +medium = 1000 + +# %% Define the tests + +def test_manual(): + sc.heading('Testing manual networks') + + # Make completely abstract layers + n_edges = 10_000 + n_agents = medium + p1 = np.random.randint(n_agents, size=n_edges) + p2 = np.random.randint(n_agents, size=n_edges) + beta = np.ones(n_edges) + nw1 = ss.Network(p1=p1, p2=p2, beta=beta, label='rand') + + # Create a maternal network + sim = ss.Sim(n_agents=n_agents) + sim.initialize() + nw2 = ss.MaternalNet() + nw2.init_pre(sim) + nw2.add_pairs(mother_inds=[1, 2, 3], unborn_inds=[100, 101, 102], dur=[1, 1, 1]) + + # Tidy + o = sc.objdict(nw1=nw1, nw2=nw2) + return o + + +def test_random(): + sc.heading('Testing random networks') + + # Manual creation + nw1 = ss.RandomNet() + ss.Sim(n_agents=small, networks=nw1, copy_inputs=False).initialize() # This initializes the network + + # Automatic creation as part of sim + s2 = ss.Sim(n_agents=small, networks='random').initialize() + nw2 = s2.networks[0] + + # Increase the number of contacts + nwdict = dict(type='random', n_contacts=20) + s3 = ss.Sim(n_agents=small, networks=nwdict).initialize() + nw3 = s3.networks[0] + + # Checks + assert np.array_equal(nw1.p2, nw2.p2), 'Implicit and explicit creation should give the same network' + assert len(nw3) == len(nw2)*2, 'Doubling n_contacts should produce twice as many contacts' + + # Tidy + o = sc.objdict(nw1=nw1, nw2=nw2, nw3=nw3) + return o + + +def test_static(): + sc.heading('Testing static networks') + + # Create with p + p = 0.2 + n = 100 + nc = p*n + nd1 = dict(type='static', p=p) + nw1 = ss.Sim(n_agents=n, networks=nd1).initialize().networks[0] + + # Create with n_contacts + nd2 = dict(type='static', n_contacts=nc) + nw2 = ss.Sim(n_agents=n, networks=nd2).initialize().networks[0] + + # Check + assert len(nw1) == len(nw2), 'Networks should be the same length' + target = n*n*p/2 + assert target/2 < len(nw1) < target*2, f'Network should be approximately length {target}' + + # Tidy + o = sc.objdict(nw1=nw1, nw2=nw2) + return o + + +# %% Run as a script +if __name__ == '__main__': + do_plot = True + sc.options(interactive=do_plot) + + # Start timing + T = sc.tic() + + # Run tests + man = test_manual() + rnd = test_random() + sta = test_static() + + sc.toc(T) + print('Done.') \ No newline at end of file diff --git a/tests/test_other.py b/tests/test_other.py index a9f0cbe5..06f673a6 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -1,5 +1,5 @@ """ -Test objects from base.py +Test Starsim features not covered by other test files """ # %% Imports and settings @@ -37,27 +37,6 @@ def geo_func(n): return ppl -def test_networks(): - sc.heading('Testing networks') - - # Make completely abstract layers - n_edges = 10_000 - n_people = medium - p1 = np.random.randint(n_people, size=n_edges) - p2 = np.random.randint(n_people, size=n_edges) - beta = np.ones(n_edges) - nw1 = ss.Network(p1=p1, p2=p2, beta=beta, label='rand') - - sim = ss.Sim() - sim.initialize() - - nw2 = ss.MaternalNet() - nw2.initialize(sim) - nw2.add_pairs(mother_inds=[1, 2, 3], unborn_inds=[100, 101, 102], dur=[1, 1, 1]) - - return nw1, nw2 - - def test_microsim(do_plot=False): sc.heading('Test small HIV simulation') @@ -183,7 +162,6 @@ def test_deepcopy_until(): # Run tests ppl = test_people() - nw1, nw2 = test_networks() sim1 = test_microsim(do_plot) sim2 = test_ppl_construction() sims = test_arrs() diff --git a/tests/test_randomness.py b/tests/test_randomness.py index 38f9b465..2b38f672 100644 --- a/tests/test_randomness.py +++ b/tests/test_randomness.py @@ -110,7 +110,7 @@ def test_order(n=n): class CountInf(ss.Intervention): """ Store every infection state in a timepoints x people array """ - def initialize(self, sim): + def init_pre(self, sim): n_agents = len(sim.people) self.arr = np.zeros((sim.npts, n_agents)) self.n_agents = n_agents @@ -123,9 +123,9 @@ def apply(self, sim): class OneMore(ss.Intervention): """ Add one additional agent and infection """ - def initialize(self, sim): + def init_pre(self, sim): one_birth = ss.Pregnancy(name='one_birth', rel_fertility=0) # Ensure no default births - one_birth.initialize(sim) + one_birth.init_pre(sim) self.one_birth = one_birth return @@ -238,7 +238,7 @@ def test_independence(do_plot=False, thresh=0.1): ], networks = [ dict(type='random', n_contacts=ss.poisson(8)), - dict(type='mf', debut=ss.delta(0), participation=0.5), # To avoid age correlations + dict(type='mf', debut=ss.constant(0), participation=0.5), # To avoid age correlations ] ) sim.initialize() @@ -253,7 +253,7 @@ def test_independence(do_plot=False, thresh=0.1): for key,network in sim.networks.items(): data = np.zeros(len(sim.people)) for p in ['p1', 'p2']: - for uid in network.contacts[p]: + for uid in network.edges[p]: data[uid] += 1 # Could also use a histogram arrs[key] = data diff --git a/tests/test_sim.py b/tests/test_sim.py index afe7bebe..ae5d04cf 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -1,5 +1,5 @@ """ -Test simple APIs +Test Sim API """ # %% Imports and settings