diff --git a/CHANGELOG.md b/CHANGELOG.md index bcc8030..95b8748 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,50 @@ All notable changes to this project will be documented in this file. + + +## [1.2.1] - 2024-05-28 + +### Bug Fixes + +- (**uni**) `load_patient_data` should accept `None`. +- (**mid**) Correct type hint of `marginalize`. +- (**graph**) Wrong dict when trinary.\ + The `to_dict()` method returned a wrong graph dictionary when trinary + due to growth edges. This is fixed now. +- Skip `marginalize` only when safe.\ + The marginalization should only be skipped (and 1 returned), when the + entire disease state of interest is `None`. In the midline case, this + disease state includes the midline extension.\ + Previously, only the involvement pattern was checked. Now, the model is + more careful about when to take shortcuts. + + +### Features + +- (**graph**) Modify mermaid graph.\ + The `get_mermaid()` and `get_mermaid_url()` methods now accept arguments + that allow some modifications of the output. +- (**uni**) Add `__repr__()`. + +### Refactor + +- (**uni**) Use pandas `map` instead of `apply`.\ + This saves us a couple of lines in the `load_patient_data` method and is + more readable. + + +### Merge + +- Branch 'main' into 'dev'. + +### Remove + +- Remains of callbacks.\ + Some callback functionality that was tested in a pre-release has been + forgotten in the code base and is now deleted. + + ## [1.2.0] - 2024-03-29 @@ -668,7 +712,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the - add pre-commit hook to check commit msg -[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.0...HEAD +[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.2.1...HEAD +[1.2.1]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.1 [1.2.0]: https://github.com/rmnldwg/lymph/compare/1.1.0...1.2.0 [1.1.0]: https://github.com/rmnldwg/lymph/compare/1.0.0...1.1.0 [1.0.0]: https://github.com/rmnldwg/lymph/compare/1.0.0.rc2...1.0.0 diff --git a/lymph/graph.py b/lymph/graph.py index ffb79ac..a21a88e 100644 --- a/lymph/graph.py +++ b/lymph/graph.py @@ -14,16 +14,17 @@ import base64 import warnings from itertools import product +from typing import Literal import numpy as np +from lymph import types from lymph.utils import ( check_unique_names, comp_transition_tensor, flatten, popfirst, set_params_for, - trigger, ) @@ -224,7 +225,6 @@ def __init__( child: LymphNodeLevel, spread_prob: float = 0., micro_mod: float = 1., - callbacks: list[callable] | None = None, ) -> None: """Create a new edge between two nodes. @@ -235,10 +235,6 @@ def __init__( spread to the next LNL. The ``micro_mod`` parameter is a modifier for the spread probability in case of only a microscopic node involvement. """ - self.trigger_callbacks = [] - if callbacks is not None: - self.trigger_callbacks += callbacks - self.parent: Tumor | LymphNodeLevel = parent self.child: LymphNodeLevel = child @@ -353,7 +349,6 @@ def get_micro_mod(self) -> float: self._micro_mod = 1. return self._micro_mod - @trigger def set_micro_mod(self, new_micro_mod: float | None) -> None: """Set the spread modifier for LNLs with microscopic involvement.""" if new_micro_mod is None: @@ -380,7 +375,6 @@ def get_spread_prob(self) -> float: self._spread_prob = 0. return self._spread_prob - @trigger def set_spread_prob(self, new_spread_prob: float | None) -> None: """Set the spread probability of the edge.""" if new_spread_prob is None: @@ -493,7 +487,6 @@ def __init__( graph_dict: dict[tuple[str], list[str]], tumor_state: int | None = None, allowed_states: list[int] | None = None, - on_edge_change: list[callable] | None = None, ) -> None: """Create a new graph representation of nodes and edges. @@ -512,7 +505,7 @@ def __init__( check_unique_names(graph_dict) self._init_nodes(graph_dict, tumor_state, allowed_states) - self._init_edges(graph_dict, on_edge_change) + self._init_edges(graph_dict) def _init_nodes(self, graph, tumor_state, allowed_lnl_states): @@ -585,7 +578,6 @@ def is_trinary(self) -> bool: def _init_edges( self, graph: dict[tuple[str, str], list[str]], - on_edge_change: list[callable] ) -> None: """Initialize the edges of the ``graph``. @@ -602,12 +594,12 @@ def _init_edges( for (_, start_name), end_names in graph.items(): start = self.nodes[start_name] if isinstance(start, LymphNodeLevel) and start.is_trinary: - growth_edge = Edge(parent=start, child=start, callbacks=on_edge_change) + growth_edge = Edge(parent=start, child=start) self._edges[growth_edge.get_name()] = growth_edge for end_name in end_names: end = self.nodes[end_name] - new_edge = Edge(parent=start, child=end, callbacks=on_edge_change) + new_edge = Edge(parent=start, child=end) self._edges[new_edge.get_name()] = new_edge @@ -669,11 +661,19 @@ def to_dict(self) -> dict[tuple[str, str], set[str]]: res = {} for node in self.nodes.values(): node_type = "tumor" if isinstance(node, Tumor) else "lnl" - res[(node_type, node.name)] = [o.child.name for o in node.out] + res[(node_type, node.name)] = [ + o.child.name + for o in node.out + if not o.is_growth + ] return res - def get_mermaid(self) -> str: + def get_mermaid( + self, + with_params: bool = True, + direction: Literal["TD", "LR"] = "TD", + ) -> str: """Prints the graph in mermaid format. >>> graph_dict = { @@ -691,19 +691,29 @@ def get_mermaid(self) -> str: T-->|20%| III II-->|30%| III + >>> print(graph.get_mermaid(with_params=False)) # doctest: +NORMALIZE_WHITESPACE + flowchart TD + T--> II + T--> III + II--> III + """ - mermaid_graph = "flowchart TD\n" + mermaid_graph = f"flowchart {direction}\n" for node in self.nodes.values(): for edge in node.out: - mermaid_graph += f"\t{node.name}-->|{edge.spread_prob:.0%}| {edge.child.name}\n" + param_str = f"|{edge.spread_prob:.0%}|" if with_params else "" + mermaid_graph += f"\t{node.name}-->{param_str} {edge.child.name}\n" return mermaid_graph - def get_mermaid_url(self) -> str: - """Returns the URL to the rendered graph.""" - mermaid_graph = self.get_mermaid() + def get_mermaid_url(self, **mermaid_kwargs) -> str: + """Returns the URL to the rendered graph. + + Keyword arguments are passed to :py:meth:`~Representation.get_mermaid`. + """ + mermaid_graph = self.get_mermaid(**mermaid_kwargs) graphbytes = mermaid_graph.encode("ascii") base64_bytes = base64.b64encode(graphbytes) base64_string = base64_bytes.decode("ascii") diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py index eef5174..2ecaa15 100644 --- a/lymph/models/bilateral.py +++ b/lymph/models/bilateral.py @@ -626,7 +626,7 @@ def marginalize( are ignored if ``given_state_dist`` is provided. """ if involvement is None: - return 1. + involvement = {} if given_state_dist is None: given_state_dist = self.state_dist(t_stage=t_stage, mode=mode) diff --git a/lymph/models/midline.py b/lymph/models/midline.py index 0a31a0c..74bc43b 100644 --- a/lymph/models/midline.py +++ b/lymph/models/midline.py @@ -753,7 +753,7 @@ def posterior_state_dist( def marginalize( self, - involvement: types.PatternType | None = None, + involvement: dict[str, types.PatternType] | None = None, given_state_dist: np.ndarray | None = None, t_stage: str = "early", mode: Literal["HMM", "BN"] = "HMM", @@ -770,7 +770,7 @@ def marginalize( :py:meth:`.state_dist` method. """ if involvement is None: - return 1. + involvement = {} if given_state_dist is None: given_state_dist = self.state_dist(t_stage=t_stage, mode=mode, central=central) @@ -787,7 +787,7 @@ def marginalize( given_state_dist = given_state_dist[int(midext)] # I think I don't need to normalize here, since I am not computing a # probability of something *given* midext, but only sum up all states that - # match the involvement pattern (which includes the midext status). + # match the disease state of interest (which includes the midext status). return self.ext.marginalize( involvement=involvement, diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py index 2d9662e..56644fd 100644 --- a/lymph/models/unilateral.py +++ b/lymph/models/unilateral.py @@ -2,7 +2,7 @@ import warnings from itertools import product -from typing import Any, Iterable, Literal +from typing import Any, Callable, Iterable, Literal import numpy as np import pandas as pd @@ -117,6 +117,17 @@ def trinary(cls, graph_dict: types.GraphDictType, **kwargs) -> Unilateral: return cls(graph_dict, allowed_states=[0, 1, 2], **kwargs) + def __repr__(self) -> str: + """Return a string representation of the instance.""" + return ( + f"{type(self).__name__}(" + f"graph_dict={self.graph.to_dict()}, " + f"tumor_state={list(self.graph.tumors.values())[0].state}, " + f"allowed_states={self.graph.allowed_states}, " + f"max_time={self.max_time})" + ) + + def __str__(self) -> str: """Print info about the instance.""" return f"Unilateral with {len(self.graph.tumors)} tumors and {len(self.graph.lnls)} LNLs" @@ -489,7 +500,7 @@ def load_patient_data( self, patient_data: pd.DataFrame, side: str = "ipsi", - mapping: callable | dict[int, Any] | None = None, + mapping: Callable[[int], Any] | dict[int, Any] | None = None, ) -> None: """Load patient data in `LyProX`_ format into the model. @@ -512,7 +523,6 @@ def load_patient_data( if mapping is None: mapping = early_late_mapping - # pylint: disable=unnecessary-lambda-assignment patient_data = ( patient_data .copy() @@ -545,15 +555,7 @@ def load_patient_data( patient_data["_model", modality, lnl] = column - if len(patient_data) == 0: - patient_data[MAP_T_COL] = None - else: - mapping = dict_to_func(mapping) if isinstance(mapping, dict) else mapping - patient_data[MAP_T_COL] = patient_data.apply( - lambda row: mapping(row[RAW_T_COL]), - axis=1, - ) - + patient_data[MAP_T_COL] = patient_data[RAW_T_COL].map(mapping) self._patient_data = patient_data self._cache_version += 1 @@ -833,7 +835,11 @@ def marginalize( :py:meth:`.state_dist` with the given ``t_stage`` and ``mode``. These arguments are ignored if ``given_state_dist`` is provided. """ - if involvement is None: + if ( + involvement is None + or not involvement # empty dict is falsey + or all(value is None for value in involvement.values()) + ): return 1. if given_state_dist is None: diff --git a/tests/edge_test.py b/tests/edge_test.py index bb09155..482afe6 100644 --- a/tests/edge_test.py +++ b/tests/edge_test.py @@ -13,12 +13,7 @@ def setUp(self) -> None: super().setUp() parent = graph.LymphNodeLevel("parent") child = graph.LymphNodeLevel("child") - self.was_called = False - self.edge = graph.Edge(parent, child, callbacks=[self.callback]) - - def callback(self) -> None: - """Callback function for the edge.""" - self.was_called = True + self.edge = graph.Edge(parent, child) def test_str(self) -> None: """Test the string representation of the edge.""" @@ -41,17 +36,11 @@ def test_repr(self) -> None: self.assertEqual(self.edge.spread_prob, recreated_edge.spread_prob) self.assertEqual(self.edge.micro_mod, recreated_edge.micro_mod) - def test_callback_on_param_change(self) -> None: - """Test if the callback function is called.""" - self.edge.spread_prob = 0.5 - self.assertTrue(self.was_called) - def test_graph_change(self) -> None: """Check if the callback also works when parent/child nodes are changed.""" old_child = self.edge.child new_child = graph.LymphNodeLevel("new_child") self.edge.child = new_child - self.assertTrue(self.was_called) self.assertNotIn(self.edge, old_child.inc) def test_transition_tensor_row_sums(self) -> None: diff --git a/tests/graph_representation_test.py b/tests/graph_representation_test.py index 5eb55df..4a8010c 100644 --- a/tests/graph_representation_test.py +++ b/tests/graph_representation_test.py @@ -34,15 +34,10 @@ def setUp(self) -> None: self.graph_repr = graph.Representation( graph_dict=self.graph_dict, allowed_states=[0, 1], - on_edge_change=[self.callback], ) self.was_called = False self.rng = np.random.default_rng(42) - def callback(self) -> None: - """Callback function for the graph.""" - self.was_called = True - def test_nodes(self) -> None: """Test the number of nodes.""" self.assertEqual(len(self.graph_repr.nodes), len(self.graph_dict))