Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine crn #322

Closed
wants to merge 14 commits into from
167 changes: 39 additions & 128 deletions starsim/disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,154 +249,64 @@ def _check_betas(self, sim):
raise ValueError(errormsg)
return betamap

def _make_new_cases_singlerng(self, sim):
# Not common-random-number-safe, but more efficient for when not using the multirng feature
def make_new_cases(self, sim):
"""
Add new cases of module, through transmission, incidence, etc.

Common-random-number-safe transmission code works by mapping edges onto
slots.
"""
new_cases = []
sources = []
people = sim.people
betamap = self._check_betas(sim)

for nkey, net in sim.networks.items():
for i,nkey,net in sim.networks.enumitems():
if not len(net):
break

nbetas = betamap[nkey]
contacts = net.contacts
rel_trans = (self.infectious & people.alive) * self.rel_trans
rel_sus = (self.susceptible & people.alive) * self.rel_sus
for a, b, beta in [[contacts.p1, contacts.p2, nbetas[0]],
[contacts.p2, contacts.p1, nbetas[1]]]:
p1p2b0 = [contacts.p1, contacts.p2, nbetas[0]]
p2p1b1 = [contacts.p2, contacts.p1, nbetas[1]]
for src, trg, beta in [p1p2b0, p2p1b1]:

# Skip networks with no transmission
if beta == 0:
continue

# Calculate probability of a->b transmission.
beta_per_dt = net.beta_per_dt(disease_beta=beta, dt=people.dt)
p_transmit = rel_trans[a] * rel_sus[b] * beta_per_dt
p_transmit = rel_trans[src] * rel_sus[trg] * beta_per_dt

new_cases_bool = np.random.random(
len(a)) < p_transmit # As this class is not common-random-number safe anyway, calling np.random is perfectly fine!
new_cases.append(b[new_cases_bool])
sources.append(a[new_cases_bool])
if not ss.options.multirng:
rvs = np.random.rand(len(src))
else:
for rng in [self.rng_source, self.rng_acquisition]: # Reset the random states
rng.random_state.step(sim.ti+i)
slots_s = people.slot[src] # Slots for the possible source
slots_t = people.slot[trg] # Slots for the possible target
rvs_s = self.rng_source.rvs(size=np.max(slots_s) + 1)[slots_s]
rvs_t = self.rng_acquisition.rvs(size=np.max(slots_t) + 1)[slots_t]
rvs = np.remainder(rvs_s + rvs_t, 1) # Generate a new random number based on the two other random numbers
new_cases_bool = rvs < p_transmit
new_cases.append(trg[new_cases_bool])
sources.append(src[new_cases_bool])

# Tidy up
if len(new_cases) and len(sources):
return np.concatenate(new_cases), np.concatenate(sources)
return np.empty((0,), dtype=int), np.empty((0,), dtype=int)

def _make_new_cases_multirng(self, sim):
"""
Common-random-number-safe transmission code works by computing the
probability of each _node_ acquiring a case rather than checking if each
_edge_ transmits.
Subsequent step uses a roulette wheel with slotted RNG to determine
infection source.
"""
people = sim.people
n = len(people.uid) # TODO: possibly could be shortened to just the people who are alive
p_acq_node = np.zeros(n)
betamap = self._check_betas(sim)

avec = []
bvec = []
pvec = []
for nkey, net in sim.networks.items():
if not len(net):
break
nbetas = betamap[nkey]
contacts = net.contacts
rel_trans = self.rel_trans * (self.infectious & people.alive)
rel_sus = self.rel_sus * (self.susceptible & people.alive)

p1p2 = ['p1', 'p2', nbetas[0]]
p2p1 = ['p2', 'p1', nbetas[1]]
for source, target, beta in [p1p2, p2p1]: # Transmission from a --> b
if beta == 0:
continue

a, b, beta_arr = contacts[source], contacts[target], contacts.beta
nzi = (rel_trans[a] > 0) & (rel_sus[b] > 0) & (beta_arr > 0)
avec.append(a[nzi])
bvec.append(b[nzi])

beta_per_dt = net.beta_per_dt(disease_beta=beta, dt=people.dt, uids=nzi)
trans_arr = rel_trans[a[nzi]].__array__()
sus_arr = rel_sus[b[nzi]].__array__()
new_pvec = trans_arr * sus_arr * beta_per_dt
pvec.append(new_pvec)

