diff --git a/CHANGELOG.md b/CHANGELOG.md
index bcc8030..95b8748 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,50 @@
All notable changes to this project will be documented in this file.
+
+
+## [1.2.1] - 2024-05-28
+
+### Bug Fixes
+
+- (**uni**) `load_patient_data` should accept `None`.
+- (**mid**) Correct type hint of `marginalize`.
+- (**graph**) Wrong dict when trinary.\
+ The `to_dict()` method returned a wrong graph dictionary when trinary
+ due to growth edges. This is fixed now.
+- Skip `marginalize` only when safe.\
+ The marginalization should only be skipped (and 1 returned), when the
+ entire disease state of interest is `None`. In the midline case, this
+ disease state includes the midline extension.\
+ Previously, only the involvement pattern was checked. Now, the model is
+ more careful about when to take shortcuts.
+
+
+### Features
+
+- (**graph**) Modify mermaid graph.\
+ The `get_mermaid()` and `get_mermaid_url()` methods now accept arguments
+ that allow some modifications of the output.
+- (**uni**) Add `__repr__()`.
+
+### Refactor
+
+- (**uni**) Use pandas `map` instead of `apply`.\
+ This saves us a couple of lines in the `load_patient_data` method and is
+ more readable.
+
+
+### Merge
+
+- Branch 'main' into 'dev'.
+
+### Remove
+
+- Remains of callbacks.\
+ Some callback functionality that was tested in a pre-release has been
+ forgotten in the code base and is now deleted.
+
+
## [1.2.0] - 2024-03-29
@@ -668,7 +712,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
- add pre-commit hook to check commit msg
-[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.0...HEAD
+[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.1...HEAD
+[1.2.1]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.1
[1.2.0]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.0
[1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0
[1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0
diff --git a/lymph/graph.py b/lymph/graph.py
index ffb79ac..a21a88e 100644
--- a/lymph/graph.py
+++ b/lymph/graph.py
@@ -14,16 +14,17 @@
import base64
import warnings
from itertools import product
+from typing import Literal
import numpy as np
+from lymph import types
from lymph.utils import (
check_unique_names,
comp_transition_tensor,
flatten,
popfirst,
set_params_for,
- trigger,
)
@@ -224,7 +225,6 @@ def __init__(
child: LymphNodeLevel,
spread_prob: float = 0.,
micro_mod: float = 1.,
- callbacks: list[callable] | None = None,
) -> None:
"""Create a new edge between two nodes.
@@ -235,10 +235,6 @@ def __init__(
spread to the next LNL. The ``micro_mod`` parameter is a modifier for the spread
probability in case of only a microscopic node involvement.
"""
- self.trigger_callbacks = []
- if callbacks is not None:
- self.trigger_callbacks += callbacks
-
self.parent: Tumor | LymphNodeLevel = parent
self.child: LymphNodeLevel = child
@@ -353,7 +349,6 @@ def get_micro_mod(self) -> float:
self._micro_mod = 1.
return self._micro_mod
- @trigger
def set_micro_mod(self, new_micro_mod: float | None) -> None:
"""Set the spread modifier for LNLs with microscopic involvement."""
if new_micro_mod is None:
@@ -380,7 +375,6 @@ def get_spread_prob(self) -> float:
self._spread_prob = 0.
return self._spread_prob
- @trigger
def set_spread_prob(self, new_spread_prob: float | None) -> None:
"""Set the spread probability of the edge."""
if new_spread_prob is None:
@@ -493,7 +487,6 @@ def __init__(
graph_dict: dict[tuple[str], list[str]],
tumor_state: int | None = None,
allowed_states: list[int] | None = None,
- on_edge_change: list[callable] | None = None,
) -> None:
"""Create a new graph representation of nodes and edges.
@@ -512,7 +505,7 @@ def __init__(
check_unique_names(graph_dict)
self._init_nodes(graph_dict, tumor_state, allowed_states)
- self._init_edges(graph_dict, on_edge_change)
+ self._init_edges(graph_dict)
def _init_nodes(self, graph, tumor_state, allowed_lnl_states):
@@ -585,7 +578,6 @@ def is_trinary(self) -> bool:
def _init_edges(
self,
graph: dict[tuple[str, str], list[str]],
- on_edge_change: list[callable]
) -> None:
"""Initialize the edges of the ``graph``.
@@ -602,12 +594,12 @@ def _init_edges(
for (_, start_name), end_names in graph.items():
start = self.nodes[start_name]
if isinstance(start, LymphNodeLevel) and start.is_trinary:
- growth_edge = Edge(parent=start, child=start, callbacks=on_edge_change)
+ growth_edge = Edge(parent=start, child=start)
self._edges[growth_edge.get_name()] = growth_edge
for end_name in end_names:
end = self.nodes[end_name]
- new_edge = Edge(parent=start, child=end, callbacks=on_edge_change)
+ new_edge = Edge(parent=start, child=end)
self._edges[new_edge.get_name()] = new_edge
@@ -669,11 +661,19 @@ def to_dict(self) -> dict[tuple[str, str], set[str]]:
res = {}
for node in self.nodes.values():
node_type = "tumor" if isinstance(node, Tumor) else "lnl"
- res[(node_type, node.name)] = [o.child.name for o in node.out]
+ res[(node_type, node.name)] = [
+ o.child.name
+ for o in node.out
+ if not o.is_growth
+ ]
return res
- def get_mermaid(self) -> str:
+ def get_mermaid(
+ self,
+ with_params: bool = True,
+ direction: Literal["TD", "LR"] = "TD",
+ ) -> str:
"""Prints the graph in mermaid format.
>>> graph_dict = {
@@ -691,19 +691,29 @@ def get_mermaid(self) -> str:
T-->|20%| III
II-->|30%| III
+ >>> print(graph.get_mermaid(with_params=False)) # doctest: +NORMALIZE_WHITESPACE
+ flowchart TD
+ T--> II
+ T--> III
+ II--> III
+
"""
- mermaid_graph = "flowchart TD\n"
+ mermaid_graph = f"flowchart {direction}\n"
for node in self.nodes.values():
for edge in node.out:
- mermaid_graph += f"\t{node.name}-->|{edge.spread_prob:.0%}| {edge.child.name}\n"
+ param_str = f"|{edge.spread_prob:.0%}|" if with_params else ""
+ mermaid_graph += f"\t{node.name}-->{param_str} {edge.child.name}\n"
return mermaid_graph
- def get_mermaid_url(self) -> str:
- """Returns the URL to the rendered graph."""
- mermaid_graph = self.get_mermaid()
+ def get_mermaid_url(self, **mermaid_kwargs) -> str:
+ """Returns the URL to the rendered graph.
+
+ Keyword arguments are passed to :py:meth:`~Representation.get_mermaid`.
+ """
+ mermaid_graph = self.get_mermaid(**mermaid_kwargs)
graphbytes = mermaid_graph.encode("ascii")
base64_bytes = base64.b64encode(graphbytes)
base64_string = base64_bytes.decode("ascii")
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index eef5174..2ecaa15 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -626,7 +626,7 @@ def marginalize(
are ignored if ``given_state_dist`` is provided.
"""
if involvement is None:
- return 1.
+ involvement = {}
if given_state_dist is None:
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
diff --git a/lymph/models/midline.py b/lymph/models/midline.py
index 0a31a0c..74bc43b 100644
--- a/lymph/models/midline.py
+++ b/lymph/models/midline.py
@@ -753,7 +753,7 @@ def posterior_state_dist(
def marginalize(
self,
- involvement: types.PatternType | None = None,
+ involvement: dict[str, types.PatternType] | None = None,
given_state_dist: np.ndarray | None = None,
t_stage: str = "early",
mode: Literal["HMM", "BN"] = "HMM",
@@ -770,7 +770,7 @@ def marginalize(
:py:meth:`.state_dist` method.
"""
if involvement is None:
- return 1.
+ involvement = {}
if given_state_dist is None:
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode, central=central)
@@ -787,7 +787,7 @@ def marginalize(
given_state_dist = given_state_dist[int(midext)]
# I think I don't need to normalize here, since I am not computing a
# probability of something *given* midext, but only sum up all states that
- # match the involvement pattern (which includes the midext status).
+ # match the disease state of interest (which includes the midext status).
return self.ext.marginalize(
involvement=involvement,
diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py
index 2d9662e..56644fd 100644
--- a/lymph/models/unilateral.py
+++ b/lymph/models/unilateral.py
@@ -2,7 +2,7 @@
import warnings
from itertools import product
-from typing import Any, Iterable, Literal
+from typing import Any, Callable, Iterable, Literal
import numpy as np
import pandas as pd
@@ -117,6 +117,17 @@ def trinary(cls, graph_dict: types.GraphDictType, **kwargs) -> Unilateral:
return cls(graph_dict, allowed_states=[0, 1, 2], **kwargs)
+ def __repr__(self) -> str:
+ """Return a string representation of the instance."""
+ return (
+ f"{type(self).__name__}("
+ f"graph_dict={self.graph.to_dict()}, "
+ f"tumor_state={list(self.graph.tumors.values())[0].state}, "
+ f"allowed_states={self.graph.allowed_states}, "
+ f"max_time={self.max_time})"
+ )
+
+
def __str__(self) -> str:
"""Print info about the instance."""
return f"Unilateral with {len(self.graph.tumors)} tumors and {len(self.graph.lnls)} LNLs"
@@ -489,7 +500,7 @@ def load_patient_data(
self,
patient_data: pd.DataFrame,
side: str = "ipsi",
- mapping: callable | dict[int, Any] | None = None,
+ mapping: Callable[[int], Any] | dict[int, Any] | None = None,
) -> None:
"""Load patient data in `LyProX`_ format into the model.
@@ -512,7 +523,6 @@ def load_patient_data(
if mapping is None:
mapping = early_late_mapping
- # pylint: disable=unnecessary-lambda-assignment
patient_data = (
patient_data
.copy()
@@ -545,15 +555,7 @@ def load_patient_data(
patient_data["_model", modality, lnl] = column
- if len(patient_data) == 0:
- patient_data[MAP_T_COL] = None
- else:
- mapping = dict_to_func(mapping) if isinstance(mapping, dict) else mapping
- patient_data[MAP_T_COL] = patient_data.apply(
- lambda row: mapping(row[RAW_T_COL]),
- axis=1,
- )
-
+ patient_data[MAP_T_COL] = patient_data[RAW_T_COL].map(mapping)
self._patient_data = patient_data
self._cache_version += 1
@@ -833,7 +835,11 @@ def marginalize(
:py:meth:`.state_dist` with the given ``t_stage`` and ``mode``. These arguments
are ignored if ``given_state_dist`` is provided.
"""
- if involvement is None:
+ if (
+ involvement is None
+ or not involvement # empty dict is falsey
+ or all(value is None for value in involvement.values())
+ ):
return 1.
if given_state_dist is None:
diff --git a/tests/edge_test.py b/tests/edge_test.py
index bb09155..482afe6 100644
--- a/tests/edge_test.py
+++ b/tests/edge_test.py
@@ -13,12 +13,7 @@ def setUp(self) -> None:
super().setUp()
parent = graph.LymphNodeLevel("parent")
child = graph.LymphNodeLevel("child")
- self.was_called = False
- self.edge = graph.Edge(parent, child, callbacks=[self.callback])
-
- def callback(self) -> None:
- """Callback function for the edge."""
- self.was_called = True
+ self.edge = graph.Edge(parent, child)
def test_str(self) -> None:
"""Test the string representation of the edge."""
@@ -41,17 +36,11 @@ def test_repr(self) -> None:
self.assertEqual(self.edge.spread_prob, recreated_edge.spread_prob)
self.assertEqual(self.edge.micro_mod, recreated_edge.micro_mod)
- def test_callback_on_param_change(self) -> None:
- """Test if the callback function is called."""
- self.edge.spread_prob = 0.5
- self.assertTrue(self.was_called)
-
def test_graph_change(self) -> None:
"""Check if the callback also works when parent/child nodes are changed."""
old_child = self.edge.child
new_child = graph.LymphNodeLevel("new_child")
self.edge.child = new_child
- self.assertTrue(self.was_called)
self.assertNotIn(self.edge, old_child.inc)
def test_transition_tensor_row_sums(self) -> None:
diff --git a/tests/graph_representation_test.py b/tests/graph_representation_test.py
index 5eb55df..4a8010c 100644
--- a/tests/graph_representation_test.py
+++ b/tests/graph_representation_test.py
@@ -34,15 +34,10 @@ def setUp(self) -> None:
self.graph_repr = graph.Representation(
graph_dict=self.graph_dict,
allowed_states=[0, 1],
- on_edge_change=[self.callback],
)
self.was_called = False
self.rng = np.random.default_rng(42)
- def callback(self) -> None:
- """Callback function for the graph."""
- self.was_called = True
-
def test_nodes(self) -> None:
"""Test the number of nodes."""
self.assertEqual(len(self.graph_repr.nodes), len(self.graph_dict))