From 6c879c1759ce255270dbc0f4f4c50e648a45ebeb Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 10:38:27 +0100
Subject: [PATCH 1/9] fix(bi): align posterior method name
---
lymph/models/bilateral.py | 4 ++--
tests/binary_bilateral_test.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index b642689..c286e4e 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -564,7 +564,7 @@ def likelihood(
raise ValueError("Invalid mode. Must be either 'HMM' or 'BN'.")
- def posterior_joint_state_dist(
+ def posterior_state_dist(
self,
given_params: types.ParamsType | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
@@ -645,7 +645,7 @@ def risk(
only marginalizes over the states that match the involvement pattern.
"""
# TODO: test this method
- posterior_state_probs = self.posterior_joint_state_dist(
+ posterior_state_probs = self.posterior_state_dist(
given_params=given_params,
given_diagnoses=given_diagnoses,
t_stage=t_stage,
diff --git a/tests/binary_bilateral_test.py b/tests/binary_bilateral_test.py
index c90239e..d18cfa0 100644
--- a/tests/binary_bilateral_test.py
+++ b/tests/binary_bilateral_test.py
@@ -318,7 +318,7 @@ def test_posterior_state_dist(self):
random_parameters = self.create_random_params()
random_diagnoses = self.create_random_diagnoses()
- posterior = self.model.posterior_joint_state_dist(
+ posterior = self.model.posterior_state_dist(
given_params=random_parameters,
given_diagnoses=random_diagnoses,
)
From 68e97e596f70460ef812499afae2491135203267 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 10:47:49 +0100
Subject: [PATCH 2/9] change(bi): import `utils`
---
lymph/models/bilateral.py | 35 ++++++++++++++++-------------------
1 file changed, 16 insertions(+), 19 deletions(-)
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index c286e4e..2f0a776 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -7,14 +7,7 @@
import numpy as np
import pandas as pd
-from lymph import diagnose_times, matrix, modalities, models, types
-from lymph.utils import (
- add_or_mult,
- early_late_mapping,
- flatten,
- synchronize_params,
- unflatten_and_split,
-)
+from lymph import diagnose_times, matrix, modalities, models, types, utils
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
logger = logging.getLogger(__name__)
@@ -189,7 +182,7 @@ def get_tumor_spread_params(
params = params["ipsi"]
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -221,7 +214,7 @@ def get_lnl_spread_params(
params = params["ipsi"]
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -276,7 +269,7 @@ def get_spread_params(
params.update(self.get_lnl_spread_params(as_flat=as_flat))
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -300,14 +293,14 @@ def get_params(
params.update(self.get_distribution_params(as_flat=as_flat))
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:
"""Set the parameters of the model's spread from tumor to LNLs."""
- kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])
+ kwargs, global_kwargs = utils.unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])
ipsi_kwargs = global_kwargs.copy()
ipsi_kwargs.update(kwargs.get("ipsi", {}))
@@ -316,7 +309,7 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]
args = self.ipsi.set_tumor_spread_params(*args, **ipsi_kwargs)
if self.is_symmetric["tumor_spread"]:
- synchronize_params(
+ utils.synchronize_params(
get_from=self.ipsi.graph.tumor_edges,
set_to=self.contra.graph.tumor_edges,
)
@@ -328,7 +321,9 @@ def set_tumor_spread_params(self, *args: float, **kwargs: float) -> tuple[float]
def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:
"""Set the parameters of the model's spread from LNLs to tumor."""
- kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=["ipsi", "contra"])
+ kwargs, global_kwargs = utils.unflatten_and_split(
+ kwargs, expected_keys=["ipsi", "contra"],
+ )
ipsi_kwargs = global_kwargs.copy()
ipsi_kwargs.update(kwargs.get("ipsi", {}))
@@ -337,7 +332,7 @@ def set_lnl_spread_params(self, *args: float, **kwargs: float) -> tuple[float]:
args = self.ipsi.set_lnl_spread_params(*args, **ipsi_kwargs)
if self.is_symmetric["lnl_spread"]:
- synchronize_params(
+ utils.synchronize_params(
get_from=self.ipsi.graph.lnl_edges,
set_to=self.contra.graph.lnl_edges,
)
@@ -394,7 +389,7 @@ def set_params(self, *args: float, **kwargs: float) -> tuple[float]:
def load_patient_data(
self,
patient_data: pd.DataFrame,
- mapping: callable | dict[int, Any] = early_late_mapping,
+ mapping: callable | dict[int, Any] = utils.early_late_mapping,
) -> None:
"""Load patient data into the model.
@@ -515,7 +510,7 @@ def _hmm_likelihood(self, log: bool = True, t_stage: str | None = None) -> float
self.ipsi.diagnose_matrix(stage),
joint_state_dist @ self.contra.diagnose_matrix(stage).T,
)
- llh = add_or_mult(llh, patient_llhs, log)
+ llh = utils.add_or_mult(llh, patient_llhs, log)
return llh
@@ -586,7 +581,9 @@ def posterior_state_dist(
See Also:
:py:meth:`.Unilateral.posterior_state_dist`
"""
- if isinstance(given_params, dict):
+ if given_params is None:
+ pass
+ elif isinstance(given_params, dict):
self.set_params(**given_params)
else:
self.set_params(*given_params)
From 46dc125e5dba81dc8088d0a9ce8603d3a77d0a41 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 10:49:20 +0100
Subject: [PATCH 3/9] feat(utils): add `saf_set_params()` function
This checks whether the params are a dict, list, or None and handles
them accordingly.
---
lymph/utils.py | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
diff --git a/lymph/utils.py b/lymph/utils.py
index a0ebcf6..88ada41 100644
--- a/lymph/utils.py
+++ b/lymph/utils.py
@@ -347,6 +347,24 @@ def set_params_for(
return args
+def safe_set_params(
+ model: types.ModelT,
+ params: types.ParamsType | None = None,
+) -> None:
+ """Set the ``params`` of the ``model``.
+
+ This infers whether ``params`` is a dict or a list and calls the ``model``'s method
+ ``set_params()`` accordingly.
+ """
+ if params is None:
+ return
+
+ if isinstance(params, dict):
+ model.set_params(**params)
+ else:
+ model.set_params(*params)
+
+
def synchronize_params(
get_from: dict[str, types.HasGetParams],
set_to: dict[str, types.HasSetParams],
From f9292bd82240407aa5c0d65dd59ad4167c0f45a9 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 10:53:31 +0100
Subject: [PATCH 4/9] change(mid): import `utils`
---
lymph/models/midline.py | 48 +++++++++++++++--------------------------
1 file changed, 17 insertions(+), 31 deletions(-)
diff --git a/lymph/models/midline.py b/lymph/models/midline.py
index e942487..f27733d 100644
--- a/lymph/models/midline.py
+++ b/lymph/models/midline.py
@@ -7,16 +7,7 @@
import numpy as np
import pandas as pd
-from lymph import diagnose_times, matrix, modalities, models, types
-from lymph.types import DiagnoseType, PatternType
-from lymph.utils import (
- add_or_mult,
- draw_diagnoses,
- early_late_mapping,
- flatten,
- popfirst,
- unflatten_and_split,
-)
+from lymph import diagnose_times, matrix, modalities, models, types, utils
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
logger = logging.getLogger(__name__)
@@ -277,7 +268,7 @@ def get_tumor_spread_params(
}
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -311,7 +302,7 @@ def get_lnl_spread_params(
)
if as_flat or not as_dict:
- ext_lnl_params = flatten(ext_lnl_params)
+ ext_lnl_params = utils.flatten(ext_lnl_params)
return ext_lnl_params if as_dict else ext_lnl_params.values()
@@ -338,7 +329,7 @@ def get_spread_params(
params["contra"].update(lnl_spread_params["contra"])
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -360,7 +351,7 @@ def get_params(
params.update(self.get_distribution_params(as_flat=as_flat))
if as_flat or not as_dict:
- params = flatten(params)
+ params = utils.flatten(params)
return params if as_dict else params.values()
@@ -377,7 +368,7 @@ def set_tumor_spread_params(
:py:meth:`get_tumor_spread_params` method returns for an insight in what you
can provide.
"""
- kwargs, global_kwargs = unflatten_and_split(
+ kwargs, global_kwargs = utils.unflatten_and_split(
kwargs, expected_keys=["ipsi", "noext", "ext", "contra"],
)
@@ -394,7 +385,7 @@ def set_tumor_spread_params(
contra_kwargs = global_kwargs.copy()
contra_kwargs.update(kwargs.get("contra", {}))
args = self.noext.contra.set_tumor_spread_params(*args, **contra_kwargs)
- mixing_param, args = popfirst(args)
+ mixing_param, args = utils.popfirst(args)
mixing_param = global_kwargs.get("mixing", mixing_param) or self.mixing_param
self.mixing_param = global_kwargs.get("mixing", mixing_param)
@@ -429,7 +420,7 @@ def set_lnl_spread_params(self, *args: float, **kwargs: float) -> Iterable[float
three instances of :py:class:`~.Bilateral` depending on the value of the
``use_central`` attribute.
"""
- kwargs, global_kwargs = unflatten_and_split(
+ kwargs, global_kwargs = utils.unflatten_and_split(
kwargs, expected_keys=["ipsi", "noext", "ext", "contra"],
)
ipsi_kwargs = global_kwargs.copy()
@@ -474,7 +465,7 @@ def set_params(
Combines the calls to :py:meth:`.set_spread_params` and
:py:meth:`.set_distribution_params`.
"""
- first, args = popfirst(args)
+ first, args = utils.popfirst(args)
self.midext_prob = kwargs.get("midext_prob", first) or self.midext_prob
args = self.set_spread_params(*args, **kwargs)
return self.set_distribution_params(*args, **kwargs)
@@ -483,7 +474,7 @@ def set_params(
def load_patient_data(
self,
patient_data: pd.DataFrame,
- mapping: callable = early_late_mapping,
+ mapping: callable = utils.early_late_mapping,
) -> None:
"""Load patient data into the model.
@@ -607,14 +598,14 @@ def _hmm_likelihood(self, log: bool = True, for_t_stage: str | None = None) -> f
_model.ipsi.diagnose_matrix(stage),
joint_state_dist @ _model.contra.diagnose_matrix(stage).T
)
- llh = add_or_mult(llh, patient_llhs, log=log)
+ llh = utils.add_or_mult(llh, patient_llhs, log=log)
try:
marg_patient_llhs = matrix.fast_trace(
self.unknown.ipsi.diagnose_matrix(stage),
marg_joint_state_dist @ self.unknown.contra.diagnose_matrix(stage).T
)
- llh = add_or_mult(llh, marg_patient_llhs, log=log)
+ llh = utils.add_or_mult(llh, marg_patient_llhs, log=log)
except AttributeError:
# an AttributeError is raised both when the model has no `unknown`
# attribute and when no data is loaded in the `unknown` model.
@@ -673,9 +664,9 @@ def likelihood(
def risk(
self,
- involvement: PatternType | None = None,
+ involvement: types.PatternType | None = None,
given_params: types.ParamsType | None = None,
- given_diagnoses: dict[str, DiagnoseType] | None = None,
+ given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str = "early",
midline_extension: bool = False,
central: bool = False,
@@ -691,12 +682,7 @@ def risk(
For logical reasons, ``midline_extension=False`` makes no sense if
``central=True`` and is thus ignored.
"""
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
if central:
return self.central.risk(
@@ -768,7 +754,7 @@ def draw_patients(
drawn_diags = np.empty(shape=(num, len(self.ext.ipsi.obs_list)))
for case in ["ext", "noext"]:
case_model = getattr(self, case)
- drawn_ipsi_diags = draw_diagnoses(
+ drawn_ipsi_diags = utils.draw_diagnoses(
diagnose_times=drawn_diag_times[drawn_midexts == (case == "ext")],
state_evolution=ipsi_evo,
observation_matrix=case_model.ipsi.observation_matrix(),
@@ -776,7 +762,7 @@ def draw_patients(
rng=rng,
seed=seed,
)
- drawn_contra_diags = draw_diagnoses(
+ drawn_contra_diags = utils.draw_diagnoses(
diagnose_times=drawn_diag_times[drawn_midexts == (case == "ext")],
state_evolution=case_model.contra.state_dist_evo(),
observation_matrix=case_model.contra.observation_matrix(),
From 09a03006124e63b863989e83f34154da9482e7ac Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 10:54:21 +0100
Subject: [PATCH 5/9] refactor: use `safe_set_params()` across models
---
lymph/models/bilateral.py | 16 +++-------------
lymph/models/midline.py | 7 +------
lymph/models/unilateral.py | 21 +++++----------------
3 files changed, 9 insertions(+), 35 deletions(-)
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index 2f0a776..dbf9ad1 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -542,12 +542,7 @@ def likelihood(
try:
# all functions and methods called here should raise a ValueError if the
# given parameters are invalid...
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
except ValueError:
return -np.inf if log else 0.
@@ -581,12 +576,8 @@ def posterior_state_dist(
See Also:
:py:meth:`.Unilateral.posterior_state_dist`
"""
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
+ joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
if given_diagnoses is None:
given_diagnoses = {}
@@ -603,7 +594,6 @@ def posterior_state_dist(
# vector with P(Z=z|X) for each state X. A data matrix for one "patient"
diagnose_given_state[side] = diagnose_encoding @ observation_matrix.T
- joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
# matrix with P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) for all states Xi,Xc.
joint_diagnose_and_state = np.outer(
diagnose_given_state["ipsi"],
diff --git a/lymph/models/midline.py b/lymph/models/midline.py
index f27733d..6da5a86 100644
--- a/lymph/models/midline.py
+++ b/lymph/models/midline.py
@@ -647,12 +647,7 @@ def likelihood(
try:
# all functions and methods called here should raise a ValueError if the
# given parameters are invalid...
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
except ValueError:
return -np.inf if log else 0.
diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py
index e0ea4d6..bc4adbe 100644
--- a/lymph/models/unilateral.py
+++ b/lymph/models/unilateral.py
@@ -8,7 +8,7 @@
import pandas as pd
from cachetools import LRUCache
-from lymph import diagnose_times, graph, matrix, modalities, types
+from lymph import diagnose_times, graph, matrix, modalities, types, utils
# pylint: disable=unused-import
from lymph.utils import ( # nopycln: import
@@ -728,12 +728,7 @@ def likelihood(
try:
# all functions and methods called here should raise a ValueError if the
# given parameters are invalid...
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
except ValueError:
return -np.inf if log else 0.
@@ -797,12 +792,9 @@ def posterior_state_dist(
# here if the parameters are invalid, since we want to know if the user
# provided invalid parameters. In the likelihood, we rather return a zero
# likelihood to tell the inference algorithm that the parameters are invalid.
- if given_params is None:
- pass
- elif isinstance(given_params, dict):
- self.set_params(**given_params)
- else:
- self.set_params(*given_params)
+ utils.safe_set_params(self, given_params)
+ # vector P(X=x) of probabilities of arriving in state x (marginalized over time)
+ state_dist = self.state_dist(t_stage, mode=mode)
if given_diagnoses is None:
given_diagnoses = {}
@@ -811,9 +803,6 @@ def posterior_state_dist(
# vector containing P(Z=z|X). Essentially a data matrix for one patient
diagnose_given_state = diagnose_encoding @ self.observation_matrix().T
- # vector P(X=x) of probabilities of arriving in state x (marginalized over time)
- state_dist = self.state_dist(t_stage, mode=mode)
-
# multiply P(Z=z|X) * P(X) elementwise to get vector of joint probs P(Z=z,X)
joint_diagnose_and_state = state_dist * diagnose_given_state
From 615c7cbcbcbdc0cc5c6efb06affe5f087145bb9a Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 18:01:40 +0100
Subject: [PATCH 6/9] feat: allow to pass state dists to posterior/risk
With this, one can use precomputed state distributions to speed up
computing the posterior or risk for multiple scenarios.
Fixes: #80
---
lymph/models/bilateral.py | 64 +++++++++---------
lymph/models/midline.py | 133 ++++++++++++++++++++++---------------
lymph/models/unilateral.py | 48 ++++++-------
lymph/types.py | 3 +-
4 files changed, 137 insertions(+), 111 deletions(-)
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index dbf9ad1..899c5c2 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -422,17 +422,13 @@ def state_dist(
ipsi_state_evo = self.ipsi.state_dist_evo()
contra_state_evo = self.contra.state_dist_evo()
time_marg_matrix = np.diag(self.get_distribution(t_stage).pmf)
+ result = ipsi_state_evo.T @ time_marg_matrix @ contra_state_evo
- result = (
- ipsi_state_evo.T
- @ time_marg_matrix
- @ contra_state_evo
- )
elif mode == "BN":
ipsi_state_dist = self.ipsi.state_dist(mode=mode)
contra_state_dist = self.contra.state_dist(mode=mode)
-
result = np.outer(ipsi_state_dist, contra_state_dist)
+
else:
raise ValueError(f"Unknown mode '{mode}'.")
@@ -557,27 +553,33 @@ def likelihood(
def posterior_state_dist(
self,
given_params: types.ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str | int = "early",
mode: Literal["HMM", "BN"] = "HMM",
) -> np.ndarray:
"""Compute joint post. dist. over ipsi & contra states, ``given_diagnoses``.
- The ``given_diagnoses`` is a dictionary storing a :py:obj:`.types.DiagnoseType`
- for the ``"ipsi"`` and ``"contra"`` side of the neck.
+ The ``given_diagnoses`` is a dictionary storing one :py:obj:`.types.DiagnoseType`
+ each for the ``"ipsi"`` and ``"contra"`` side of the neck.
Essentially, this is the risk for any possible combination of ipsi- and
contralateral involvement, given the provided diagnoses.
- Note:
- The computation is much faster if no parameters are given, since then the
- transition matrix does not need to be recomputed.
+ Warning:
+ As in the :py:meth:`.Unilateral.posterior_state_dist` method, one may
+ provide a precomputed (joint) state distribution via the ``given_state_dist``
+ argument (should be a square matric). In this case, the ``given_params``
+ are ignored and the model does not need to recompute e.g. the
+ :py:meth:`.transition_matrix` or :py:meth:`.state_dist`, making the
+ computation much faster.
- See Also:
- :py:meth:`.Unilateral.posterior_state_dist`
+ However, this will mean that ``t_stage`` and ``mode`` are also ignored,
+ since these are only used to compute the state distribution.
"""
- utils.safe_set_params(self, given_params)
- joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
+ if given_state_dist is None:
+ utils.safe_set_params(self, given_params)
+ given_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
if given_diagnoses is None:
given_diagnoses = {}
@@ -598,7 +600,7 @@ def posterior_state_dist(
joint_diagnose_and_state = np.outer(
diagnose_given_state["ipsi"],
diagnose_given_state["contra"],
- ) * joint_state_dist
+ ) * given_state_dist
# Following Bayes' theorem, this is P(Xi,Xc|Zi=zi,Zc=zc) which is given by
# P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) / P(Zi=zi,Zc=zc)
return joint_diagnose_and_state / np.sum(joint_diagnose_and_state)
@@ -606,34 +608,28 @@ def posterior_state_dist(
def risk(
self,
- involvement: types.PatternType | None = None,
+ involvement: dict[str, types.PatternType] | None = None,
given_params: types.ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str = "early",
mode: Literal["HMM", "BN"] = "HMM",
) -> float:
- """Compute risk of an ``involvement`` pattern, given parameters and diagnoses.
+ """Compute risk of the ``involvement`` patterns, given parameters and diagnoses.
- The parameters can be set via the ``given_params`` and ``given_params``, both
- of which are passed to the :py:meth:`.set_params` method. The
- ``given_diagnoses`` must be a dictionary mapping the side of the neck to a
- :py:obj:`.types.DiagnoseType`.
+ The ``involvement`` of interest is expected to be a :py:obj:`.PatternType` for
+ each side of the neck (``"ipsi"`` and ``"contra"``). This method then
+ marginalizes over those posterior state probabilities that match the
+ ``involvement`` patterns.
- Note:
- The computation is much faster if no parameters are given, since then the
- transition matrix does not need to be recomputed.
-
- See Also:
- :py:meth:`.Unilateral.risk`
- The unilateral method for computing the risk of an involvment pattern.
- :py:meth:`.Bilateral.comp_posterior_joint_state_dist`
- This method computes the joint distribution over ipsi- and
- contralateral states, given the parameters and diagnoses. The risk then
- only marginalizes over the states that match the involvement pattern.
+ If ``involvement`` is not provided, the method returns the posterior state
+ distribution as computed by the :py:meth:`.posterior_state_dist` method. See
+ its docstring for more details on the remaining arguments.
"""
# TODO: test this method
posterior_state_probs = self.posterior_state_dist(
given_params=given_params,
+ given_state_dist=given_state_dist,
given_diagnoses=given_diagnoses,
t_stage=t_stage,
mode=mode,
@@ -647,7 +643,7 @@ def risk(
side_graph = getattr(self, side).graph
marginalize_over_states[side] = matrix.compute_encoding(
lnls=side_graph.lnls.keys(),
- pattern=involvement[side],
+ pattern=involvement.get(side, {}),
base=3 if self.is_trinary else 2,
)
return (
diff --git a/lymph/models/midline.py b/lymph/models/midline.py
index 6da5a86..dff1562 100644
--- a/lymph/models/midline.py
+++ b/lymph/models/midline.py
@@ -141,7 +141,7 @@ def __init__(
uni_kwargs=uni_kwargs,
is_symmetric=self.is_symmetric,
)
- other_children["unknown"] = self._unknown
+ other_children["unknown"] = self.unknown
if use_mixing:
self.mixing_param = 0.
@@ -534,44 +534,52 @@ def midext_evo(self) -> np.ndarray:
def contra_state_dist_evo(self) -> tuple[np.ndarray, np.ndarray]:
"""Evolve contra side as mixture of with & without midline extension."""
- noext_contra_dist_evo = np.zeros(
- shape=(self.max_time + 1, len(self.noext.contra.graph.state_list))
- )
- noext_contra_dist_evo[0,0] = 1.
+ noext_contra_dist_evo = self.noext.contra.state_dist_evo()
+ ext_contra_dist_evo = self.ext.contra.state_dist_evo()
- ext_contra_dist_evo = np.zeros(
- shape=(self.max_time + 1, len(self.ext.contra.graph.state_list))
- )
if not self.use_midext_evo:
- noext_contra_dist_evo[0,0] = 1. - self.midext_prob
- ext_contra_dist_evo[0,0] = self.midext_prob
-
- for t in range(self.max_time):
- # When evolving over the midline extension state, there's a chance at any
- # time step that the tumor grows over the midline and starts spreading to
- # the contralateral side more aggressively.
- if self.use_midext_evo:
- noext_contra_dist_evo[t+1] = (
- (1. - self.midext_prob) * noext_contra_dist_evo[t]
- ) @ self.noext.contra.transition_matrix()
- ext_contra_dist_evo[t+1] = (
- self.midext_prob * noext_contra_dist_evo[t]
- + ext_contra_dist_evo[t]
- ) @ self.ext.contra.transition_matrix()
-
- # When we do not evolve, the tumor is considered lateralized or extending
- # over the midline from the start.
- else:
- noext_contra_dist_evo[t+1] = (
- noext_contra_dist_evo[t] @ self.noext.contra.transition_matrix()
- )
- ext_contra_dist_evo[t+1] = (
- ext_contra_dist_evo[t] @ self.ext.contra.transition_matrix()
- )
+ noext_contra_dist_evo *= (1. - self.midext_prob)
+ ext_contra_dist_evo *= self.midext_prob
+
+ else:
+ midext_evo = self.midext_evo()
+ noext_contra_dist_evo *= midext_evo[:,0].reshape((-1, 1))
+ ext_contra_dist_evo *= midext_evo[:,1].reshape((-1, 1))
return noext_contra_dist_evo, ext_contra_dist_evo
+ def state_dist(
+ self,
+ t_stage: str = "early",
+ central: bool = False,
+ mode: Literal["HMM", "BN"] = "HMM",
+ ) -> np.ndarray:
+ """Compute the joint over ipsi- & contralaleral hidden states and midline ext.
+
+ If ``central=False``, the result has shape (2, num_states, num_states), where
+ the first axis is for the midline extension status, the second for the
+ ipsilateral state, and the third for the contralateral state.
+
+ If ``central=True``, the result will be the state distribution of the central
+ model's :py:meth:`.Bilateral.state_dist` method.
+ """
+ if central:
+ return self.central.state_dist(t_stage, mode)
+
+ ipsi_dist_evo = self.ext.ipsi.state_dist_evo()
+ noext_contra_dist_evo, ext_contra_dist_evo = self.contra_state_dist_evo()
+
+ if mode == "HMM":
+ result = np.empty(shape=(2, ipsi_dist_evo.shape[1], ipsi_dist_evo.shape[1]))
+ time_marg_matrix = np.diag(self.get_distribution(t_stage).pmf)
+ result[0] = ipsi_dist_evo.T @ time_marg_matrix @ noext_contra_dist_evo
+ result[1] = ipsi_dist_evo.T @ time_marg_matrix @ ext_contra_dist_evo
+ return result
+
+ raise NotImplementedError("Only HMM mode is supported as of now.")
+
+
def _hmm_likelihood(self, log: bool = True, for_t_stage: str | None = None) -> float:
"""Compute the likelihood of the stored data under the hidden Markov model."""
llh = 0. if log else 1.
@@ -661,9 +669,10 @@ def risk(
self,
involvement: types.PatternType | None = None,
given_params: types.ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.DiagnoseType] | None = None,
t_stage: str = "early",
- midline_extension: bool = False,
+ midext: bool | None = None,
central: bool = False,
mode: Literal["HMM", "BN"] = "HMM",
) -> float:
@@ -671,33 +680,51 @@ def risk(
In addition to the arguments of the :py:meth:`.Bilateral.risk` method, this
also allows specifying if the patient's tumor extended over the mid-sagittal
- line (``midline_extension=True``) or if it was even located right on that line
+ line (``midext=True``) or if it was even located right on that line
(``central=True``).
- For logical reasons, ``midline_extension=False`` makes no sense if
- ``central=True`` and is thus ignored.
+ For logical reasons, ``midext=False`` makes no sense if ``central=True`` and
+ is thus ignored.
+
+ Warning:
+ As in the :py:meth:`.Bilateral.posterior_state_dist` method, you may
+ provide a precomputed (joint) state distribution in the ``given_state_dist``
+ argument. Here, this ``given_state_dist`` may be a 2D array, in which case
+ it is assumed you know how it was computed and the arguments ``t_stage``,
+ ``midext``, ``central``, and ``mode`` are ignored. If it is 3D, it should
+ have the shape ``(2, num_states, num_states)`` and be the output of the
+ :py:meth:`.Midline.state_dist` method. In this case, the ``midext``
+ argument is *not* ignored: It may be used to select the correct state
+ distribution (when ``True`` or ``False``), or marginalize over the midline
+ extension status (when ``midext=None``).
"""
- utils.safe_set_params(self, given_params)
+ # NOTE: When given a 2D state distribution, it does not matter which of the
+ # Bilateral models is used to compute the risk, since the state dist is
+ # is the only thing that could differ between models.
+ if given_state_dist is None:
+ utils.safe_set_params(self, given_params)
+ given_state_dist = self.state_dist(t_stage, central, mode)
- if central:
- return self.central.risk(
- given_diagnoses=given_diagnoses,
- t_stage=t_stage,
- involvement=involvement,
- mode=mode,
- )
- if midline_extension:
+ if given_state_dist.ndim == 2:
return self.ext.risk(
- given_diagnoses=given_diagnoses,
- t_stage=t_stage,
involvement=involvement,
- mode=mode,
+ given_state_dist=given_state_dist,
+ given_diagnoses=given_diagnoses,
)
- return self.noext.risk(
- given_diagnoses=given_diagnoses,
- t_stage=t_stage,
+
+ if central:
+ raise ValueError("The `given_state_dist` must be 2D for the central model.")
+
+ if midext is None:
+ given_state_dist = np.sum(given_state_dist, axis=0)
+ else:
+ given_state_dist = given_state_dist[int(midext)]
+ given_state_dist = given_state_dist / given_state_dist.sum()
+
+ return self.ext.risk(
involvement=involvement,
- mode=mode,
+ given_state_dist=given_state_dist,
+ given_diagnoses=given_diagnoses,
)
diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py
index bc4adbe..e1bfac1 100644
--- a/lymph/models/unilateral.py
+++ b/lymph/models/unilateral.py
@@ -763,6 +763,7 @@ def compute_encoding(
def posterior_state_dist(
self,
given_params: types.ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: types.DiagnoseType | None = None,
t_stage: str | int = "early",
mode: Literal["HMM", "BN"] = "HMM",
@@ -784,17 +785,19 @@ def posterior_state_dist(
for the hidden Markov model (``"HMM"``) or the Bayesian network (``"BN"``).
In case of the Bayesian network mode, the ``t_stage`` parameter is ignored.
- Note:
- The computation is much faster if no parameters are given, since then the
- transition matrix does not need to be recomputed.
+ Warning:
+ To speed up repetitive computations, one can provide precomputed state
+ distributions via the ``given_state_dist`` parameter. When provided, the
+ method will ignore the ``given_params``, ``t_stage``, and ``mode``
+ arguments, but compute the posterior much quicker.
"""
- # in contrast to when computing the likelihood, we do want to raise an error
- # here if the parameters are invalid, since we want to know if the user
- # provided invalid parameters. In the likelihood, we rather return a zero
- # likelihood to tell the inference algorithm that the parameters are invalid.
- utils.safe_set_params(self, given_params)
- # vector P(X=x) of probabilities of arriving in state x (marginalized over time)
- state_dist = self.state_dist(t_stage, mode=mode)
+ if given_state_dist is None:
+ # in contrast to when computing the likelihood, we do want to raise an error
+ # here if the parameters are invalid, since we want to know if the user
+ # provided invalid parameters.
+ utils.safe_set_params(self, given_params)
+ # vector P(X=x) of probs of arriving in state x (marginalized over time)
+ given_state_dist = self.state_dist(t_stage, mode=mode)
if given_diagnoses is None:
given_diagnoses = {}
@@ -804,7 +807,7 @@ def posterior_state_dist(
diagnose_given_state = diagnose_encoding @ self.observation_matrix().T
# multiply P(Z=z|X) * P(X) elementwise to get vector of joint probs P(Z=z,X)
- joint_diagnose_and_state = state_dist * diagnose_given_state
+ joint_diagnose_and_state = given_state_dist * diagnose_given_state
# compute vector of probabilities for all possible involvements given the
# specified diagnosis P(X|Z=z) = P(Z=z,X) / P(X), where P(X) = sum_z P(Z=z,X)
@@ -815,29 +818,28 @@ def risk(
self,
involvement: types.PatternType | None = None,
given_params: types.ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, types.PatternType] | None = None,
t_stage: str = "early",
mode: Literal["HMM", "BN"] = "HMM",
) -> float | np.ndarray:
- """Compute risk of a certain involvement, given a patient's diagnosis.
+ """Compute risk of a certain ``involvement``, using the ``given_diagnoses``.
If an ``involvement`` pattern of interest is provided, this method computes
the risk of seeing just that pattern for the set of given parameters and a
dictionary of diagnoses for each modality.
- Using the ``mode`` parameter, the risk can be computed either for the hidden
- Markov model (``"HMM"``) or the Bayesian network (``"BN"``). In case of the
- Bayesian network mode, the ``t_stage`` parameter is ignored.
-
- Note:
- The computation is much faster if no parameters are given, since then the
- transition matrix does not need to be recomputed.
-
- See Also:
- :py:meth:`posterior_state_dist`
+ If no ``involvement`` is provided, this will simply return the posterior
+ distribution over hidden states, given the diagnoses, as computed by the
+ :py:meth:`.posterior_state_dist` method. See its documentaiton for more
+ details about the arguments and the return value.
"""
posterior_state_dist = self.posterior_state_dist(
- given_params, given_diagnoses, t_stage, mode,
+ given_params=given_params,
+ given_state_dist=given_state_dist,
+ given_diagnoses=given_diagnoses,
+ t_stage=t_stage,
+ mode=mode,
)
if involvement is None:
diff --git a/lymph/types.py b/lymph/types.py
index e07392a..5def3aa 100644
--- a/lymph/types.py
+++ b/lymph/types.py
@@ -149,7 +149,8 @@ def risk(
self,
involvement: PatternType | None = None,
given_params: ParamsType | None = None,
+ given_state_dist: np.ndarray | None = None,
given_diagnoses: dict[str, PatternType] | None = None,
) -> float | np.ndarray:
- """Return the risk of ``involvement``, given the parameters and diagnoses."""
+ """Return the risk of ``involvement``, given params/state_dist and diagnoses."""
raise NotImplementedError
From 64e5893324be400b30220406804709e416b78ff0 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 18:02:04 +0100
Subject: [PATCH 7/9] test: add checks for midline risk
Related: #80
---
tests/binary_midline_test.py | 151 ++++++++++++++++++++---------------
tests/fixtures.py | 39 ++++++++-
2 files changed, 122 insertions(+), 68 deletions(-)
diff --git a/tests/binary_midline_test.py b/tests/binary_midline_test.py
index 98caa80..72608d9 100644
--- a/tests/binary_midline_test.py
+++ b/tests/binary_midline_test.py
@@ -1,7 +1,6 @@
"""
Test the midline model for the binary case.
"""
-from typing import Literal
import numpy as np
import pandas as pd
@@ -11,32 +10,12 @@
from . import fixtures
-class MidlineSetParamsTestCase(fixtures.IgnoreWarningsTestCase):
+class MidlineSetParamsTestCase(
+ fixtures.MidlineFixtureMixin,
+ fixtures.IgnoreWarningsTestCase,
+):
"""Check that the complex parameter assignment works correctly."""
- def setUp(
- self,
- seed: int = 42,
- graph_size: Literal["small", "medium", "large"] = "small",
- use_mixing: bool = True,
- use_central: bool = True,
- is_symmetric: dict[str, bool] | None = None,
- ) -> None:
- super().setUp()
- self.rng = np.random.default_rng(seed)
- graph_dict = fixtures.get_graph(graph_size)
- if is_symmetric is None:
- is_symmetric = {"tumor_spread": False, "lnl_spread": True}
-
- self.model = models.Midline(
- graph_dict=graph_dict,
- is_symmetric=is_symmetric,
- use_mixing=use_mixing,
- use_central=use_central,
- use_midext_evo=False,
- )
-
-
def test_init(self) -> None:
"""Check some basic attributes."""
self.assertTrue(self.model.use_central)
@@ -82,49 +61,20 @@ def test_get_set_params_order(self) -> None:
self.assertEqual(params_to_set[:-1].tolist(), returned_params)
-class MidlineLikelihoodTestCase(fixtures.IgnoreWarningsTestCase):
+class MidlineLikelihoodTestCase(
+ fixtures.MidlineFixtureMixin,
+ fixtures.IgnoreWarningsTestCase,
+):
"""Check that the likelihood function works correctly."""
- def setUp(
- self,
- seed: int = 42,
- graph_size: Literal["small", "medium", "large"] = "small",
- use_mixing: bool = True,
- use_central: bool = False,
- use_midext_evo: bool = True,
- is_symmetric: dict[str, bool] | None = None,
- ) -> None:
+ def setUp(self) -> None:
+ """Set up the test case."""
super().setUp()
- self.rng = np.random.default_rng(seed)
- graph_dict = fixtures.get_graph(graph_size)
- if is_symmetric is None:
- is_symmetric = {"tumor_spread": False, "lnl_spread": True}
-
- self.model = models.Midline(
- graph_dict=graph_dict,
- is_symmetric=is_symmetric,
- use_mixing=use_mixing,
- use_central=use_central,
- use_midext_evo=use_midext_evo,
- )
- self.model.set_distribution(
- "early",
- fixtures.create_random_dist(
- type_="frozen",
- max_time=self.model.max_time,
- rng=self.rng,
- ),
- )
- self.model.set_distribution(
- "late",
- fixtures.create_random_dist(
- type_="parametric",
- max_time=self.model.max_time,
- rng=self.rng,
- ),
- )
+ self.init_diag_time_dists(early="frozen", late="parametric")
self.model.set_modality("pathology", spec=1., sens=1., kind="pathological")
- self.model.load_patient_data(pd.read_csv("./tests/data/2021-clb-oropharynx.csv", header=[0,1,2]))
+ self.model.load_patient_data(
+ pd.read_csv("./tests/data/2021-clb-oropharynx.csv", header=[0,1,2]),
+ )
def test_likelihood(self) -> None:
@@ -142,6 +92,79 @@ def test_likelihood(self) -> None:
self.assertLessEqual(self.model.likelihood(), 0)
+class MidlineRiskTestCase(
+ fixtures.MidlineFixtureMixin,
+ fixtures.IgnoreWarningsTestCase,
+):
+ """Check that the risk method works correctly."""
+
+ def setUp(self) -> None:
+ """Set up the test case."""
+ super().setUp()
+ self.init_diag_time_dists(early="frozen", late="parametric")
+ self.model.set_modality("pathology", spec=1., sens=1., kind="pathological")
+ self.model.set_params(
+ midext_prob=0.1,
+ ipsi_TtoII_spread=0.35,
+ ipsi_TtoIII_spread=0.0,
+ contra_TtoII_spread=0.05,
+ contra_TtoIII_spread=0.0,
+ mixing=0.5,
+ IItoIII_spread=0.1,
+ late_p=0.5,
+ )
+
+
+ def test_risk(self) -> None:
+ """Check that the risk method works correctly."""
+ plain_risk = self.model.risk()
+ self.assertEqual(plain_risk.shape, (4,4))
+ self.assertTrue(np.isclose(plain_risk.sum(), 1.0))
+ self.assertTrue(np.allclose(plain_risk[1,:], 0.))
+ self.assertTrue(np.allclose(plain_risk[:,1], 0.))
+
+ lnlIII_risk = self.model.risk(involvement={"ipsi": {"II": False, "III": True}})
+ self.assertTrue(np.isscalar(lnlIII_risk))
+ self.assertAlmostEqual(lnlIII_risk, 0.0)
+
+ ipsi_lnlII_risk = self.model.risk(involvement={"ipsi": {"II": True}})
+ contra_lnlII_risk = self.model.risk(involvement={"contra": {"II": True}})
+ self.assertGreater(ipsi_lnlII_risk, contra_lnlII_risk)
+ ext_contra_lnlII_risk = self.model.risk(
+ involvement={"contra": {"II": True}},
+ midext=True,
+ )
+ self.assertGreater(ipsi_lnlII_risk, ext_contra_lnlII_risk)
+ self.assertGreater(ext_contra_lnlII_risk, contra_lnlII_risk)
+ noext_contra_lnlII_risk = self.model.risk(
+ involvement={"contra": {"II": True}},
+ midext=False,
+ )
+ self.assertGreater(contra_lnlII_risk, noext_contra_lnlII_risk)
+ self.assertGreater(ext_contra_lnlII_risk, noext_contra_lnlII_risk)
+
+
+ def test_risk_given_state_dist(self) -> None:
+ """Check how providing a state distribution works correctly."""
+ state_dist_3d = self.model.state_dist(t_stage="early")
+ self.assertEqual(state_dist_3d.shape, (2, 4, 4))
+
+ risk_from_state_dist = self.model.risk(given_state_dist=state_dist_3d, midext=True)
+ risk_direct = self.model.risk(midext=True)
+ self.assertTrue(np.allclose(risk_from_state_dist, risk_direct))
+
+ state_dist_2d = state_dist_3d[0] / state_dist_3d[0].sum()
+ risk_from_state_dist = self.model.risk(given_state_dist=state_dist_2d)
+ risk_direct = self.model.risk(midext=False)
+ self.assertTrue(np.allclose(risk_from_state_dist, risk_direct))
+
+ state_dist_2d = state_dist_3d[1] / state_dist_3d[1].sum()
+ risk_from_state_dist = self.model.risk(given_state_dist=state_dist_2d)
+ risk_direct = self.model.risk(midext=True)
+ self.assertTrue(np.allclose(risk_from_state_dist, risk_direct))
+
+
+
class MidlineDrawPatientsTestCase(fixtures.IgnoreWarningsTestCase):
"""Check the data generation."""
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 4b0903c..4273961 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -5,16 +5,15 @@
import unittest
import warnings
from pathlib import Path
-from typing import Any, Callable
+from typing import Any, Callable, Literal
import numpy as np
import pandas as pd
import scipy as sp
-import lymph
from lymph import diagnose_times
from lymph.modalities import Clinical, Modality, Pathological
-from lymph.models import Unilateral
+from lymph.models import Bilateral, Midline, Unilateral
from lymph.types import DataWarning, PatternType
MODALITIES = {
@@ -180,7 +179,7 @@ def setUp(self):
super().setUp()
self.rng = np.random.default_rng(42)
self.graph_dict = get_graph("large")
- self.model = lymph.models.Bilateral(graph_dict=self.graph_dict, **self.model_kwargs)
+ self.model = Bilateral(graph_dict=self.graph_dict, **self.model_kwargs)
self.init_diag_time_dists(early="frozen", late="parametric")
self.model.set_params(**self.create_random_params())
self.logger = get_logger(level=logging.INFO)
@@ -276,3 +275,35 @@ def load_patient_data(
filepath = Path(__file__).parent / "data" / filename
self.raw_data = pd.read_csv(filepath, header=[0,1,2])
self.model.load_patient_data(self.raw_data, side="ipsi")
+
+
+class MidlineFixtureMixin:
+ """Mixing for testing the Midline model."""
+
+ def setUp(
+ self,
+ seed: int = 42,
+ graph_size: Literal["small", "medium", "large"] = "small",
+ use_mixing: bool = True,
+ use_central: bool = False,
+ use_midext_evo: bool = True,
+ is_symmetric: dict[str, bool] | None = None,
+ ) -> None:
+ """Initialize rng and model."""
+ self.rng = np.random.default_rng(seed)
+ self.model = Midline(
+ graph_dict=get_graph(graph_size),
+ is_symmetric=is_symmetric or {"tumor_spread": False, "lnl_spread": True},
+ use_mixing=use_mixing,
+ use_central=use_central,
+ use_midext_evo=use_midext_evo,
+ )
+
+
+ def init_diag_time_dists(self, **dists) -> None:
+ """Init the diagnose time distributions."""
+ for t_stage, type_ in dists.items():
+ self.model.set_distribution(
+ t_stage,
+ create_random_dist(type_, self.model.max_time, self.rng),
+ )
From 07e608c5b5dce3bb24587a20a73b313362ebdbc8 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Tue, 19 Mar 2024 18:09:33 +0100
Subject: [PATCH 8/9] test(mid): fix wrong assumption in risk test
---
tests/binary_midline_test.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/tests/binary_midline_test.py b/tests/binary_midline_test.py
index 72608d9..c949278 100644
--- a/tests/binary_midline_test.py
+++ b/tests/binary_midline_test.py
@@ -15,6 +15,8 @@ class MidlineSetParamsTestCase(
fixtures.IgnoreWarningsTestCase,
):
"""Check that the complex parameter assignment works correctly."""
+ def setUp(self):
+ return super().setUp(use_central=True, use_midext_evo=False)
def test_init(self) -> None:
"""Check some basic attributes."""
@@ -134,7 +136,6 @@ def test_risk(self) -> None:
involvement={"contra": {"II": True}},
midext=True,
)
- self.assertGreater(ipsi_lnlII_risk, ext_contra_lnlII_risk)
self.assertGreater(ext_contra_lnlII_risk, contra_lnlII_risk)
noext_contra_lnlII_risk = self.model.risk(
involvement={"contra": {"II": True}},
From 7ca44f9120a38f4504754e45c60256cfa41b7b58 Mon Sep 17 00:00:00 2001
From: Roman Ludwig <48687784+rmnldwg@users.noreply.github.com>
Date: Wed, 20 Mar 2024 18:06:10 +0100
Subject: [PATCH 9/9] chore: update changelog
---
CHANGELOG.md | 26 +++++++++++++++++++++++++-
1 file changed, 25 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 160d36c..6d87b24 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,28 @@
All notable changes to this project will be documented in this file.
+
+## [1.1.0] - 2024-03-20
+
+### Features
+
+- (**utils**) Add `safe_set_params()` function.\
+ This checks whether the params are a dict, list, or None and handles
+ them accordingly. Just a convencience method that helped refactor some methods.
+- Allow to pass state distributions to `posterior_state_dist()` and `risk()` methds. Fixes [#80].\
+ With this, one can use precomputed state distributions to speed up
+ computing the posterior or risk for multiple scenarios.
+
+### Refactor
+
+- Use `safe_set_params()` across models.
+
+### Testing
+
+- Add checks for midline risk. Related [#80].
+- (**mid**) Fix wrong assumption in risk test.
+
+
## [1.0.0] - 2024-03-18
@@ -604,7 +626,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
- add pre-commit hook to check commit msg
-[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.0.0...HEAD
+[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.1.0...HEAD
+[1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0
[1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0
[1.0.0.rc2]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc1...1.0.0.rc2
[1.0.0.rc1]: https://github.com/rmnldwg/lymph/compare/1.0.0.a6...1.0.0.rc1
@@ -620,6 +643,7 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
[0.4.1]: https://github.com/rmnldwg/lymph/compare/0.4.0...0.4.1
[0.4.0]: https://github.com/rmnldwg/lymph/compare/0.3.10...0.4.0
+[#80]: https://github.com/rmnldwg/lymph/issues/80
[#79]: https://github.com/rmnldwg/lymph/issues/79
[#77]: https://github.com/rmnldwg/lymph/issues/77
[#74]: https://github.com/rmnldwg/lymph/issues/74