Skip to content

Commit

Permalink
adapted code for changes in lyscripts prevalence prediction (midext t…
Browse files Browse the repository at this point in the history
…imeevolution model)
  • Loading branch information
larstwi committed Oct 26, 2023
1 parent 74d1181 commit e02c47c
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions lymph/midline.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ def likelihood(
data: Optional[pd.DataFrame] = None,
given_params: Optional[np.ndarray] = None,
log: bool = True,
t_stages: Optional[List] = None,
prevalence_calc: bool = False
) -> float:
"""Compute the (log-)likelihood of data, using the stored spread probs and
fixed distributions for marginalizing over diagnose times.
Expand All @@ -510,23 +512,22 @@ def likelihood(
"""
if data is not None:
self.patient_data = data

try:
self.check_and_assign(given_params)
except ValueError:
return -np.inf if log else 0.

stored_t_stages = set(self.ext.ipsi.diagnose_matrices.keys())
provided_t_stages = set(self.ext.ipsi.diag_time_dists.keys())
t_stages = list(stored_t_stages.intersection(provided_t_stages))
if t_stages is None:
stored_t_stages = set(self.ext.ipsi.diagnose_matrices.keys())
provided_t_stages = set(self.ext.ipsi.diag_time_dists.keys())
t_stages = list(stored_t_stages.intersection(provided_t_stages))

max_t = self.diag_time_dists.max_t
llh = 0. if log else 1.

# state_probs_midext = self._evolve_midext()
state_probs_ipsi = self.ext.ipsi._evolve(t_last=max_t)
state_probs_contra_nox, state_probs_contra_ex = self._evolve_contra(t_last=max_t)

for stage in t_stages:
# the two `joint_state_probs` below together represent the joint probability
# of any ipsi- AND contralateral state AND any midline extension state
Expand All @@ -541,7 +542,6 @@ def likelihood(
@ np.diag(self.ext.ipsi.diag_time_dists[stage].pmf)
@ state_probs_contra_ex
)

joint_diag_probs_nox = np.sum(
self.noext.ipsi.diagnose_matrices[stage]
* (joint_state_probs_nox
Expand All @@ -559,12 +559,13 @@ def likelihood(
stage_llh = (
joint_diag_probs * self.diagnose_matrices_midext[stage].T
).sum(axis=0)

if log:
llh += np.sum(np.log(stage_llh))
else:
llh *= np.prod(stage_llh)

if prevalence_calc:
llh = stage_llh
else:
llh *= np.prod(stage_llh)
return llh


Expand Down

0 comments on commit e02c47c

Please sign in to comment.