diff --git a/lineage/LineageTree.py b/lineage/LineageTree.py index 57c6bc8e0..9ce659305 100644 --- a/lineage/LineageTree.py +++ b/lineage/LineageTree.py @@ -12,7 +12,7 @@ class LineageTree: pi: npt.NDArray[np.float64] T: npt.NDArray[np.float64] - leaves_idx: npt.NDArray[np.uintp] + leaves_idx: np.ndarray idx_by_gen: list[np.ndarray] output_lineage: list[CellVar] cell_to_parent: np.ndarray @@ -26,10 +26,12 @@ def __init__(self, list_of_cells: list, E: list): # assign times using the state distribution specific time model E[0].assign_times(self.output_lineage) - self.leaves_idx = get_leaves_idx(self.output_lineage) self.cell_to_parent = cell_to_parent(self.output_lineage) self.cell_to_daughters = cell_to_daughters(self.output_lineage) + # Leaves have no daughters + self.leaves_idx = np.nonzero(np.all(self.cell_to_daughters == -1, axis=1))[0] + @classmethod def rand_init( cls, @@ -227,18 +229,3 @@ def cell_to_daughters(lineage: list[CellVar]) -> np.ndarray: output[ii, 1] = lineage.index(cell.right) return output - - -def get_leaves_idx(lineage: list[CellVar]) -> npt.NDArray[np.uintp]: - """ - A function to find the leaves and their indexes in the lineage list. - :param lineage: The list of cells in a lineageTree object. - :return leaf_indices: The list of cell indexes. - :return leaves: The last cells in the lineage branch. - """ - leaf_indices = [] - for index, cell in enumerate(lineage): - if cell.isLeaf(): - assert cell.observed - leaf_indices.append(index) # appending the index of the cells - return np.array(leaf_indices, dtype=np.uintp) diff --git a/lineage/states/stateCommon.py b/lineage/states/stateCommon.py index ef410fc67..1e47eba45 100644 --- a/lineage/states/stateCommon.py +++ b/lineage/states/stateCommon.py @@ -2,7 +2,7 @@ import warnings import numpy as np -from numba import njit +from numba import jit from numba.typed import List import numpy.typing as npt from ctypes import CFUNCTYPE, c_double @@ -51,7 +51,7 @@ def bern_estimator(bern_obs: np.ndarray, gammas: np.ndarray): return numerator / denominator -@njit +@jit(nopython=True) def gamma_LL( logX: npt.NDArray[np.float64], gamma_obs: List[npt.NDArray[np.float64]], diff --git a/lineage/tests/test_LineageTree.py b/lineage/tests/test_LineageTree.py index 1a3b25440..395657ade 100644 --- a/lineage/tests/test_LineageTree.py +++ b/lineage/tests/test_LineageTree.py @@ -2,7 +2,7 @@ import unittest import numpy as np from ..CellVar import CellVar as c -from ..LineageTree import LineageTree, max_gen, get_leaves_idx +from ..LineageTree import LineageTree, max_gen from ..states.StateDistributionGamma import StateDistribution @@ -125,14 +125,3 @@ def test_get_parent_for_level(self): list_by_gen = max_gen(self.lineage1.output_lineage) parent_ind_holder = np.unique(self.lineage1.cell_to_parent[list_by_gen[3]]) np.testing.assert_array_equal(parent_ind_holder, list_by_gen[2]) - - def test_get_leaves(self): - """ - A unittest fot get_leaves function. - """ - # getting the leaves and their indexes for lineage1 - leaf_index = get_leaves_idx(self.lineage1.output_lineage) - - # to check the indexes for leaf cells are true - for i in leaf_index: - self.assertTrue(self.lineage1.output_lineage[i].isLeaf())