Skip to content

Commit

Permalink
Simplify code a bit by removing separate leaves accounting
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Jan 31, 2024
1 parent 0a9893f commit 7b5fb01
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 31 deletions.
21 changes: 4 additions & 17 deletions lineage/LineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class LineageTree:

pi: npt.NDArray[np.float64]
T: npt.NDArray[np.float64]
leaves_idx: npt.NDArray[np.uintp]
leaves_idx: np.ndarray
idx_by_gen: list[np.ndarray]
output_lineage: list[CellVar]
cell_to_parent: np.ndarray
Expand All @@ -26,10 +26,12 @@ def __init__(self, list_of_cells: list, E: list):
# assign times using the state distribution specific time model
E[0].assign_times(self.output_lineage)

self.leaves_idx = get_leaves_idx(self.output_lineage)
self.cell_to_parent = cell_to_parent(self.output_lineage)
self.cell_to_daughters = cell_to_daughters(self.output_lineage)

# Leaves have no daughters
self.leaves_idx = np.nonzero(np.all(self.cell_to_daughters == -1, axis=1))[0]

@classmethod
def rand_init(
cls,
Expand Down Expand Up @@ -227,18 +229,3 @@ def cell_to_daughters(lineage: list[CellVar]) -> np.ndarray:
output[ii, 1] = lineage.index(cell.right)

return output


def get_leaves_idx(lineage: list[CellVar]) -> npt.NDArray[np.uintp]:
"""
A function to find the leaves and their indexes in the lineage list.
:param lineage: The list of cells in a lineageTree object.
:return leaf_indices: The list of cell indexes.
:return leaves: The last cells in the lineage branch.
"""
leaf_indices = []
for index, cell in enumerate(lineage):
if cell.isLeaf():
assert cell.observed
leaf_indices.append(index) # appending the index of the cells
return np.array(leaf_indices, dtype=np.uintp)
4 changes: 2 additions & 2 deletions lineage/states/stateCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
import numpy as np
from numba import njit
from numba import jit
from numba.typed import List
import numpy.typing as npt
from ctypes import CFUNCTYPE, c_double
Expand Down Expand Up @@ -51,7 +51,7 @@ def bern_estimator(bern_obs: np.ndarray, gammas: np.ndarray):
return numerator / denominator


@njit
@jit(nopython=True)
def gamma_LL(
logX: npt.NDArray[np.float64],
gamma_obs: List[npt.NDArray[np.float64]],
Expand Down
13 changes: 1 addition & 12 deletions lineage/tests/test_LineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
import numpy as np
from ..CellVar import CellVar as c
from ..LineageTree import LineageTree, max_gen, get_leaves_idx
from ..LineageTree import LineageTree, max_gen
from ..states.StateDistributionGamma import StateDistribution


Expand Down Expand Up @@ -125,14 +125,3 @@ def test_get_parent_for_level(self):
list_by_gen = max_gen(self.lineage1.output_lineage)
parent_ind_holder = np.unique(self.lineage1.cell_to_parent[list_by_gen[3]])
np.testing.assert_array_equal(parent_ind_holder, list_by_gen[2])

def test_get_leaves(self):
"""
A unittest fot get_leaves function.
"""
# getting the leaves and their indexes for lineage1
leaf_index = get_leaves_idx(self.lineage1.output_lineage)

# to check the indexes for leaf cells are true
for i in leaf_index:
self.assertTrue(self.lineage1.output_lineage[i].isLeaf())

0 comments on commit 7b5fb01

Please sign in to comment.