Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress graph #15

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ install_requires =
magicgui
qtpy
scikit-image
motile_toolbox
motile @ git+https://github.com/funkelab/motile.git@5fdc0247a13a58b5cb49f808daf9d37719481467
motile_toolbox @ git+https://github.com/funkelab/motile_toolbox.git@4221b06eb1bd1485f709ca62c45482847e5488f7
pydantic

python_requires = >=3.8
Expand Down
44 changes: 33 additions & 11 deletions src/motile_plugin/backend/motile_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
IN_SEG_FILEANME = "input_segmentation.npy"
OUT_SEG_FILEANME = "output_segmentation.npy"
TRACKS_FILENAME = "tracks.json"
GAPS_FILENAME = "gaps.txt"


class MotileRun(BaseModel):
"""An object representing a motile tracking run. It is frozen because a completed
run cannot be mutated.
TODO: lazy loading from zarr, don't re-save existing input zarr
(e.g. if its a chunk from a bigger zarr)
TODO: Do we need BaseModel? It requires kwargs which is mildly annoying
"""An object representing a motile tracking run.
"""
run_name: str = Field()
solver_params: SolverParams = Field()
input_segmentation: np.ndarray | None = Field(None)
output_segmentation: np.ndarray | None = Field(None)
tracks: nx.DiGraph | None = Field(None)
time: datetime = Field(datetime.now())
run_name: str
solver_params: SolverParams
input_segmentation: np.ndarray | None = None
output_segmentation: np.ndarray | None = None
tracks: nx.DiGraph | None = None
time: datetime = datetime.now()
gaps: list[float] = []

class Config:
allow_mutation = False
Expand Down Expand Up @@ -57,6 +55,8 @@ def save(self, base_path: str | Path):
self._save_segmentation(run_dir, IN_SEG_FILEANME, self.input_segmentation)
self._save_segmentation(run_dir, OUT_SEG_FILEANME, self.output_segmentation)
self._save_tracks(run_dir)
if self.gaps is not None:
self._save_gaps(run_dir)

@classmethod
def load(cls, run_dir: Path | str):
Expand All @@ -67,13 +67,15 @@ def load(cls, run_dir: Path | str):
input_segmentation = cls._load_segmentation(run_dir, IN_SEG_FILEANME)
output_segmentation = cls._load_segmentation(run_dir, OUT_SEG_FILEANME)
tracks = cls._load_tracks(run_dir)
gaps = cls._load_gaps(run_dir)
return cls(
run_name=run_name,
solver_params=params,
input_segmentation=input_segmentation,
output_segmentation=output_segmentation,
tracks=tracks,
time=time,
gaps=gaps,
)

def _save_params(self, run_dir):
Expand Down Expand Up @@ -123,6 +125,25 @@ def _load_tracks(run_dir: Path, required: bool = True) -> nx.DiGraph:
else:
return None

def _save_gaps(self, run_dir: Path):
gaps_file = run_dir / GAPS_FILENAME
with open(gaps_file, 'w') as f:
f.write(','.join(map(str, self.gaps)))

@staticmethod
def _load_gaps(run_dir, required: bool = True) -> list[float]:
gaps_file = run_dir / GAPS_FILENAME
if gaps_file.is_file():
with open(gaps_file) as f:
gaps = list(map(float, f.read().split(",")))
return gaps
elif required:
raise FileNotFoundError(f"No gaps found at {gaps_file}")
else:
return None



def delete(self, base_path: str | Path):
base_path = Path(base_path)
run_dir = base_path / self._make_directory(self.time, self.run_name)
Expand All @@ -132,5 +153,6 @@ def delete(self, base_path: str | Path):
(run_dir / IN_SEG_FILEANME).unlink()
(run_dir / OUT_SEG_FILEANME).unlink()
(run_dir / TRACKS_FILENAME).unlink()
(run_dir / GAPS_FILENAME).unlink()
run_dir.rmdir()

48 changes: 26 additions & 22 deletions src/motile_plugin/backend/solve.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@

import logging
import time

import numpy as np
from motile import Solver, TrackGraph
from motile.constraints import MaxChildren, MaxParents, ExclusiveNodes
from motile.costs import Appear, Disappear, EdgeSelection, Split
from motile.constraints import ExclusiveNodes, MaxChildren, MaxParents
from motile.costs import Appear, Disappear, EdgeDistance, EdgeSelection, Split
from motile_toolbox.candidate_graph import (
EdgeAttr,
NodeAttr,
get_candidate_graph,
graph_to_nx,
)
import numpy as np

from .solver_params import SolverParams

Expand All @@ -22,17 +21,17 @@ def solve(
solver_params: SolverParams,
segmentation: np.ndarray,
on_solver_update=None,
):
):
cand_graph, conflict_sets = get_candidate_graph(
segmentation,
solver_params.max_edge_distance,
iou=solver_params.iou_weight is not None,
)
logger.debug(f"Cand graph has {cand_graph.number_of_nodes()} nodes")
logger.debug("Cand graph has %d nodes", cand_graph.number_of_nodes())
solver = construct_solver(cand_graph, solver_params, conflict_sets)
start_time = time.time()
solution = solver.solve(verbose=False, on_event=on_solver_update)
logger.info(f"Solution took {time.time() - start_time} seconds")
logger.info("Solution took %.2f seconds", time.time() - start_time)

solution_graph = solver.get_selected_subgraph(solution=solution)
solution_nx_graph = graph_to_nx(solution_graph)
Expand All @@ -41,9 +40,11 @@ def solve(


def construct_solver(cand_graph, solver_params, exclusive_sets):
solver = Solver(TrackGraph(cand_graph, frame_attribute=NodeAttr.TIME.value))
solver = Solver(
TrackGraph(cand_graph, frame_attribute=NodeAttr.TIME.value)
)
solver.add_constraints(MaxChildren(solver_params.max_children))
solver.add_constraints(MaxParents(solver_params.max_parents))
solver.add_constraints(MaxParents(1))
if exclusive_sets is None or len(exclusive_sets) > 0:
solver.add_constraints(ExclusiveNodes(exclusive_sets))

Expand All @@ -53,20 +54,23 @@ def construct_solver(cand_graph, solver_params, exclusive_sets):
solver.add_costs(Disappear(solver_params.disappear_cost))
if solver_params.division_cost is not None:
solver.add_costs(Split(constant=solver_params.division_cost))
if solver_params.merge_cost is not None:
from motile.costs import Merge
solver.add_costs(Merge(constant=solver_params.merge_cost))

if solver_params.distance_weight is not None:
solver.add_costs(EdgeSelection(
solver_params.distance_weight,
attribute=EdgeAttr.DISTANCE.value,
constant=solver_params.distance_offset), name="distance")
solver.add_costs(
EdgeDistance(
position_attribute=NodeAttr.POS.value,
weight=solver_params.distance_weight,
constant=solver_params.distance_offset,
),
name="distance",
)
if solver_params.iou_weight is not None:
solver.add_costs(EdgeSelection(
weight=solver_params.iou_weight,
attribute=EdgeAttr.IOU.value,
constant=solver_params.iou_offset), name="iou")
solver.add_costs(
EdgeSelection(
weight=solver_params.iou_weight,
attribute=EdgeAttr.IOU.value,
constant=solver_params.iou_offset,
),
name="iou",
)
return solver


35 changes: 13 additions & 22 deletions src/motile_plugin/backend/solver_params.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,63 @@

from pydantic import BaseModel, Field


class SolverParams(BaseModel):
"""The set of solver parameters supported in the motile plugin.
Used to build the UI as well as store parameters for runs.
"""

max_edge_distance: float = Field(
50.0,
title="Max Move Distance",
description=r"""The maximum distance an object center can move between time frames.
Objects further than this cannot be matched, but making this value larger will increase solving time."""
Objects further than this cannot be matched, but making this value larger will increase solving time.""",
)
max_children: int = Field(
2,
title="Max Children",
description="The maximum number of object in time t+1 that can be linked to an item in time t.\nIf no division, set to 1."
)
max_parents: int = Field(
1,
title="Max Parent",
description=r"""The maximum number of object in time t that can be linked to an item in time t+1.
If no merging, set to 1."""
description="The maximum number of object in time t+1 that can be linked to an item in time t.\nIf no division, set to 1.",
)
appear_cost: float | None = Field(
30,
title="Appear Cost",
description=r"""Cost for starting a new track. A higher value means fewer selected tracks and fewer merges."""
description=r"""Cost for starting a new track. A higher value means fewer selected tracks and fewer merges.""",
)
disappear_cost: float | None = Field(
30,
title="Disappear Cost",
description=r"""Cost for ending a track. A higher value means fewer selected tracks and fewer divisions."""
description=r"""Cost for ending a track. A higher value means fewer selected tracks and fewer divisions.""",
)
division_cost: float | None = Field(
20,
title="Division Cost",
description=r"""Cost for a track dividing. A higher value means fewer divisions.
If this cost is higher than the appear cost, tracks will likely never divide."""
)
merge_cost: float | None = Field(
20,
title="Merge Cost",
description=r"""Cost for a track merging. A higher value means fewer merges.
If this cost is higher than the disappear cost, tracks will likely never merge."""
If this cost is higher than the appear cost, tracks will likely never divide.""",
)
distance_weight: float | None = Field(
1,
title="Distance Weight",
description=r"""Value multiplied by distance to create the cost for selecting an edge.
The weight should generally be positive, because a higher distance should be a higher cost.
The magnitude of the weight helps balance distance with other costs."""
The magnitude of the weight helps balance distance with other costs.""",
)
distance_offset: float | None = Field(
-20,
title="Distance Offset",
description=r"""Value added to (distance * weight) to create the cost for selecting an edge.
Usually should be negative to encourage anything being selected.
(If all costs are positive, the optimal solution is selecting nothing.)"""
(If all costs are positive, the optimal solution is selecting nothing.)""",
)
iou_weight: float | None = Field(
-5,
title="IOU Weight",
description=r"""Value multiplied by IOU to create cost.
The weight should generally be negative, because a higher IOU should be a lower cost.
The magnitude of the weight helps balance IOU with other costs."""
The magnitude of the weight helps balance IOU with other costs.""",
)
iou_offset: float | None = Field(
0,
title="IOU Offset",
description=r"""Value added to (IOU * weight) to create cost.
Zero is a sensible default with a negative weight, and will have the cost range from -weight to 0
(In this case, any IOU will never hurt selection)."""
(In this case, any IOU will never hurt selection).""",
)
34 changes: 14 additions & 20 deletions src/motile_plugin/widgets/motile_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from .runs_list import RunsList
from .run_editor import RunEditor
from .run_viewer import RunViewer
from .solver_status import SolverStatus
from .solver_params import SolverParamsWidget
from qtpy.QtCore import Signal

logger = logging.getLogger(__name__)


class MotileWidget(QWidget):
solver_event = Signal(dict)

def __init__(self, viewer, graph_layer=False, multiseg=False):
super().__init__()
self.viewer: Viewer = viewer
Expand All @@ -42,17 +43,14 @@ def __init__(self, viewer, graph_layer=False, multiseg=False):

self.view_run_widget = RunViewer(None)
self.view_run_widget.hide()
self.solver_event.connect(self.view_run_widget.solver_event_update)

self.run_list_widget = RunsList()
self.run_list_widget.view_run.connect(self.view_run)
self.run_list_widget.edit_run.connect(self.edit_run)

self.solver_status_widget = SolverStatus()
self.solver_status_widget.hide()

main_layout.addWidget(self.view_run_widget)
main_layout.addWidget(self.edit_run_widget)
main_layout.addWidget(self.solver_status_widget)
main_layout.addWidget(self.run_list_widget)
self.setLayout(main_layout)

Expand Down Expand Up @@ -93,52 +91,48 @@ def add_napari_layers(self):


def view_run(self, run: MotileRun):
# TODO: remove old layers from napari and replace with new
self.view_run_widget.reset_progress()
self.view_run_widget.update_run(run)
self.edit_run_widget.hide()
self.view_run_widget.show()

# Add layers to Napari
self.remove_napari_layers()
self.update_napari_layers(run)
self.add_napari_layers()



def edit_run(self, run: MotileRun | None):
self.view_run_widget.hide()
self.edit_run_widget.show()
if run:
self.edit_run_widget.new_run(run)

self.remove_napari_layers()


def _generate_tracks(self, run):
# Logic for generating tracks
logger.debug("Segmentation shape: %s", run.input_segmentation.shape)
# do this in a separate thread so we can parse stdout and not block
self.solver_status_widget.show()
self.view_run_widget.reset_progress()
self.edit_run_widget.hide()
self.view_run_widget.update_run(run)
self.view_run_widget.set_solver_label("initializing")
self.view_run_widget.show()
worker = self.solve_with_motile(run)
worker.returned.connect(self._on_solve_complete)
worker.start()

def on_solver_update(self, event_data):
self.solver_status_widget.update(event_data)

@thread_worker
def solve_with_motile(
self,
run: MotileRun
):
run.tracks = solve(run.solver_params, run.input_segmentation, self.on_solver_update)
run.tracks = solve(run.solver_params, run.input_segmentation, self.solver_event.emit)
run.output_segmentation = relabel_segmentation(run.tracks, run.input_segmentation)
return run

def _on_solve_complete(self, run: MotileRun):
self.solver_status_widget.hide()
self.run_list_widget.add_run(run.copy(), select=True)
self.view_run(run)
self.solver_status_widget.reset()
self.view_run_widget.set_solver_label("done")




Expand Down
Loading
Loading