if len(avec):
dfp1 = np.concatenate(avec)
dfp2 = np.concatenate(bvec)
dfp = np.concatenate(pvec)
new_cases = np.concatenate(new_cases)
sources = np.concatenate(sources)
else:
return np.empty((0,), dtype=int), np.empty((0,), dtype=int)

if len(dfp) == 0:
return np.empty((0,), dtype=int), np.empty((0,), dtype=int)

p2uniq, p2idx, p2inv, p2cnt = np.unique(dfp2, return_index=True, return_inverse=True, return_counts=True)

# Pre-draw random numbers
slots = people.slot[p2uniq] # Slots for the possible infectee
r = self.rng_acquisition.rvs(size=np.max(slots) + 1)
q = self.rng_source.rvs(size=np.max(slots) + 1)

# Now address nodes with multiple possible infectees
degrees = np.unique(p2cnt)
new_cases = []
sources = []
for deg in degrees:
if deg == 1:
# UIDs that only appear once
cnt1 = p2cnt == 1
uids = p2uniq[cnt1]
idx = p2idx[cnt1]
p_acq_node = dfp[idx]
cases = r[people.slot[uids]] < p_acq_node
if cases.any():
s = dfp1[idx][cases]
else:
dups = np.argwhere(p2cnt==deg).flatten()
uids = p2uniq[dups]
inds = [np.argwhere(np.isin(p2inv, d)).flatten() for d in dups]
probs = dfp[inds]
p_acq_node = 1-np.prod(1-probs, axis=1)

cases = r[people.slot[uids]] < p_acq_node
if cases.any():
# Vectorized roulette wheel
cumsum = probs[cases].cumsum(axis=1)
cumsum /= cumsum[:,-1][:,np.newaxis]
ix = np.argmax(cumsum >= q[people.slot[uids[cases]]][:,np.newaxis], axis=1)
s = np.take_along_axis(dfp1[inds][cases], ix[:,np.newaxis], axis=1).flatten()#dfp1[inds][cases][np.arange(len(cases)),ix]

if cases.any():
new_cases.append(uids[cases])
sources.append(s)

if len(new_cases) == 0:
return np.empty((0,), dtype=int), np.empty((0,), dtype=int)

new_cases = np.concatenate(new_cases)
sources = np.concatenate(sources)
return new_cases, sources

def make_new_cases(self, sim):
""" Add new cases of module, through transmission, incidence, etc. """
if not sim.networks:
warnmsg = f'Disease {self.name} does not transmit without a network.'
if sim.ti == 0: ss.warn(warnmsg, die=False)
return

if not ss.options.multirng:
# Determine new cases for singlerng
new_cases, sources = self._make_new_cases_singlerng(sim)
else:
# Determine new cases for multirng
new_cases, sources = self._make_new_cases_multirng(sim)

new_cases = np.empty(0, dtype=int)
sources = np.empty(0, dtype=int)

if len(new_cases):
self._set_cases(sim, new_cases, sources)

return new_cases, sources

def _set_cases(self, sim, target_uids, source_uids=None):
congenital = sim.people.age[target_uids] <= 0
Expand All @@ -410,13 +320,14 @@ def _set_cases(self, sim, target_uids, source_uids=None):
def set_congenital(self, sim, target_uids, source_uids=None):
pass


def update_results(self, sim):
super().update_results(sim)
res = self.results
res['prevalence'][sim.ti] = res.n_infected[sim.ti] / np.count_nonzero(sim.people.alive)
res['new_infections'][sim.ti] = np.count_nonzero(self.ti_infected == sim.ti)
res['cum_infections'][sim.ti] = np.sum(res['new_infections'][:sim.ti])
ti = sim.ti
res.prevalence[ti] = res.n_infected[ti] / np.count_nonzero(sim.people.alive)
res.new_infections[ti] = np.count_nonzero(self.ti_infected == ti)
res.cum_infections[ti] = np.sum(res['new_infections'][:ti])
return


class InfectionLog(nx.MultiDiGraph):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def get_outputs(p_death):
outputs = []
for i in range(3):
ppl = ss.People(1000)
ppl.networks = ss.ndict(ss.RandomNet(n_contacts=ss.poisson(5)))
network = ss.RandomNet(n_contacts=ss.poisson(mu=5))
sir = ss.SIR(pars={'p_death':p_death})
sim = ss.Sim(people=ppl, diseases=sir, rand_seed=0, n_years=5)
sim = ss.Sim(people=ppl, networks=network, diseases=sir, rand_seed=0, n_years=5)
sim.run(verbose=0)
df = sim.export_df()
summary = {}
Expand Down
Loading