Skip to content

Commit

Permalink
Merge pull request #82 from rmnldwg/release-1.1.0
Browse files Browse the repository at this point in the history
Release 1.1.0
  • Loading branch information
rmnldwg authored Mar 20, 2024
2 parents 6afa3cb + 7ca44f9 commit 6559f11
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 260 deletions.
26 changes: 25 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@
All notable changes to this project will be documented in this file.


<a name="1.1.0"></a>
## [1.1.0] - 2024-03-20

### Features

- (**utils**) Add `safe_set_params()` function.\
This checks whether the params are a dict, list, or None and handles
them accordingly. Just a convencience method that helped refactor some methods.
- Allow to pass state distributions to `posterior_state_dist()` and `risk()` methds. Fixes [#80].\
With this, one can use precomputed state distributions to speed up
computing the posterior or risk for multiple scenarios.

### Refactor

- Use `safe_set_params()` across models.

### Testing

- Add checks for midline risk. Related [#80].
- (**mid**) Fix wrong assumption in risk test.


<a name="1.0.0"></a>
## [1.0.0] - 2024-03-18

Expand Down Expand Up @@ -604,7 +626,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
- add pre-commit hook to check commit msg


[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.0.0...HEAD
[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.1.0...HEAD
[1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0
[1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0
[1.0.0.rc2]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc1...1.0.0.rc2
[1.0.0.rc1]: https://github.com/rmnldwg/lymph/compare/1.0.0.a6...1.0.0.rc1
Expand All @@ -620,6 +643,7 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
[0.4.1]: https://github.com/rmnldwg/lymph/compare/0.4.0...0.4.1
[0.4.0]: https://github.com/rmnldwg/lymph/compare/0.3.10...0.4.0

[#80]: https://github.com/rmnldwg/lymph/issues/80
[#79]: https://github.com/rmnldwg/lymph/issues/79
[#77]: https://github.com/rmnldwg/lymph/issues/77
[#74]: https://github.com/rmnldwg/lymph/issues/74
Expand Down
109 changes: 46 additions & 63 deletions lymph/models/bilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import numpy as np
import pandas as pd

from lymph import diagnose_times, matrix, modalities, models, types
from lymph.utils import (
add_or_mult,
early_late_mapping,
flatten,
synchronize_params,
unflatten_and_split,
)
from lymph import diagnose_times, matrix, modalities, models, types, utils

warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -189,7 +182,7 @@ def get_tumor_spread_params(
params = params["ipsi"]

if as_flat or not as_dict:
params = flatten(params)
params = utils.flatten(params)

return params if as_dict else params.values()

Expand Down Expand Up @@ -221,7 +214,7 @@ def get_lnl_spread_params(
params = params["ipsi"]

if as_flat or not as_dict:
params = flatten(params)
params = utils.flatten(params)

return params if as_dict else params.values()

Expand Down Expand Up @@ -276,7 +269,7 @@ def get_spread_params(
params.update(self.get_lnl_spread_params(as_flat=as_flat))

if as_flat or not as_dict:
params = flatten(params)
params = utils.flatten(params)

return params if as_dict else params.values()

Expand All @@ -300,14 +293,14 @@ def get_params(
params.update(self.get_distribution_params(as_flat=as_flat))

if as_flat or not as_dict:
params = flatten(params)
params = utils.flatten(params)

return params if as_dict else params.values()


def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:
"""Set the parameters of the model's spread from tumor to LNLs."""
kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])
kwargs, global_kwargs = utils.unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])

ipsi_kwargs = global_kwargs.copy()
ipsi_kwargs.update(kwargs.get("ipsi", {}))
Expand All @@ -316,7 +309,7 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]

args = self.ipsi.set_tumor_spread_params(*args, **ipsi_kwargs)
if self.is_symmetric["tumor_spread"]:
synchronize_params(
utils.synchronize_params(
get_from=self.ipsi.graph.tumor_edges,
set_to=self.contra.graph.tumor_edges,
)
Expand All @@ -328,7 +321,9 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]

def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:
"""Set the parameters of the model's spread from LNLs to tumor."""
kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])
kwargs, global_kwargs = utils.unflatten_and_split(
kwargs, expected_keys=["ipsi", "contra"],
)

ipsi_kwargs = global_kwargs.copy()
ipsi_kwargs.update(kwargs.get("ipsi", {}))
Expand All @@ -337,7 +332,7 @@ def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:

args = self.ipsi.set_lnl_spread_params(*args, **ipsi_kwargs)
if self.is_symmetric["lnl_spread"]:
synchronize_params(
utils.synchronize_params(
get_from=self.ipsi.graph.lnl_edges,
set_to=self.contra.graph.lnl_edges,
)
Expand Down Expand Up @@ -394,7 +389,7 @@ def set_params(self, *args: float, **kwargs: float) -> tuple[float]:
def load_patient_data(
self,
patient_data: pd.DataFrame,
mapping: callable | dict[int, Any] = early_late_mapping,
mapping: callable | dict[int, Any] = utils.early_late_mapping,
) -> None:
"""Load patient data into the model.
Expand Down Expand Up @@ -427,17 +422,13 @@ def state_dist(
ipsi_state_evo = self.ipsi.state_dist_evo()
contra_state_evo = self.contra.state_dist_evo()
time_marg_matrix = np.diag(self.get_distribution(t_stage).pmf)
result = ipsi_state_evo.T @ time_marg_matrix @ contra_state_evo

result = (
ipsi_state_evo.T
@ time_marg_matrix
@ contra_state_evo
)
elif mode == "BN":
ipsi_state_dist = self.ipsi.state_dist(mode=mode)
contra_state_dist = self.contra.state_dist(mode=mode)

result = np.outer(ipsi_state_dist, contra_state_dist)

else:
raise ValueError(f"Unknown mode '{mode}'.")

Expand Down Expand Up @@ -515,7 +506,7 @@ def _hmm_likelihood(self, log: bool = True, t_stage: str | None = None) -> float
self.ipsi.diagnose_matrix(stage),
joint_state_dist @ self.contra.diagnose_matrix(stage).T,
)
llh = add_or_mult(llh, patient_llhs, log)
llh = utils.add_or_mult(llh, patient_llhs, log)

return llh

Expand Down Expand Up @@ -547,12 +538,7 @@ def likelihood(
try:
# all functions and methods called here should raise a ValueError if the
# given parameters are invalid...
if given_params is None:
pass
elif isinstance(given_params, dict):
self.set_params(**given_params)
else:
self.set_params(*given_params)
utils.safe_set_params(self, given_params)
except ValueError:
return -np.inf if log else 0.

Expand All @@ -564,32 +550,36 @@ def likelihood(
raise ValueError("Invalid mode. Must be either 'HMM' or 'BN'.")


def posterior_joint_state_dist(
def posterior_state_dist(
self,
given_params: types.ParamsType | None = None,
given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str | int = "early",
mode: Literal["HMM", "BN"] = "HMM",
) -> np.ndarray:
"""Compute joint post. dist. over ipsi & contra states, ``given_diagnoses``.
The ``given_diagnoses`` is a dictionary storing a :py:obj:`.types.DiagnoseType`
for the ``"ipsi"`` and ``"contra"`` side of the neck.
The ``given_diagnoses`` is a dictionary storing one :py:obj:`.types.DiagnoseType`
each for the ``"ipsi"`` and ``"contra"`` side of the neck.
Essentially, this is the risk for any possible combination of ipsi- and
contralateral involvement, given the provided diagnoses.
Note:
The computation is much faster if no parameters are given, since then the
transition matrix does not need to be recomputed.
Warning:
As in the :py:meth:`.Unilateral.posterior_state_dist` method, one may
provide a precomputed (joint) state distribution via the ``given_state_dist``
argument (should be a square matric). In this case, the ``given_params``
are ignored and the model does not need to recompute e.g. the
:py:meth:`.transition_matrix` or :py:meth:`.state_dist`, making the
computation much faster.
See Also:
:py:meth:`.Unilateral.posterior_state_dist`
However, this will mean that ``t_stage`` and ``mode`` are also ignored,
since these are only used to compute the state distribution.
"""
if isinstance(given_params, dict):
self.set_params(**given_params)
else:
self.set_params(*given_params)
if given_state_dist is None:
utils.safe_set_params(self, given_params)
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode)

if given_diagnoses is None:
given_diagnoses = {}
Expand All @@ -606,47 +596,40 @@ def posterior_joint_state_dist(
# vector with P(Z=z|X) for each state X. A data matrix for one "patient"
diagnose_given_state[side] = diagnose_encoding @ observation_matrix.T

joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
# matrix with P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) for all states Xi,Xc.
joint_diagnose_and_state = np.outer(
diagnose_given_state["ipsi"],
diagnose_given_state["contra"],
) * joint_state_dist
) * given_state_dist
# Following Bayes' theorem, this is P(Xi,Xc|Zi=zi,Zc=zc) which is given by
# P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) / P(Zi=zi,Zc=zc)
return joint_diagnose_and_state / np.sum(joint_diagnose_and_state)


def risk(
self,
involvement: types.PatternType | None = None,
involvement: dict[str, types.PatternType] | None = None,
given_params: types.ParamsType | None = None,
given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str = "early",
mode: Literal["HMM", "BN"] = "HMM",
) -> float:
"""Compute risk of an ``involvement`` pattern, given parameters and diagnoses.
"""Compute risk of the ``involvement`` patterns, given parameters and diagnoses.
The parameters can be set via the ``given_params`` and ``given_params``, both
of which are passed to the :py:meth:`.set_params` method. The
``given_diagnoses`` must be a dictionary mapping the side of the neck to a
:py:obj:`.types.DiagnoseType`.
The ``involvement`` of interest is expected to be a :py:obj:`.PatternType` for
each side of the neck (``"ipsi"`` and ``"contra"``). This method then
marginalizes over those posterior state probabilities that match the
``involvement`` patterns.
Note:
The computation is much faster if no parameters are given, since then the
transition matrix does not need to be recomputed.
See Also:
:py:meth:`.Unilateral.risk`
The unilateral method for computing the risk of an involvment pattern.
:py:meth:`.Bilateral.comp_posterior_joint_state_dist`
This method computes the joint distribution over ipsi- and
contralateral states, given the parameters and diagnoses. The risk then
only marginalizes over the states that match the involvement pattern.
If ``involvement`` is not provided, the method returns the posterior state
distribution as computed by the :py:meth:`.posterior_state_dist` method. See
its docstring for more details on the remaining arguments.
"""
# TODO: test this method
posterior_state_probs = self.posterior_joint_state_dist(
posterior_state_probs = self.posterior_state_dist(
given_params=given_params,
given_state_dist=given_state_dist,
given_diagnoses=given_diagnoses,
t_stage=t_stage,
mode=mode,
Expand All @@ -660,7 +643,7 @@ def risk(
side_graph = getattr(self, side).graph
marginalize_over_states[side] = matrix.compute_encoding(
lnls=side_graph.lnls.keys(),
pattern=involvement[side],
pattern=involvement.get(side, {}),
base=3 if self.is_trinary else 2,
)
return (
Expand Down
Loading

0 comments on commit 6559f11

Please sign in to comment.