From 6c879c1759ce255270dbc0f4f4c50e648a45ebeb Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 10:38:27 +0100 Subject: [PATCH 1/9] fix(bi): align posterior method name --- lymph/models/bilateral.py | 4 ++-- tests/binary_bilateral_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index b642689..c286e4e 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -564,7 +564,7 @@ def likelihood( raise ValueError("Invalid mode. Must be either 'HMM' or 'BN'.") - def posterior_joint_state_dist( + def posterior_state_dist( self, given_params: types.ParamsType | None = None, given_diagnoses: dict[str, types.DiagnoseType] | None = None, @@ -645,7 +645,7 @@ def risk( only marginalizes over the states that match the involvement pattern. """ # TODO: test this method - posterior_state_probs = self.posterior_joint_state_dist( + posterior_state_probs = self.posterior_state_dist( given_params=given_params, given_diagnoses=given_diagnoses, t_stage=t_stage, diff --git a/tests/binary_bilateral_test.py b/tests/binary_bilateral_test.py index c90239e..d18cfa0 100644 --- a/tests/binary_bilateral_test.py +++ b/tests/binary_bilateral_test.py @@ -318,7 +318,7 @@ def test_posterior_state_dist(self): random_parameters = self.create_random_params() random_diagnoses = self.create_random_diagnoses() - posterior = self.model.posterior_joint_state_dist( + posterior = self.model.posterior_state_dist( given_params=random_parameters, given_diagnoses=random_diagnoses, ) From 68e97e596f70460ef812499afae2491135203267 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 10:47:49 +0100 Subject: [PATCH 2/9] change(bi): import `utils` --- lymph/models/bilateral.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index c286e4e..2f0a776 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -7,14 +7,7 @@ import numpy as np import pandas as pd -from lymph import diagnose_times, matrix, modalities, models, types -from lymph.utils import ( - add_or_mult, - early_late_mapping, - flatten, - synchronize_params, - unflatten_and_split, -) +from lymph import diagnose_times, matrix, modalities, models, types, utils warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning) logger = logging.getLogger(__name__) @@ -189,7 +182,7 @@ def get_tumor_spread_params( params = params["ipsi"] if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -221,7 +214,7 @@ def get_lnl_spread_params( params = params["ipsi"] if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -276,7 +269,7 @@ def get_spread_params( params.update(self.get_lnl_spread_params(as_flat=as_flat)) if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -300,14 +293,14 @@ def get_params( params.update(self.get_distribution_params(as_flat=as_flat)) if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]: """Set the parameters of the model's spread from tumor to LNLs.""" - kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"]) + kwargs, global_kwargs = utils.unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"]) ipsi_kwargs = global_kwargs.copy() ipsi_kwargs.update(kwargs.get("ipsi", {})) @@ -316,7 +309,7 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float] args = self.ipsi.set_tumor_spread_params(*args, **ipsi_kwargs) if self.is_symmetric["tumor_spread"]: - synchronize_params( + utils.synchronize_params( get_from=self.ipsi.graph.tumor_edges, set_to=self.contra.graph.tumor_edges, ) @@ -328,7 +321,9 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float] def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]: """Set the parameters of the model's spread from LNLs to tumor.""" - kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"]) + kwargs, global_kwargs = utils.unflatten_and_split( + kwargs, expected_keys=["ipsi", "contra"], + ) ipsi_kwargs = global_kwargs.copy() ipsi_kwargs.update(kwargs.get("ipsi", {})) @@ -337,7 +332,7 @@ def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]: args = self.ipsi.set_lnl_spread_params(*args, **ipsi_kwargs) if self.is_symmetric["lnl_spread"]: - synchronize_params( + utils.synchronize_params( get_from=self.ipsi.graph.lnl_edges, set_to=self.contra.graph.lnl_edges, ) @@ -394,7 +389,7 @@ def set_params(self, *args: float, **kwargs: float) -> tuple[float]: def load_patient_data( self, patient_data: pd.DataFrame, - mapping: callable | dict[int, Any] = early_late_mapping, + mapping: callable | dict[int, Any] = utils.early_late_mapping, ) -> None: """Load patient data into the model. @@ -515,7 +510,7 @@ def _hmm_likelihood(self, log: bool = True, t_stage: str | None = None) -> float self.ipsi.diagnose_matrix(stage), joint_state_dist @ self.contra.diagnose_matrix(stage).T, ) - llh = add_or_mult(llh, patient_llhs, log) + llh = utils.add_or_mult(llh, patient_llhs, log) return llh @@ -586,7 +581,9 @@ def posterior_state_dist( See Also: :py:meth:`.Unilateral.posterior_state_dist` """ - if isinstance(given_params, dict): + if given_params is None: + pass + elif isinstance(given_params, dict): self.set_params(**given_params) else: self.set_params(*given_params) From 46dc125e5dba81dc8088d0a9ce8603d3a77d0a41 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 10:49:20 +0100 Subject: [PATCH 3/9] feat(utils): add `saf_set_params()` function This checks whether the params are a dict, list, or None and handles them accordingly. --- lymph/utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/lymph/utils.py b/lymph/utils.py index a0ebcf6..88ada41 100644 --- a/lymph/utils.py +++ b/lymph/utils.py @@ -347,6 +347,24 @@ def set_params_for( return args +def safe_set_params( + model: types.ModelT, + params: types.ParamsType | None = None, +) -> None: + """Set the ``params`` of the ``model``. + + This infers whether ``params`` is a dict or a list and calls the ``model``'s method + ``set_params()`` accordingly. + """ + if params is None: + return + + if isinstance(params, dict): + model.set_params(**params) + else: + model.set_params(*params) + + def synchronize_params( get_from: dict[str, types.HasGetParams], set_to: dict[str, types.HasSetParams], From f9292bd82240407aa5c0d65dd59ad4167c0f45a9 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 10:53:31 +0100 Subject: [PATCH 4/9] change(mid): import `utils` --- lymph/models/midline.py | 48 +++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/lymph/models/midline.py b/lymph/models/midline.py index e942487..f27733d 100644 --- a/lymph/models/midline.py +++ b/lymph/models/midline.py @@ -7,16 +7,7 @@ import numpy as np import pandas as pd -from lymph import diagnose_times, matrix, modalities, models, types -from lymph.types import DiagnoseType, PatternType -from lymph.utils import ( - add_or_mult, - draw_diagnoses, - early_late_mapping, - flatten, - popfirst, - unflatten_and_split, -) +from lymph import diagnose_times, matrix, modalities, models, types, utils warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning) logger = logging.getLogger(__name__) @@ -277,7 +268,7 @@ def get_tumor_spread_params( } if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -311,7 +302,7 @@ def get_lnl_spread_params( ) if as_flat or not as_dict: - ext_lnl_params = flatten(ext_lnl_params) + ext_lnl_params = utils.flatten(ext_lnl_params) return ext_lnl_params if as_dict else ext_lnl_params.values() @@ -338,7 +329,7 @@ def get_spread_params( params["contra"].update(lnl_spread_params["contra"]) if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -360,7 +351,7 @@ def get_params( params.update(self.get_distribution_params(as_flat=as_flat)) if as_flat or not as_dict: - params = flatten(params) + params = utils.flatten(params) return params if as_dict else params.values() @@ -377,7 +368,7 @@ def set_tumor_spread_params( :py:meth:`get_tumor_spread_params` method returns for an insight in what you can provide. """ - kwargs, global_kwargs = unflatten_and_split( + kwargs, global_kwargs = utils.unflatten_and_split( kwargs, expected_keys=["ipsi", "noext", "ext", "contra"], ) @@ -394,7 +385,7 @@ def set_tumor_spread_params( contra_kwargs = global_kwargs.copy() contra_kwargs.update(kwargs.get("contra", {})) args = self.noext.contra.set_tumor_spread_params(*args, **contra_kwargs) - mixing_param, args = popfirst(args) + mixing_param, args = utils.popfirst(args) mixing_param = global_kwargs.get("mixing", mixing_param) or self.mixing_param self.mixing_param = global_kwargs.get("mixing", mixing_param) @@ -429,7 +420,7 @@ def set_lnl_spread_params(self, *args: float, **kwargs: float) -> Iterable[float three instances of :py:class:`~.Bilateral` depending on the value of the ``use_central`` attribute. """ - kwargs, global_kwargs = unflatten_and_split( + kwargs, global_kwargs = utils.unflatten_and_split( kwargs, expected_keys=["ipsi", "noext", "ext", "contra"], ) ipsi_kwargs = global_kwargs.copy() @@ -474,7 +465,7 @@ def set_params( Combines the calls to :py:meth:`.set_spread_params` and :py:meth:`.set_distribution_params`. """ - first, args = popfirst(args) + first, args = utils.popfirst(args) self.midext_prob = kwargs.get("midext_prob", first) or self.midext_prob args = self.set_spread_params(*args, **kwargs) return self.set_distribution_params(*args, **kwargs) @@ -483,7 +474,7 @@ def set_params( def load_patient_data( self, patient_data: pd.DataFrame, - mapping: callable = early_late_mapping, + mapping: callable = utils.early_late_mapping, ) -> None: """Load patient data into the model. @@ -607,14 +598,14 @@ def _hmm_likelihood(self, log: bool = True, for_t_stage: str | None = None) -> f _model.ipsi.diagnose_matrix(stage), joint_state_dist @ _model.contra.diagnose_matrix(stage).T ) - llh = add_or_mult(llh, patient_llhs, log=log) + llh = utils.add_or_mult(llh, patient_llhs, log=log) try: marg_patient_llhs = matrix.fast_trace( self.unknown.ipsi.diagnose_matrix(stage), marg_joint_state_dist @ self.unknown.contra.diagnose_matrix(stage).T ) - llh = add_or_mult(llh, marg_patient_llhs, log=log) + llh = utils.add_or_mult(llh, marg_patient_llhs, log=log) except AttributeError: # an AttributeError is raised both when the model has no `unknown` # attribute and when no data is loaded in the `unknown` model. @@ -673,9 +664,9 @@ def likelihood( def risk( self, - involvement: PatternType | None = None, + involvement: types.PatternType | None = None, given_params: types.ParamsType | None = None, - given_diagnoses: dict[str, DiagnoseType] | None = None, + given_diagnoses: dict[str, types.DiagnoseType] | None = None, t_stage: str = "early", midline_extension: bool = False, central: bool = False, @@ -691,12 +682,7 @@ def risk( For logical reasons, ``midline_extension=False`` makes no sense if ``central=True`` and is thus ignored. """ - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) if central: return self.central.risk( @@ -768,7 +754,7 @@ def draw_patients( drawn_diags = np.empty(shape=(num, len(self.ext.ipsi.obs_list))) for case in ["ext", "noext"]: case_model = getattr(self, case) - drawn_ipsi_diags = draw_diagnoses( + drawn_ipsi_diags = utils.draw_diagnoses( diagnose_times=drawn_diag_times[drawn_midexts == (case == "ext")], state_evolution=ipsi_evo, observation_matrix=case_model.ipsi.observation_matrix(), @@ -776,7 +762,7 @@ def draw_patients( rng=rng, seed=seed, ) - drawn_contra_diags = draw_diagnoses( + drawn_contra_diags = utils.draw_diagnoses( diagnose_times=drawn_diag_times[drawn_midexts == (case == "ext")], state_evolution=case_model.contra.state_dist_evo(), observation_matrix=case_model.contra.observation_matrix(), From 09a03006124e63b863989e83f34154da9482e7ac Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 10:54:21 +0100 Subject: [PATCH 5/9] refactor: use `safe_set_params()` across models --- lymph/models/bilateral.py | 16 +++------------- lymph/models/midline.py | 7 +------ lymph/models/unilateral.py | 21 +++++---------------- 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index 2f0a776..dbf9ad1 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -542,12 +542,7 @@ def likelihood( try: # all functions and methods called here should raise a ValueError if the # given parameters are invalid... - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) except ValueError: return -np.inf if log else 0. @@ -581,12 +576,8 @@ def posterior_state_dist( See Also: :py:meth:`.Unilateral.posterior_state_dist` """ - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) + joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode) if given_diagnoses is None: given_diagnoses = {} @@ -603,7 +594,6 @@ def posterior_state_dist( # vector with P(Z=z|X) for each state X. A data matrix for one "patient" diagnose_given_state[side] = diagnose_encoding @ observation_matrix.T - joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode) # matrix with P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) for all states Xi,Xc. joint_diagnose_and_state = np.outer( diagnose_given_state["ipsi"], diff --git a/lymph/models/midline.py b/lymph/models/midline.py index f27733d..6da5a86 100644 --- a/lymph/models/midline.py +++ b/lymph/models/midline.py @@ -647,12 +647,7 @@ def likelihood( try: # all functions and methods called here should raise a ValueError if the # given parameters are invalid... - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) except ValueError: return -np.inf if log else 0. diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py index e0ea4d6..bc4adbe 100644 --- a/lymph/models/unilateral.py +++ b/lymph/models/unilateral.py @@ -8,7 +8,7 @@ import pandas as pd from cachetools import LRUCache -from lymph import diagnose_times, graph, matrix, modalities, types +from lymph import diagnose_times, graph, matrix, modalities, types, utils # pylint: disable=unused-import from lymph.utils import ( # nopycln: import @@ -728,12 +728,7 @@ def likelihood( try: # all functions and methods called here should raise a ValueError if the # given parameters are invalid... - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) except ValueError: return -np.inf if log else 0. @@ -797,12 +792,9 @@ def posterior_state_dist( # here if the parameters are invalid, since we want to know if the user # provided invalid parameters. In the likelihood, we rather return a zero # likelihood to tell the inference algorithm that the parameters are invalid. - if given_params is None: - pass - elif isinstance(given_params, dict): - self.set_params(**given_params) - else: - self.set_params(*given_params) + utils.safe_set_params(self, given_params) + # vector P(X=x) of probabilities of arriving in state x (marginalized over time) + state_dist = self.state_dist(t_stage, mode=mode) if given_diagnoses is None: given_diagnoses = {} @@ -811,9 +803,6 @@ def posterior_state_dist( # vector containing P(Z=z|X). Essentially a data matrix for one patient diagnose_given_state = diagnose_encoding @ self.observation_matrix().T - # vector P(X=x) of probabilities of arriving in state x (marginalized over time) - state_dist = self.state_dist(t_stage, mode=mode) - # multiply P(Z=z|X) * P(X) elementwise to get vector of joint probs P(Z=z,X) joint_diagnose_and_state = state_dist * diagnose_given_state From 615c7cbcbcbdc0cc5c6efb06affe5f087145bb9a Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:01:40 +0100 Subject: [PATCH 6/9] feat: allow to pass state dists to posterior/risk With this, one can use precomputed state distributions to speed up computing the posterior or risk for multiple scenarios. Fixes: #80 --- lymph/models/bilateral.py | 64 +++++++++--------- lymph/models/midline.py | 133 ++++++++++++++++++++++--------------- lymph/models/unilateral.py | 48 ++++++------- lymph/types.py | 3 +- 4 files changed, 137 insertions(+), 111 deletions(-) diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index dbf9ad1..899c5c2 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -422,17 +422,13 @@ def state_dist( ipsi_state_evo = self.ipsi.state_dist_evo() contra_state_evo = self.contra.state_dist_evo() time_marg_matrix = np.diag(self.get_distribution(t_stage).pmf) + result = ipsi_state_evo.T @ time_marg_matrix @ contra_state_evo - result = ( - ipsi_state_evo.T - @ time_marg_matrix - @ contra_state_evo - ) elif mode == "BN": ipsi_state_dist = self.ipsi.state_dist(mode=mode) contra_state_dist = self.contra.state_dist(mode=mode) - result = np.outer(ipsi_state_dist, contra_state_dist) + else: raise ValueError(f"Unknown mode '{mode}'.") @@ -557,27 +553,33 @@ def likelihood( def posterior_state_dist( self, given_params: types.ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: dict[str, types.DiagnoseType] | None = None, t_stage: str | int = "early", mode: Literal["HMM", "BN"] = "HMM", ) -> np.ndarray: """Compute joint post. dist. over ipsi & contra states, ``given_diagnoses``. - The ``given_diagnoses`` is a dictionary storing a :py:obj:`.types.DiagnoseType` - for the ``"ipsi"`` and ``"contra"`` side of the neck. + The ``given_diagnoses`` is a dictionary storing one :py:obj:`.types.DiagnoseType` + each for the ``"ipsi"`` and ``"contra"`` side of the neck. Essentially, this is the risk for any possible combination of ipsi- and contralateral involvement, given the provided diagnoses. - Note: - The computation is much faster if no parameters are given, since then the - transition matrix does not need to be recomputed. + Warning: + As in the :py:meth:`.Unilateral.posterior_state_dist` method, one may + provide a precomputed (joint) state distribution via the ``given_state_dist`` + argument (should be a square matric). In this case, the ``given_params`` + are ignored and the model does not need to recompute e.g. the + :py:meth:`.transition_matrix` or :py:meth:`.state_dist`, making the + computation much faster. - See Also: - :py:meth:`.Unilateral.posterior_state_dist` + However, this will mean that ``t_stage`` and ``mode`` are also ignored, + since these are only used to compute the state distribution. """ - utils.safe_set_params(self, given_params) - joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode) + if given_state_dist is None: + utils.safe_set_params(self, given_params) + given_state_dist = self.state_dist(t_stage=t_stage, mode=mode) if given_diagnoses is None: given_diagnoses = {} @@ -598,7 +600,7 @@ def posterior_state_dist( joint_diagnose_and_state = np.outer( diagnose_given_state["ipsi"], diagnose_given_state["contra"], - ) * joint_state_dist + ) * given_state_dist # Following Bayes' theorem, this is P(Xi,Xc|Zi=zi,Zc=zc) which is given by # P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) / P(Zi=zi,Zc=zc) return joint_diagnose_and_state / np.sum(joint_diagnose_and_state) @@ -606,34 +608,28 @@ def posterior_state_dist( def risk( self, - involvement: types.PatternType | None = None, + involvement: dict[str, types.PatternType] | None = None, given_params: types.ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: dict[str, types.DiagnoseType] | None = None, t_stage: str = "early", mode: Literal["HMM", "BN"] = "HMM", ) -> float: - """Compute risk of an ``involvement`` pattern, given parameters and diagnoses. + """Compute risk of the ``involvement`` patterns, given parameters and diagnoses. - The parameters can be set via the ``given_params`` and ``given_params``, both - of which are passed to the :py:meth:`.set_params` method. The - ``given_diagnoses`` must be a dictionary mapping the side of the neck to a - :py:obj:`.types.DiagnoseType`. + The ``involvement`` of interest is expected to be a :py:obj:`.PatternType` for + each side of the neck (``"ipsi"`` and ``"contra"``). This method then + marginalizes over those posterior state probabilities that match the + ``involvement`` patterns. - Note: - The computation is much faster if no parameters are given, since then the - transition matrix does not need to be recomputed. - - See Also: - :py:meth:`.Unilateral.risk` - The unilateral method for computing the risk of an involvment pattern. - :py:meth:`.Bilateral.comp_posterior_joint_state_dist` - This method computes the joint distribution over ipsi- and - contralateral states, given the parameters and diagnoses. The risk then - only marginalizes over the states that match the involvement pattern. + If ``involvement`` is not provided, the method returns the posterior state + distribution as computed by the :py:meth:`.posterior_state_dist` method. See + its docstring for more details on the remaining arguments. """ # TODO: test this method posterior_state_probs = self.posterior_state_dist( given_params=given_params, + given_state_dist=given_state_dist, given_diagnoses=given_diagnoses, t_stage=t_stage, mode=mode, @@ -647,7 +643,7 @@ def risk( side_graph = getattr(self, side).graph marginalize_over_states[side] = matrix.compute_encoding( lnls=side_graph.lnls.keys(), - pattern=involvement[side], + pattern=involvement.get(side, {}), base=3 if self.is_trinary else 2, ) return ( diff --git a/lymph/models/midline.py b/lymph/models/midline.py index 6da5a86..dff1562 100644 --- a/lymph/models/midline.py +++ b/lymph/models/midline.py @@ -141,7 +141,7 @@ def __init__( uni_kwargs=uni_kwargs, is_symmetric=self.is_symmetric, ) - other_children["unknown"] = self._unknown + other_children["unknown"] = self.unknown if use_mixing: self.mixing_param = 0. @@ -534,44 +534,52 @@ def midext_evo(self) -> np.ndarray: def contra_state_dist_evo(self) -> tuple[np.ndarray, np.ndarray]: """Evolve contra side as mixture of with & without midline extension.""" - noext_contra_dist_evo = np.zeros( - shape=(self.max_time + 1, len(self.noext.contra.graph.state_list)) - ) - noext_contra_dist_evo[0,0] = 1. + noext_contra_dist_evo = self.noext.contra.state_dist_evo() + ext_contra_dist_evo = self.ext.contra.state_dist_evo() - ext_contra_dist_evo = np.zeros( - shape=(self.max_time + 1, len(self.ext.contra.graph.state_list)) - ) if not self.use_midext_evo: - noext_contra_dist_evo[0,0] = 1. - self.midext_prob - ext_contra_dist_evo[0,0] = self.midext_prob - - for t in range(self.max_time): - # When evolving over the midline extension state, there's a chance at any - # time step that the tumor grows over the midline and starts spreading to - # the contralateral side more aggressively. - if self.use_midext_evo: - noext_contra_dist_evo[t+1] = ( - (1. - self.midext_prob) * noext_contra_dist_evo[t] - ) @ self.noext.contra.transition_matrix() - ext_contra_dist_evo[t+1] = ( - self.midext_prob * noext_contra_dist_evo[t] - + ext_contra_dist_evo[t] - ) @ self.ext.contra.transition_matrix() - - # When we do not evolve, the tumor is considered lateralized or extending - # over the midline from the start. - else: - noext_contra_dist_evo[t+1] = ( - noext_contra_dist_evo[t] @ self.noext.contra.transition_matrix() - ) - ext_contra_dist_evo[t+1] = ( - ext_contra_dist_evo[t] @ self.ext.contra.transition_matrix() - ) + noext_contra_dist_evo *= (1. - self.midext_prob) + ext_contra_dist_evo *= self.midext_prob + + else: + midext_evo = self.midext_evo() + noext_contra_dist_evo *= midext_evo[:,0].reshape((-1, 1)) + ext_contra_dist_evo *= midext_evo[:,1].reshape((-1, 1)) return noext_contra_dist_evo, ext_contra_dist_evo + def state_dist( + self, + t_stage: str = "early", + central: bool = False, + mode: Literal["HMM", "BN"] = "HMM", + ) -> np.ndarray: + """Compute the joint over ipsi- & contralaleral hidden states and midline ext. + + If ``central=False``, the result has shape (2, num_states, num_states), where + the first axis is for the midline extension status, the second for the + ipsilateral state, and the third for the contralateral state. + + If ``central=True``, the result will be the state distribution of the central + model's :py:meth:`.Bilateral.state_dist` method. + """ + if central: + return self.central.state_dist(t_stage, mode) + + ipsi_dist_evo = self.ext.ipsi.state_dist_evo() + noext_contra_dist_evo, ext_contra_dist_evo = self.contra_state_dist_evo() + + if mode == "HMM": + result = np.empty(shape=(2, ipsi_dist_evo.shape[1], ipsi_dist_evo.shape[1])) + time_marg_matrix = np.diag(self.get_distribution(t_stage).pmf) + result[0] = ipsi_dist_evo.T @ time_marg_matrix @ noext_contra_dist_evo + result[1] = ipsi_dist_evo.T @ time_marg_matrix @ ext_contra_dist_evo + return result + + raise NotImplementedError("Only HMM mode is supported as of now.") + + def _hmm_likelihood(self, log: bool = True, for_t_stage: str | None = None) -> float: """Compute the likelihood of the stored data under the hidden Markov model.""" llh = 0. if log else 1. @@ -661,9 +669,10 @@ def risk( self, involvement: types.PatternType | None = None, given_params: types.ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: dict[str, types.DiagnoseType] | None = None, t_stage: str = "early", - midline_extension: bool = False, + midext: bool | None = None, central: bool = False, mode: Literal["HMM", "BN"] = "HMM", ) -> float: @@ -671,33 +680,51 @@ def risk( In addition to the arguments of the :py:meth:`.Bilateral.risk` method, this also allows specifying if the patient's tumor extended over the mid-sagittal - line (``midline_extension=True``) or if it was even located right on that line + line (``midext=True``) or if it was even located right on that line (``central=True``). - For logical reasons, ``midline_extension=False`` makes no sense if - ``central=True`` and is thus ignored. + For logical reasons, ``midext=False`` makes no sense if ``central=True`` and + is thus ignored. + + Warning: + As in the :py:meth:`.Bilateral.posterior_state_dist` method, you may + provide a precomputed (joint) state distribution in the ``given_state_dist`` + argument. Here, this ``given_state_dist`` may be a 2D array, in which case + it is assumed you know how it was computed and the arguments ``t_stage``, + ``midext``, ``central``, and ``mode`` are ignored. If it is 3D, it should + have the shape ``(2, num_states, num_states)`` and be the output of the + :py:meth:`.Midline.state_dist` method. In this case, the ``midext`` + argument is *not* ignored: It may be used to select the correct state + distribution (when ``True`` or ``False``), or marginalize over the midline + extension status (when ``midext=None``). """ - utils.safe_set_params(self, given_params) + # NOTE: When given a 2D state distribution, it does not matter which of the + # Bilateral models is used to compute the risk, since the state dist is + # is the only thing that could differ between models. + if given_state_dist is None: + utils.safe_set_params(self, given_params) + given_state_dist = self.state_dist(t_stage, central, mode) - if central: - return self.central.risk( - given_diagnoses=given_diagnoses, - t_stage=t_stage, - involvement=involvement, - mode=mode, - ) - if midline_extension: + if given_state_dist.ndim == 2: return self.ext.risk( - given_diagnoses=given_diagnoses, - t_stage=t_stage, involvement=involvement, - mode=mode, + given_state_dist=given_state_dist, + given_diagnoses=given_diagnoses, ) - return self.noext.risk( - given_diagnoses=given_diagnoses, - t_stage=t_stage, + + if central: + raise ValueError("The `given_state_dist` must be 2D for the central model.") + + if midext is None: + given_state_dist = np.sum(given_state_dist, axis=0) + else: + given_state_dist = given_state_dist[int(midext)] + given_state_dist = given_state_dist / given_state_dist.sum() + + return self.ext.risk( involvement=involvement, - mode=mode, + given_state_dist=given_state_dist, + given_diagnoses=given_diagnoses, ) diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py index bc4adbe..e1bfac1 100644 --- a/lymph/models/unilateral.py +++ b/lymph/models/unilateral.py @@ -763,6 +763,7 @@ def compute_encoding( def posterior_state_dist( self, given_params: types.ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: types.DiagnoseType | None = None, t_stage: str | int = "early", mode: Literal["HMM", "BN"] = "HMM", @@ -784,17 +785,19 @@ def posterior_state_dist( for the hidden Markov model (``"HMM"``) or the Bayesian network (``"BN"``). In case of the Bayesian network mode, the ``t_stage`` parameter is ignored. - Note: - The computation is much faster if no parameters are given, since then the - transition matrix does not need to be recomputed. + Warning: + To speed up repetitive computations, one can provide precomputed state + distributions via the ``given_state_dist`` parameter. When provided, the + method will ignore the ``given_params``, ``t_stage``, and ``mode`` + arguments, but compute the posterior much quicker. """ - # in contrast to when computing the likelihood, we do want to raise an error - # here if the parameters are invalid, since we want to know if the user - # provided invalid parameters. In the likelihood, we rather return a zero - # likelihood to tell the inference algorithm that the parameters are invalid. - utils.safe_set_params(self, given_params) - # vector P(X=x) of probabilities of arriving in state x (marginalized over time) - state_dist = self.state_dist(t_stage, mode=mode) + if given_state_dist is None: + # in contrast to when computing the likelihood, we do want to raise an error + # here if the parameters are invalid, since we want to know if the user + # provided invalid parameters. + utils.safe_set_params(self, given_params) + # vector P(X=x) of probs of arriving in state x (marginalized over time) + given_state_dist = self.state_dist(t_stage, mode=mode) if given_diagnoses is None: given_diagnoses = {} @@ -804,7 +807,7 @@ def posterior_state_dist( diagnose_given_state = diagnose_encoding @ self.observation_matrix().T # multiply P(Z=z|X) * P(X) elementwise to get vector of joint probs P(Z=z,X) - joint_diagnose_and_state = state_dist * diagnose_given_state + joint_diagnose_and_state = given_state_dist * diagnose_given_state # compute vector of probabilities for all possible involvements given the # specified diagnosis P(X|Z=z) = P(Z=z,X) / P(X), where P(X) = sum_z P(Z=z,X) @@ -815,29 +818,28 @@ def risk( self, involvement: types.PatternType | None = None, given_params: types.ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: dict[str, types.PatternType] | None = None, t_stage: str = "early", mode: Literal["HMM", "BN"] = "HMM", ) -> float | np.ndarray: - """Compute risk of a certain involvement, given a patient's diagnosis. + """Compute risk of a certain ``involvement``, using the ``given_diagnoses``. If an ``involvement`` pattern of interest is provided, this method computes the risk of seeing just that pattern for the set of given parameters and a dictionary of diagnoses for each modality. - Using the ``mode`` parameter, the risk can be computed either for the hidden - Markov model (``"HMM"``) or the Bayesian network (``"BN"``). In case of the - Bayesian network mode, the ``t_stage`` parameter is ignored. - - Note: - The computation is much faster if no parameters are given, since then the - transition matrix does not need to be recomputed. - - See Also: - :py:meth:`posterior_state_dist` + If no ``involvement`` is provided, this will simply return the posterior + distribution over hidden states, given the diagnoses, as computed by the + :py:meth:`.posterior_state_dist` method. See its documentaiton for more + details about the arguments and the return value. """ posterior_state_dist = self.posterior_state_dist( - given_params, given_diagnoses, t_stage, mode, + given_params=given_params, + given_state_dist=given_state_dist, + given_diagnoses=given_diagnoses, + t_stage=t_stage, + mode=mode, ) if involvement is None: diff --git a/lymph/types.py b/lymph/types.py index e07392a..5def3aa 100644 --- a/lymph/types.py +++ b/lymph/types.py @@ -149,7 +149,8 @@ def risk( self, involvement: PatternType | None = None, given_params: ParamsType | None = None, + given_state_dist: np.ndarray | None = None, given_diagnoses: dict[str, PatternType] | None = None, ) -> float | np.ndarray: - """Return the risk of ``involvement``, given the parameters and diagnoses.""" + """Return the risk of ``involvement``, given params/state_dist and diagnoses.""" raise NotImplementedError From 64e5893324be400b30220406804709e416b78ff0 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:02:04 +0100 Subject: [PATCH 7/9] test: add checks for midline risk Related: #80 --- tests/binary_midline_test.py | 151 ++++++++++++++++++++--------------- tests/fixtures.py | 39 ++++++++- 2 files changed, 122 insertions(+), 68 deletions(-) diff --git a/tests/binary_midline_test.py b/tests/binary_midline_test.py index 98caa80..72608d9 100644 --- a/tests/binary_midline_test.py +++ b/tests/binary_midline_test.py @@ -1,7 +1,6 @@ """ Test the midline model for the binary case. """ -from typing import Literal import numpy as np import pandas as pd @@ -11,32 +10,12 @@ from . import fixtures -class MidlineSetParamsTestCase(fixtures.IgnoreWarningsTestCase): +class MidlineSetParamsTestCase( + fixtures.MidlineFixtureMixin, + fixtures.IgnoreWarningsTestCase, +): """Check that the complex parameter assignment works correctly.""" - def setUp( - self, - seed: int = 42, - graph_size: Literal["small", "medium", "large"] = "small", - use_mixing: bool = True, - use_central: bool = True, - is_symmetric: dict[str, bool] | None = None, - ) -> None: - super().setUp() - self.rng = np.random.default_rng(seed) - graph_dict = fixtures.get_graph(graph_size) - if is_symmetric is None: - is_symmetric = {"tumor_spread": False, "lnl_spread": True} - - self.model = models.Midline( - graph_dict=graph_dict, - is_symmetric=is_symmetric, - use_mixing=use_mixing, - use_central=use_central, - use_midext_evo=False, - ) - - def test_init(self) -> None: """Check some basic attributes.""" self.assertTrue(self.model.use_central) @@ -82,49 +61,20 @@ def test_get_set_params_order(self) -> None: self.assertEqual(params_to_set[:-1].tolist(), returned_params) -class MidlineLikelihoodTestCase(fixtures.IgnoreWarningsTestCase): +class MidlineLikelihoodTestCase( + fixtures.MidlineFixtureMixin, + fixtures.IgnoreWarningsTestCase, +): """Check that the likelihood function works correctly.""" - def setUp( - self, - seed: int = 42, - graph_size: Literal["small", "medium", "large"] = "small", - use_mixing: bool = True, - use_central: bool = False, - use_midext_evo: bool = True, - is_symmetric: dict[str, bool] | None = None, - ) -> None: + def setUp(self) -> None: + """Set up the test case.""" super().setUp() - self.rng = np.random.default_rng(seed) - graph_dict = fixtures.get_graph(graph_size) - if is_symmetric is None: - is_symmetric = {"tumor_spread": False, "lnl_spread": True} - - self.model = models.Midline( - graph_dict=graph_dict, - is_symmetric=is_symmetric, - use_mixing=use_mixing, - use_central=use_central, - use_midext_evo=use_midext_evo, - ) - self.model.set_distribution( - "early", - fixtures.create_random_dist( - type_="frozen", - max_time=self.model.max_time, - rng=self.rng, - ), - ) - self.model.set_distribution( - "late", - fixtures.create_random_dist( - type_="parametric", - max_time=self.model.max_time, - rng=self.rng, - ), - ) + self.init_diag_time_dists(early="frozen", late="parametric") self.model.set_modality("pathology", spec=1., sens=1., kind="pathological") - self.model.load_patient_data(pd.read_csv("./tests/data/2021-clb-oropharynx.csv", header=[0,1,2])) + self.model.load_patient_data( + pd.read_csv("./tests/data/2021-clb-oropharynx.csv", header=[0,1,2]), + ) def test_likelihood(self) -> None: @@ -142,6 +92,79 @@ def test_likelihood(self) -> None: self.assertLessEqual(self.model.likelihood(), 0) +class MidlineRiskTestCase( + fixtures.MidlineFixtureMixin, + fixtures.IgnoreWarningsTestCase, +): + """Check that the risk method works correctly.""" + + def setUp(self) -> None: + """Set up the test case.""" + super().setUp() + self.init_diag_time_dists(early="frozen", late="parametric") + self.model.set_modality("pathology", spec=1., sens=1., kind="pathological") + self.model.set_params( + midext_prob=0.1, + ipsi_TtoII_spread=0.35, + ipsi_TtoIII_spread=0.0, + contra_TtoII_spread=0.05, + contra_TtoIII_spread=0.0, + mixing=0.5, + IItoIII_spread=0.1, + late_p=0.5, + ) + + + def test_risk(self) -> None: + """Check that the risk method works correctly.""" + plain_risk = self.model.risk() + self.assertEqual(plain_risk.shape, (4,4)) + self.assertTrue(np.isclose(plain_risk.sum(), 1.0)) + self.assertTrue(np.allclose(plain_risk[1,:], 0.)) + self.assertTrue(np.allclose(plain_risk[:,1], 0.)) + + lnlIII_risk = self.model.risk(involvement={"ipsi": {"II": False, "III": True}}) + self.assertTrue(np.isscalar(lnlIII_risk)) + self.assertAlmostEqual(lnlIII_risk, 0.0) + + ipsi_lnlII_risk = self.model.risk(involvement={"ipsi": {"II": True}}) + contra_lnlII_risk = self.model.risk(involvement={"contra": {"II": True}}) + self.assertGreater(ipsi_lnlII_risk, contra_lnlII_risk) + ext_contra_lnlII_risk = self.model.risk( + involvement={"contra": {"II": True}}, + midext=True, + ) + self.assertGreater(ipsi_lnlII_risk, ext_contra_lnlII_risk) + self.assertGreater(ext_contra_lnlII_risk, contra_lnlII_risk) + noext_contra_lnlII_risk = self.model.risk( + involvement={"contra": {"II": True}}, + midext=False, + ) + self.assertGreater(contra_lnlII_risk, noext_contra_lnlII_risk) + self.assertGreater(ext_contra_lnlII_risk, noext_contra_lnlII_risk) + + + def test_risk_given_state_dist(self) -> None: + """Check how providing a state distribution works correctly.""" + state_dist_3d = self.model.state_dist(t_stage="early") + self.assertEqual(state_dist_3d.shape, (2, 4, 4)) + + risk_from_state_dist = self.model.risk(given_state_dist=state_dist_3d, midext=True) + risk_direct = self.model.risk(midext=True) + self.assertTrue(np.allclose(risk_from_state_dist, risk_direct)) + + state_dist_2d = state_dist_3d[0] / state_dist_3d[0].sum() + risk_from_state_dist = self.model.risk(given_state_dist=state_dist_2d) + risk_direct = self.model.risk(midext=False) + self.assertTrue(np.allclose(risk_from_state_dist, risk_direct)) + + state_dist_2d = state_dist_3d[1] / state_dist_3d[1].sum() + risk_from_state_dist = self.model.risk(given_state_dist=state_dist_2d) + risk_direct = self.model.risk(midext=True) + self.assertTrue(np.allclose(risk_from_state_dist, risk_direct)) + + + class MidlineDrawPatientsTestCase(fixtures.IgnoreWarningsTestCase): """Check the data generation.""" diff --git a/tests/fixtures.py b/tests/fixtures.py index 4b0903c..4273961 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -5,16 +5,15 @@ import unittest import warnings from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, Literal import numpy as np import pandas as pd import scipy as sp -import lymph from lymph import diagnose_times from lymph.modalities import Clinical, Modality, Pathological -from lymph.models import Unilateral +from lymph.models import Bilateral, Midline, Unilateral from lymph.types import DataWarning, PatternType MODALITIES = { @@ -180,7 +179,7 @@ def setUp(self): super().setUp() self.rng = np.random.default_rng(42) self.graph_dict = get_graph("large") - self.model = lymph.models.Bilateral(graph_dict=self.graph_dict, **self.model_kwargs) + self.model = Bilateral(graph_dict=self.graph_dict, **self.model_kwargs) self.init_diag_time_dists(early="frozen", late="parametric") self.model.set_params(**self.create_random_params()) self.logger = get_logger(level=logging.INFO) @@ -276,3 +275,35 @@ def load_patient_data( filepath = Path(__file__).parent / "data" / filename self.raw_data = pd.read_csv(filepath, header=[0,1,2]) self.model.load_patient_data(self.raw_data, side="ipsi") + + +class MidlineFixtureMixin: + """Mixing for testing the Midline model.""" + + def setUp( + self, + seed: int = 42, + graph_size: Literal["small", "medium", "large"] = "small", + use_mixing: bool = True, + use_central: bool = False, + use_midext_evo: bool = True, + is_symmetric: dict[str, bool] | None = None, + ) -> None: + """Initialize rng and model.""" + self.rng = np.random.default_rng(seed) + self.model = Midline( + graph_dict=get_graph(graph_size), + is_symmetric=is_symmetric or {"tumor_spread": False, "lnl_spread": True}, + use_mixing=use_mixing, + use_central=use_central, + use_midext_evo=use_midext_evo, + ) + + + def init_diag_time_dists(self, **dists) -> None: + """Init the diagnose time distributions.""" + for t_stage, type_ in dists.items(): + self.model.set_distribution( + t_stage, + create_random_dist(type_, self.model.max_time, self.rng), + ) From 07e608c5b5dce3bb24587a20a73b313362ebdbc8 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:09:33 +0100 Subject: [PATCH 8/9] test(mid): fix wrong assumption in risk test --- tests/binary_midline_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/binary_midline_test.py b/tests/binary_midline_test.py index 72608d9..c949278 100644 --- a/tests/binary_midline_test.py +++ b/tests/binary_midline_test.py @@ -15,6 +15,8 @@ class MidlineSetParamsTestCase( fixtures.IgnoreWarningsTestCase, ): """Check that the complex parameter assignment works correctly.""" + def setUp(self): + return super().setUp(use_central=True, use_midext_evo=False) def test_init(self) -> None: """Check some basic attributes.""" @@ -134,7 +136,6 @@ def test_risk(self) -> None: involvement={"contra": {"II": True}}, midext=True, ) - self.assertGreater(ipsi_lnlII_risk, ext_contra_lnlII_risk) self.assertGreater(ext_contra_lnlII_risk, contra_lnlII_risk) noext_contra_lnlII_risk = self.model.risk( involvement={"contra": {"II": True}}, From 7ca44f9120a38f4504754e45c60256cfa41b7b58 Mon Sep 17 00:00:00 2001 From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com> Date: Wed, 20 Mar 2024 18:06:10 +0100 Subject: [PATCH 9/9] chore: update changelog --- CHANGELOG.md | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 160d36c..6d87b24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,28 @@ All notable changes to this project will be documented in this file. + +## [1.1.0] - 2024-03-20 + +### Features + +- (**utils**) Add `safe_set_params()` function.\ + This checks whether the params are a dict, list, or None and handles + them accordingly. Just a convencience method that helped refactor some methods. +- Allow to pass state distributions to `posterior_state_dist()` and `risk()` methds. Fixes [#80].\ + With this, one can use precomputed state distributions to speed up + computing the posterior or risk for multiple scenarios. + +### Refactor + +- Use `safe_set_params()` across models. + +### Testing + +- Add checks for midline risk. Related [#80]. +- (**mid**) Fix wrong assumption in risk test. + + ## [1.0.0] - 2024-03-18 @@ -604,7 +626,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the - add pre-commit hook to check commit msg -[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.0.0...HEAD +[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.1.0...HEAD +[1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0 [1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0 [1.0.0.rc2]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc1...1.0.0.rc2 [1.0.0.rc1]: https://github.com/rmnldwg/lymph/compare/1.0.0.a6...1.0.0.rc1 @@ -620,6 +643,7 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the [0.4.1]: https://github.com/rmnldwg/lymph/compare/0.4.0...0.4.1 [0.4.0]: https://github.com/rmnldwg/lymph/compare/0.3.10...0.4.0 +[#80]: https://github.com/rmnldwg/lymph/issues/80 [#79]: https://github.com/rmnldwg/lymph/issues/79 [#77]: https://github.com/rmnldwg/lymph/issues/77 [#74]: https://github.com/rmnldwg/lymph/issues/74