Skip to content

Commit

Permalink
Merge pull request #66 from starsimhub/msm-example
Browse files Browse the repository at this point in the history
Add MSM examples
  • Loading branch information
robynstuart authored Dec 20, 2024
2 parents 75163d4 + 309530a commit 8600c84
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 30 deletions.
47 changes: 21 additions & 26 deletions stisim/diseases/hiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class HIV(BaseSTI):

def __init__(self, pars=None, init_prev_data=None, **kwargs):
super().__init__()
self.requires = 'structuredsexual'
# self.requires = 'structuredsexual'

# Parameters
self.define_pars(
Expand Down Expand Up @@ -141,12 +141,13 @@ def init_results(self):
results += [ss.Result('n_on_art_pregnant', dtype=int)]

# Add FSW and clients to results:
for risk_group in range(self.sim.networks.structuredsexual.pars.n_risk_groups):
for sex in ['female', 'male']:
results += [
ss.Result('prevalence_risk_group_' + str(risk_group) + '_' + sex, scale=False),
ss.Result('new_infections_risk_group_' + str(risk_group) + '_' + sex, dtype=int),
]
if 'structuredsexual' in self.sim.networks.keys():
for risk_group in range(self.sim.networks.structuredsexual.pars.n_risk_groups):
for sex in ['female', 'male']:
results += [
ss.Result('prevalence_risk_group_' + str(risk_group) + '_' + sex, scale=False),
ss.Result('new_infections_risk_group_' + str(risk_group) + '_' + sex, dtype=int),
]

self.define_results(*results)

Expand Down Expand Up @@ -444,25 +445,19 @@ def update_results(self):
self.results['p_on_art'][ti] = sc.safedivide(self.results['n_on_art'][ti], self.results['n_infected'][ti])

# Subset by FSW and client:
fsw_infected = self.infected[self.sim.networks.structuredsexual.fsw]
client_infected = self.infected[self.sim.networks.structuredsexual.client]
# for risk_group in range(self.sim.networks.structuredsexual.pars.n_risk_groups):
# for sex in ['female', 'male']:
# risk_group_infected = self.infected[(self.sim.networks.structuredsexual.risk_group == risk_group) & (self.sim.people[sex])]
# risk_group_new_inf = ((self.ti_infected == ti) & (self.sim.networks.structuredsexual.risk_group == risk_group) & (self.sim.people[sex])).uids
# if len(risk_group_infected) > 0:
# self.results['prevalence_risk_group_' + str(risk_group) + '_' + sex][ti] = sum(risk_group_infected) / len(risk_group_infected)
# self.results['new_infections_risk_group_' + str(risk_group) + '_' + sex][ti] = len(risk_group_new_inf)

# Add FSW and clients to results:
if len(fsw_infected) > 0:
self.results['prevalence_sw'][ti] = sum(fsw_infected) / len(fsw_infected)
self.results['new_infections_sw'][ti] = len(((self.ti_infected == ti) & self.sim.networks.structuredsexual.fsw).uids)
self.results['new_infections_not_sw'][ti] = len(((self.ti_infected == ti) & ~self.sim.networks.structuredsexual.fsw).uids)
if len(client_infected) > 0:
self.results['prevalence_client'][ti] = sum(client_infected) / len(client_infected)
self.results['new_infections_client'][ti] = len(((self.ti_infected == ti) & self.sim.networks.structuredsexual.client).uids)
self.results['new_infections_not_client'][ti] = len(((self.ti_infected == ti) & ~self.sim.networks.structuredsexual.client).uids)
if 'structuredsexual' in self.sim.networks.keys():
fsw_infected = self.infected[self.sim.networks.structuredsexual.fsw]
client_infected = self.infected[self.sim.networks.structuredsexual.client]

# Add FSW and clients to results:
if len(fsw_infected) > 0:
self.results['prevalence_sw'][ti] = sum(fsw_infected) / len(fsw_infected)
self.results['new_infections_sw'][ti] = len(((self.ti_infected == ti) & self.sim.networks.structuredsexual.fsw).uids)
self.results['new_infections_not_sw'][ti] = len(((self.ti_infected == ti) & ~self.sim.networks.structuredsexual.fsw).uids)
if len(client_infected) > 0:
self.results['prevalence_client'][ti] = sum(client_infected) / len(client_infected)
self.results['new_infections_client'][ti] = len(((self.ti_infected == ti) & self.sim.networks.structuredsexual.client).uids)
self.results['new_infections_not_client'][ti] = len(((self.ti_infected == ti) & ~self.sim.networks.structuredsexual.client).uids)

return

Expand Down
73 changes: 71 additions & 2 deletions stisim/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ss_float_ = ss.dtypes.float

# Specify all externally visible functions this file defines; see also more definitions below
__all__ = ['StructuredSexual', 'FastStructuredSexual']
__all__ = ['StructuredSexual', 'FastStructuredSexual', 'AgeMatchedMSM', 'AgeApproxMSM']


class NoPartnersFound(Exception):
Expand Down Expand Up @@ -378,7 +378,7 @@ def add_pairs(self, ti=None):
self.append(p1=p1, p2=p2, beta=beta, condoms=condoms, dur=dur, acts=acts, sw=sw, age_p1=age_p1, age_p2=age_p2)

# Checks
if self.sim.people.female[p1].any() or self.sim.people.male[p2].any():
if (self.sim.people.female[p1].any() or self.sim.people.male[p2].any()) and (self.name == 'structuredsexual'):
errormsg = 'Same-sex pairings should not be possible in this network'
raise ValueError(errormsg)
if len(p1) != len(p2):
Expand Down Expand Up @@ -557,4 +557,73 @@ def match_pairs(self, ppl):
p1 = p1[:maxlen]
p2 = p2[:maxlen]

return p1, p2


class AgeMatchedMSM(StructuredSexual):

def __init__(self, **kwargs):
super().__init__(name='msm', **kwargs)

def match_pairs(self, ppl):
""" Match males by age using sorting """

# Find people eligible for a relationship
active = self.over_debut()
underpartnered = self.partners < self.concurrency
m_eligible = active & ppl.male & underpartnered
m_looking = self.pars.p_pair_form.filter(m_eligible.uids)

if len(m_looking) == 0:
raise NoPartnersFound()

# Match mairs by sorting the men looking for partners by age, then matching pairs by taking
# 2 people at a time from the sorted list
m_ages = ppl.age[m_looking]
ind_m = np.argsort(m_ages)
p1 = m_looking[ind_m][::2]
p2 = m_looking[ind_m][1::2]
maxlen = min(len(p1), len(p2))
p1 = p1[:maxlen]
p2 = p2[:maxlen]

# Make sure everyone only appears once (?)
if len(np.intersect1d(p1, p2)):
errormsg = 'Some people appear in both p1 and p2'
raise ValueError(errormsg)

return p1, p2


class AgeApproxMSM(StructuredSexual):

def __init__(self, **kwargs):
super().__init__(name='msm', **kwargs)

def match_pairs(self, ppl):
""" Match"""

# Find people eligible for a relationship
active = self.over_debut()
underpartnered = self.partners < self.concurrency
m_eligible = active & ppl.male & underpartnered
m_looking = self.pars.p_pair_form.filter(m_eligible.uids)

# Split the total number of males looking for partners into 2 groups
# The first group will be matched with the second group
group1 = m_looking[::2]
group2 = m_looking[1::2]
loc, scale = self.get_age_risk_pars(group1, self.pars.age_diff_pars)
self.pars.age_diffs.set(loc=loc, scale=scale)
age_gaps = self.pars.age_diffs.rvs(group1)
desired_ages = ppl.age[group1] + age_gaps
g2_ages = ppl.age[group2]
ind_p1 = np.argsort(g2_ages)
ind_p2 = np.argsort(desired_ages)
p1 = m_eligible.uids[ind_p1]
p2 = group2[ind_p2]
maxlen = min(len(p1), len(p2))
p1 = p1[:maxlen]
p2 = p2[:maxlen]

return p1, p2
25 changes: 23 additions & 2 deletions tests/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_hiv_sim(n_agents=500):
)
pregnancy = ss.Pregnancy(fertility_rate=10)
death = ss.Deaths(death_rate=10)
sexual = sti.StructuredSexual()
sexual = sti.FastStructuredSexual()
maternal = ss.MaternalNet()
testing = sti.HIVTest(test_prob_data=0.2, start=2000)
art = sti.ART(coverage_data=pd.DataFrame(index=np.arange(2000, 2021), data={'p_art': np.linspace(0, 0.9, 21)}))
Expand All @@ -40,6 +40,26 @@ def test_hiv_sim(n_agents=500):
return sim


def test_msm_hiv(n_agents=500):
hiv = sti.HIV(beta={'msm': [0.1, 0.1]}, init_prev=0.05)
pregnancy = ss.Pregnancy(fertility_rate=10)
death = ss.Deaths(death_rate=10)
msm = sti.AgeMatchedMSM()
sim = ss.Sim(
dt=1/12,
start=1990,
dur=10,
n_agents=n_agents,
diseases=hiv,
networks=msm,
demographics=[pregnancy, death],
)
sim.run(verbose=1/12)

return sim



def test_bv(include_hiv=False, n_agents=500, start=2015, n_years=10):

class menstrual_hygiene(ss.Intervention):
Expand Down Expand Up @@ -133,7 +153,8 @@ def test_stis(which='discharging', n_agents=5e3, start=2010, stop=2020):
do_plot = True

s0 = test_hiv_sim()
s1 = test_stis(which='discharging')
s1 = test_msm_hiv()
s2 = test_stis(which='discharging')

if do_plot:
s1.plot("ng")
Expand Down

0 comments on commit 8600c84

Please sign in to comment.