Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module exports #332

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions starsim/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import sciris as sc
import starsim as ss
from scipy.stats._distn_infrastructure import rv_frozen
from inspect import signature, _empty
import inspect
import numpy as np

__all__ = ['Module']

class Module(sc.prettyobj):

def __init__(self, pars=None, par_dists=None, name=None, label=None, requires=None, **kwargs):
self._store_args() # Store the input arguments so the module can be recreated
self.pars = ss.omerge(pars, kwargs)
self.par_dists = ss.omerge(par_dists)
self.name = name if name else self.__class__.__name__.lower() # Default name is the class name
Expand Down Expand Up @@ -70,7 +71,7 @@ def initialize(self, sim):

# Otherwise, figure out the required arguments and assume the user is trying to set them
else:
rqrd_args = [x for x, p in signature(par_dist._parse_args).parameters.items() if p.default == _empty]
rqrd_args = [x for x, p in inspect.signature(par_dist._parse_args).parameters.items() if p.default == inspect._empty]
if len(rqrd_args) != 0:
par_dist_arg = rqrd_args[0]
else:
Expand Down Expand Up @@ -170,3 +171,43 @@ def create(cls, name, *args, **kwargs):
return subcls(*args, **kwargs)
else:
raise KeyError(f'Module "{name}" did not match any known Starsim Modules')

def _store_args(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapting logic from HPVsim for exporting interventions, but perhaps there would be a nicer way?

"""
Store user-supplied arguments for later use in to_json
"""
f0 = inspect.currentframe()
f1 = inspect.getouterframes(f0)
if self.__class__.__init__ is Module.__init__:
parent = f1[1].frame
else:
parent = f1[2].frame
_,_,_,values = inspect.getargvalues(parent) # Get the values of the arguments
if values:
self.input_args = {}
for key,value in values.items():
if key == 'kwargs': # Store additional kwargs directly
for k2,v2 in value.items(): # pragma: no cover
self.input_args[k2] = v2 # These are already a dict
elif key not in ['self', '__class__']: # Everything else, but skip these
self.input_args[key] = value
return

def to_json(self):
'''
Return JSON-compatible representation

Custom classes can't be directly represented in JSON. This method is a
one-way export to produce a JSON-compatible representation of the
intervention. In the first instance, the object dict will be returned.
However, if an intervention itself contains non-standard variables as
attributes, then its ``to_json`` method will need to handle those.

Returns:
JSON-serializable representation (typically a dict, but could be anything else)
'''
which = self.__class__.__name__
pars = sc.jsonify(self.input_args)
output = dict(which=which, pars=pars)
return output

80 changes: 71 additions & 9 deletions starsim/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Parameters(sc.objdict):
def __init__(self, **kwargs):

# Population parameters
self.people = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't remember, do we also support ss.Sim(pars=dict(people=dict(n_agents=10e3, ...))) -- which I'm not sure we do currently? (i.e. have a people pars dict in addition to being able to supply people pars directly in the sim pars)

self.n_agents = 10e3 # Number of agents
self.total_pop = None # If defined, used for calculating the scale factor
self.pop_scale = None # How much to scale the population
Expand All @@ -44,14 +45,6 @@ def __init__(self, **kwargs):
self.slot_scale = 5 # Random slots will be assigned to newborn agents between min=n_agents and max=slot_scale*n_agents. Choosing a larger value here will reduce the probability of two agents using the same slot (and hence random draws), but increase the number of random numbers that are required.
self.verbose = ss.options.verbose # Whether or not to display information during the run -- options are 0 (silent), 0.1 (some; default), 1 (default), 2 (everything)

# Plug-ins: demographics, diseases, connectors, networks, analyzers, and interventions
self.demographics = ss.ndict()
self.diseases = ss.ndict()
self.networks = ss.ndict()
self.connectors = ss.ndict()
self.interventions = ss.ndict()
self.analyzers = ss.ndict()

# Update with any supplied parameter values and generate things that need to be generated
self.update(kwargs)

Expand All @@ -60,18 +53,25 @@ def __init__(self, **kwargs):

return

def update_pars(self, pars=None, create=False, **kwargs):
def update_pars(self, pars=None, create=False, module_types=None, **kwargs):
"""
Update internal dict with new pars.
Args:
pars (dict): the parameters to update (if None, do nothing)
create (bool): if create is False, then raise a KeyNotFoundError if the key does not already exist
module_types (dict): types of parameters to convert to modules
"""
if pars is not None:
if not isinstance(pars, dict):
raise TypeError(f'The pars object must be a dict; you supplied a {type(pars)}')

pars = sc.mergedicts(pars, kwargs)

# Initialize modules here??
for mname, mtype in module_types.items():
if mname in pars.keys():
pars[mname] = self.init_module_pars(pars[mname], mtype)

if not create:
available_keys = list(self.keys())
mismatches = [key for key in pars.keys() if key not in available_keys]
Expand All @@ -81,6 +81,68 @@ def update_pars(self, pars=None, create=False, **kwargs):
self.update(pars)
return

def init_module_pars(self, mpar, mtype):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved this from sim to parameters, it could also go in the modules themselves as a class method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I kind of like that actually! (having it be in modules)

""" Initialize modules """
known_modules = {n.__name__.lower():n for n in ss.all_subclasses(mtype)}
processed_m = sc.autolist()

# Process boolean inputs for demographics
if isinstance(mpar, bool) and mtype==ss.BaseDemographics:
if mpar:
return ss.ndict([ss.Births(), ss.Deaths()], type=mtype)
else:
return ss.ndict(type=mtype)

# String: convert to a dict
if isinstance(mpar, str):
mpar = {'type': mpar}

# Dict: check the keys and convert to a class instance
if isinstance(mpar, dict):

# It might be an already-converted ndict
if all([isinstance(v, mtype) for v in mpar.values()]):
return mpar

ptype = (mpar.get('type') or mpar.get('name') or '').lower()
name = mpar.get('name') or ptype

if ptype in known_modules:
# Make an instance of the requested module
module_pars = {k: v for k, v in mpar.items() if k not in ['type', 'name']}
pclass = known_modules[ptype]
module = pclass(name=name, pars=module_pars) # TODO: does this handle par_dists, etc?
else:
errormsg = (f'Could not convert {mpar} to an instance of class {mtype}.'
f'Try specifying it directly rather than as a dictionary.')
raise ValueError(errormsg)
processed_m += module

# Class instance
elif isinstance(mpar, mtype):
processed_m += mpar

# Class
elif isinstance(mpar, type) and issubclass(mpar, mtype):
processed_m += mpar() # Convert from a class to an instance of a class

# Function - only for interventions
elif callable(mpar) and mtype==ss.Intervention:
processed_m += mpar

# It's a list - iterate
elif isinstance(mpar, list):
for mpar_val in mpar:
processed_m += self.init_module_pars(mpar_val, mtype)

else:
errormsg = (
f'{mpar.name.capitalize()} must be provided as either class instances or dictionaries with a '
f'"name" key corresponding to one of these known subclasses: {known_modules}.')
raise ValueError(errormsg)

return ss.ndict(processed_m, type=mtype)


def make_pars(**kwargs):
return Parameters(**kwargs)
Loading
Loading