diff --git a/setup.cfg b/setup.cfg index a768604..191d6b9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,8 @@ install_requires = magicgui qtpy scikit-image - motile_toolbox + motile @ git+https://github.com/funkelab/motile.git@5fdc0247a13a58b5cb49f808daf9d37719481467 + motile_toolbox @ git+https://github.com/funkelab/motile_toolbox.git@4221b06eb1bd1485f709ca62c45482847e5488f7 pydantic python_requires = >=3.8 diff --git a/src/motile_plugin/backend/motile_run.py b/src/motile_plugin/backend/motile_run.py index cabbc75..45e2cea 100644 --- a/src/motile_plugin/backend/motile_run.py +++ b/src/motile_plugin/backend/motile_run.py @@ -13,21 +13,19 @@ IN_SEG_FILEANME = "input_segmentation.npy" OUT_SEG_FILEANME = "output_segmentation.npy" TRACKS_FILENAME = "tracks.json" +GAPS_FILENAME = "gaps.txt" class MotileRun(BaseModel): - """An object representing a motile tracking run. It is frozen because a completed - run cannot be mutated. - TODO: lazy loading from zarr, don't re-save existing input zarr - (e.g. if its a chunk from a bigger zarr) - TODO: Do we need BaseModel? It requires kwargs which is mildly annoying + """An object representing a motile tracking run. """ - run_name: str = Field() - solver_params: SolverParams = Field() - input_segmentation: np.ndarray | None = Field(None) - output_segmentation: np.ndarray | None = Field(None) - tracks: nx.DiGraph | None = Field(None) - time: datetime = Field(datetime.now()) + run_name: str + solver_params: SolverParams + input_segmentation: np.ndarray | None = None + output_segmentation: np.ndarray | None = None + tracks: nx.DiGraph | None = None + time: datetime = datetime.now() + gaps: list[float] = [] class Config: allow_mutation = False @@ -57,6 +55,8 @@ def save(self, base_path: str | Path): self._save_segmentation(run_dir, IN_SEG_FILEANME, self.input_segmentation) self._save_segmentation(run_dir, OUT_SEG_FILEANME, self.output_segmentation) self._save_tracks(run_dir) + if self.gaps is not None: + self._save_gaps(run_dir) @classmethod def load(cls, run_dir: Path | str): @@ -67,6 +67,7 @@ def load(cls, run_dir: Path | str): input_segmentation = cls._load_segmentation(run_dir, IN_SEG_FILEANME) output_segmentation = cls._load_segmentation(run_dir, OUT_SEG_FILEANME) tracks = cls._load_tracks(run_dir) + gaps = cls._load_gaps(run_dir) return cls( run_name=run_name, solver_params=params, @@ -74,6 +75,7 @@ def load(cls, run_dir: Path | str): output_segmentation=output_segmentation, tracks=tracks, time=time, + gaps=gaps, ) def _save_params(self, run_dir): @@ -123,6 +125,25 @@ def _load_tracks(run_dir: Path, required: bool = True) -> nx.DiGraph: else: return None + def _save_gaps(self, run_dir: Path): + gaps_file = run_dir / GAPS_FILENAME + with open(gaps_file, 'w') as f: + f.write(','.join(map(str, self.gaps))) + + @staticmethod + def _load_gaps(run_dir, required: bool = True) -> list[float]: + gaps_file = run_dir / GAPS_FILENAME + if gaps_file.is_file(): + with open(gaps_file) as f: + gaps = list(map(float, f.read().split(","))) + return gaps + elif required: + raise FileNotFoundError(f"No gaps found at {gaps_file}") + else: + return None + + + def delete(self, base_path: str | Path): base_path = Path(base_path) run_dir = base_path / self._make_directory(self.time, self.run_name) @@ -132,5 +153,6 @@ def delete(self, base_path: str | Path): (run_dir / IN_SEG_FILEANME).unlink() (run_dir / OUT_SEG_FILEANME).unlink() (run_dir / TRACKS_FILENAME).unlink() + (run_dir / GAPS_FILENAME).unlink() run_dir.rmdir() diff --git a/src/motile_plugin/backend/solve.py b/src/motile_plugin/backend/solve.py index a2f9bef..f3215db 100644 --- a/src/motile_plugin/backend/solve.py +++ b/src/motile_plugin/backend/solve.py @@ -1,17 +1,16 @@ - import logging import time +import numpy as np from motile import Solver, TrackGraph -from motile.constraints import MaxChildren, MaxParents, ExclusiveNodes -from motile.costs import Appear, Disappear, EdgeSelection, Split +from motile.constraints import ExclusiveNodes, MaxChildren, MaxParents +from motile.costs import Appear, Disappear, EdgeDistance, EdgeSelection, Split from motile_toolbox.candidate_graph import ( EdgeAttr, NodeAttr, get_candidate_graph, graph_to_nx, ) -import numpy as np from .solver_params import SolverParams @@ -22,17 +21,17 @@ def solve( solver_params: SolverParams, segmentation: np.ndarray, on_solver_update=None, - ): +): cand_graph, conflict_sets = get_candidate_graph( segmentation, solver_params.max_edge_distance, iou=solver_params.iou_weight is not None, ) - logger.debug(f"Cand graph has {cand_graph.number_of_nodes()} nodes") + logger.debug("Cand graph has %d nodes", cand_graph.number_of_nodes()) solver = construct_solver(cand_graph, solver_params, conflict_sets) start_time = time.time() solution = solver.solve(verbose=False, on_event=on_solver_update) - logger.info(f"Solution took {time.time() - start_time} seconds") + logger.info("Solution took %.2f seconds", time.time() - start_time) solution_graph = solver.get_selected_subgraph(solution=solution) solution_nx_graph = graph_to_nx(solution_graph) @@ -41,9 +40,11 @@ def solve( def construct_solver(cand_graph, solver_params, exclusive_sets): - solver = Solver(TrackGraph(cand_graph, frame_attribute=NodeAttr.TIME.value)) + solver = Solver( + TrackGraph(cand_graph, frame_attribute=NodeAttr.TIME.value) + ) solver.add_constraints(MaxChildren(solver_params.max_children)) - solver.add_constraints(MaxParents(solver_params.max_parents)) + solver.add_constraints(MaxParents(1)) if exclusive_sets is None or len(exclusive_sets) > 0: solver.add_constraints(ExclusiveNodes(exclusive_sets)) @@ -53,20 +54,23 @@ def construct_solver(cand_graph, solver_params, exclusive_sets): solver.add_costs(Disappear(solver_params.disappear_cost)) if solver_params.division_cost is not None: solver.add_costs(Split(constant=solver_params.division_cost)) - if solver_params.merge_cost is not None: - from motile.costs import Merge - solver.add_costs(Merge(constant=solver_params.merge_cost)) if solver_params.distance_weight is not None: - solver.add_costs(EdgeSelection( - solver_params.distance_weight, - attribute=EdgeAttr.DISTANCE.value, - constant=solver_params.distance_offset), name="distance") + solver.add_costs( + EdgeDistance( + position_attribute=NodeAttr.POS.value, + weight=solver_params.distance_weight, + constant=solver_params.distance_offset, + ), + name="distance", + ) if solver_params.iou_weight is not None: - solver.add_costs(EdgeSelection( - weight=solver_params.iou_weight, - attribute=EdgeAttr.IOU.value, - constant=solver_params.iou_offset), name="iou") + solver.add_costs( + EdgeSelection( + weight=solver_params.iou_weight, + attribute=EdgeAttr.IOU.value, + constant=solver_params.iou_offset, + ), + name="iou", + ) return solver - - diff --git a/src/motile_plugin/backend/solver_params.py b/src/motile_plugin/backend/solver_params.py index 8b66bb9..7419fe8 100644 --- a/src/motile_plugin/backend/solver_params.py +++ b/src/motile_plugin/backend/solver_params.py @@ -1,72 +1,63 @@ - from pydantic import BaseModel, Field class SolverParams(BaseModel): + """The set of solver parameters supported in the motile plugin. + Used to build the UI as well as store parameters for runs. + """ + max_edge_distance: float = Field( 50.0, title="Max Move Distance", description=r"""The maximum distance an object center can move between time frames. -Objects further than this cannot be matched, but making this value larger will increase solving time.""" +Objects further than this cannot be matched, but making this value larger will increase solving time.""", ) max_children: int = Field( 2, title="Max Children", - description="The maximum number of object in time t+1 that can be linked to an item in time t.\nIf no division, set to 1." - ) - max_parents: int = Field( - 1, - title="Max Parent", - description=r"""The maximum number of object in time t that can be linked to an item in time t+1. -If no merging, set to 1.""" + description="The maximum number of object in time t+1 that can be linked to an item in time t.\nIf no division, set to 1.", ) appear_cost: float | None = Field( 30, title="Appear Cost", - description=r"""Cost for starting a new track. A higher value means fewer selected tracks and fewer merges.""" + description=r"""Cost for starting a new track. A higher value means fewer selected tracks and fewer merges.""", ) disappear_cost: float | None = Field( 30, title="Disappear Cost", - description=r"""Cost for ending a track. A higher value means fewer selected tracks and fewer divisions.""" + description=r"""Cost for ending a track. A higher value means fewer selected tracks and fewer divisions.""", ) division_cost: float | None = Field( 20, title="Division Cost", description=r"""Cost for a track dividing. A higher value means fewer divisions. -If this cost is higher than the appear cost, tracks will likely never divide.""" - ) - merge_cost: float | None = Field( - 20, - title="Merge Cost", - description=r"""Cost for a track merging. A higher value means fewer merges. -If this cost is higher than the disappear cost, tracks will likely never merge.""" +If this cost is higher than the appear cost, tracks will likely never divide.""", ) distance_weight: float | None = Field( 1, title="Distance Weight", description=r"""Value multiplied by distance to create the cost for selecting an edge. The weight should generally be positive, because a higher distance should be a higher cost. -The magnitude of the weight helps balance distance with other costs.""" +The magnitude of the weight helps balance distance with other costs.""", ) distance_offset: float | None = Field( -20, title="Distance Offset", description=r"""Value added to (distance * weight) to create the cost for selecting an edge. Usually should be negative to encourage anything being selected. -(If all costs are positive, the optimal solution is selecting nothing.)""" +(If all costs are positive, the optimal solution is selecting nothing.)""", ) iou_weight: float | None = Field( -5, title="IOU Weight", description=r"""Value multiplied by IOU to create cost. The weight should generally be negative, because a higher IOU should be a lower cost. -The magnitude of the weight helps balance IOU with other costs.""" +The magnitude of the weight helps balance IOU with other costs.""", ) iou_offset: float | None = Field( 0, title="IOU Offset", description=r"""Value added to (IOU * weight) to create cost. Zero is a sensible default with a negative weight, and will have the cost range from -weight to 0 -(In this case, any IOU will never hurt selection).""" +(In this case, any IOU will never hurt selection).""", ) diff --git a/src/motile_plugin/widgets/motile_widget.py b/src/motile_plugin/widgets/motile_widget.py index ceb6c9d..43e08b8 100644 --- a/src/motile_plugin/widgets/motile_widget.py +++ b/src/motile_plugin/widgets/motile_widget.py @@ -19,13 +19,14 @@ from .runs_list import RunsList from .run_editor import RunEditor from .run_viewer import RunViewer -from .solver_status import SolverStatus -from .solver_params import SolverParamsWidget +from qtpy.QtCore import Signal logger = logging.getLogger(__name__) class MotileWidget(QWidget): + solver_event = Signal(dict) + def __init__(self, viewer, graph_layer=False, multiseg=False): super().__init__() self.viewer: Viewer = viewer @@ -42,17 +43,14 @@ def __init__(self, viewer, graph_layer=False, multiseg=False): self.view_run_widget = RunViewer(None) self.view_run_widget.hide() + self.solver_event.connect(self.view_run_widget.solver_event_update) self.run_list_widget = RunsList() self.run_list_widget.view_run.connect(self.view_run) self.run_list_widget.edit_run.connect(self.edit_run) - self.solver_status_widget = SolverStatus() - self.solver_status_widget.hide() - main_layout.addWidget(self.view_run_widget) main_layout.addWidget(self.edit_run_widget) - main_layout.addWidget(self.solver_status_widget) main_layout.addWidget(self.run_list_widget) self.setLayout(main_layout) @@ -93,52 +91,48 @@ def add_napari_layers(self): def view_run(self, run: MotileRun): - # TODO: remove old layers from napari and replace with new + self.view_run_widget.reset_progress() self.view_run_widget.update_run(run) self.edit_run_widget.hide() self.view_run_widget.show() - # Add layers to Napari self.remove_napari_layers() self.update_napari_layers(run) self.add_napari_layers() - - + def edit_run(self, run: MotileRun | None): self.view_run_widget.hide() self.edit_run_widget.show() if run: self.edit_run_widget.new_run(run) - self.remove_napari_layers() def _generate_tracks(self, run): # Logic for generating tracks - logger.debug("Segmentation shape: %s", run.input_segmentation.shape) # do this in a separate thread so we can parse stdout and not block - self.solver_status_widget.show() + self.view_run_widget.reset_progress() + self.edit_run_widget.hide() + self.view_run_widget.update_run(run) + self.view_run_widget.set_solver_label("initializing") + self.view_run_widget.show() worker = self.solve_with_motile(run) worker.returned.connect(self._on_solve_complete) worker.start() - - def on_solver_update(self, event_data): - self.solver_status_widget.update(event_data) @thread_worker def solve_with_motile( self, run: MotileRun ): - run.tracks = solve(run.solver_params, run.input_segmentation, self.on_solver_update) + run.tracks = solve(run.solver_params, run.input_segmentation, self.solver_event.emit) run.output_segmentation = relabel_segmentation(run.tracks, run.input_segmentation) return run def _on_solve_complete(self, run: MotileRun): - self.solver_status_widget.hide() self.run_list_widget.add_run(run.copy(), select=True) - self.view_run(run) - self.solver_status_widget.reset() + self.view_run_widget.set_solver_label("done") + diff --git a/src/motile_plugin/widgets/run_viewer.py b/src/motile_plugin/widgets/run_viewer.py index 5c91f2a..d0b7817 100644 --- a/src/motile_plugin/widgets/run_viewer.py +++ b/src/motile_plugin/widgets/run_viewer.py @@ -16,6 +16,9 @@ import networkx as nx import json from pathlib import Path +import pyqtgraph as pg +import numpy as np + class RunViewer(QWidget): def __init__(self, run: MotileRun): @@ -28,34 +31,28 @@ def __init__(self, run: MotileRun): self.run_name_widget = QLabel("temp") self.params_widget = SolverParamsWidget(SolverParams(), editable=False) - self.save_run_dialog = QFileDialog() - self.save_run_dialog.setFileMode(QFileDialog.Directory) - self.save_run_dialog.setOption(QFileDialog.ShowDirsOnly, True) - - self.export_tracks_dialog = QFileDialog() - self.export_tracks_dialog.setFileMode(QFileDialog.AnyFile) - self.export_tracks_dialog.setAcceptMode(QFileDialog.AcceptSave) - self.export_tracks_dialog.setDefaultSuffix("json") - - title_widget = QWidget() - title_layout = QHBoxLayout() - title_layout.addWidget(self.run_name_widget) - title_layout.addWidget(self._save_widget()) - title_widget.setLayout(title_layout) + # Define persistent file dialogs for saving and exporting + self.save_run_dialog = self._save_dialog() + self.export_tracks_dialog = self._export_tracks_dialog() export_tracks_btn = QPushButton("Export tracks") export_tracks_btn.clicked.connect(self.export_tracks) + self.solver_label: QLabel + self.gap_plot = pg.PlotWidget + main_layout = QVBoxLayout() - main_layout.addWidget(title_widget) + main_layout.addWidget(self._title_widget()) main_layout.addWidget(export_tracks_btn) + main_layout.addWidget(self._progress_widget()) main_layout.addWidget(self.params_widget) self.setLayout(main_layout) def update_run(self, run: MotileRun): - print(f"Updating run with params {run.solver_params}") self.run = run self.run_name_widget.setText(self._run_name_view(self.run)) + self.plot_gaps() + self.set_solver_label("done") self.params_widget.new_params.emit(run.solver_params) def _run_name_view(self, run: MotileRun) -> str: @@ -69,10 +66,76 @@ def _save_widget(self): save_run_button.clicked.connect(self.save_run) return save_run_button + def _title_widget(self): + title_widget = QWidget() + title_layout = QHBoxLayout() + title_layout.addWidget(self.run_name_widget) + title_layout.addWidget(self._save_widget()) + title_layout.setContentsMargins(0, 0, 0, 0) + title_widget.setLayout(title_layout) + return title_widget + + def _progress_widget(self): + progress_widget = QWidget() + layout = QVBoxLayout() + self.solver_label = QLabel("") + self.gap_plot = self._plot_widget() + layout.addWidget(self.solver_label) + layout.addWidget(self.gap_plot) + layout.setContentsMargins(0, 0, 0, 0) + progress_widget.setLayout(layout) + return progress_widget + + def set_solver_label(self, status: str): + message = "Solver status: " + status + self.solver_label.setText(message) + + def _plot_widget(self) -> pg.PlotWidget: + class CustomAxisItem(pg.AxisItem): + + def tickStrings(self, values, scale, spacing): + print(f"{values=}, {scale=}, {spacing=}") + if self.logMode: + strings = self.logTickStrings(values, scale, spacing) + print(f"log {strings=}") + return strings + + places = max(0, np.ceil(-np.log10(spacing*scale))) + strings = [] + for v in values: + vs = v * scale + vstr = ("%%0.%df" % places) % vs + strings.append(vstr) + return strings + + gap_plot = pg.PlotWidget() + gap_plot.setBackground((37, 41, 49)) + styles = {"color": "white",} + gap_plot.plotItem.setLogMode(x=False, y=True) + gap_plot.plotItem.setLabel("left", "Gap", **styles) + gap_plot.plotItem.setLabel("bottom", "Solver round", **styles) + return gap_plot + + def plot_gaps(self): + gaps = self.run.gaps + self.gap_plot.getPlotItem().plot(range(len(gaps)), gaps) + + def _save_dialog(self) -> QFileDialog: + save_run_dialog = QFileDialog() + save_run_dialog.setFileMode(QFileDialog.Directory) + save_run_dialog.setOption(QFileDialog.ShowDirsOnly, True) + return save_run_dialog + + def _export_tracks_dialog(self) -> QFileDialog: + export_tracks_dialog = QFileDialog() + export_tracks_dialog.setFileMode(QFileDialog.AnyFile) + export_tracks_dialog.setAcceptMode(QFileDialog.AcceptSave) + export_tracks_dialog.setDefaultSuffix("json") + return export_tracks_dialog + def save_run(self): if self.save_run_dialog.exec_(): directory = self.save_run_dialog.selectedFiles()[0] - print(directory) self.run.save(directory) else: warn("Saving aborted") @@ -89,3 +152,19 @@ def export_tracks(self): else: warn("Exporting aborted") + def solver_event_update(self, event_data): + event_type = event_data["event_type"] + if event_type in ["PRESOLVE", "PRESOLVEROUND"]: + self.set_solver_label("presolving") + self.run.gaps = [] # try this to remove the weird initial gap for gurobi + elif event_type in ["MIPSOL", "BESTSOLFOUND"]: + self.set_solver_label("solving") + gap = event_data["gap"] + print(f"{gap=}") + self.run.gaps.append(event_data["gap"]) + self.plot_gaps() + + def reset_progress(self): + self.set_solver_label("not running") + self.gap_plot.getPlotItem().clear() + diff --git a/src/motile_plugin/widgets/solver_params.py b/src/motile_plugin/widgets/solver_params.py index 9880927..8142790 100644 --- a/src/motile_plugin/widgets/solver_params.py +++ b/src/motile_plugin/widgets/solver_params.py @@ -1,5 +1,6 @@ from functools import partial +from motile_plugin.backend.solver_params import SolverParams from qtpy.QtCore import Signal from qtpy.QtWidgets import ( QCheckBox, @@ -7,33 +8,37 @@ QFormLayout, QGroupBox, QHBoxLayout, + QLabel, QSpinBox, QVBoxLayout, QWidget, - QLabel, ) -from motile_plugin.backend.solver_params import SolverParams - class ParamLabel(QLabel): send_value = Signal(object) + def __init__(self, param_name, *args, **kwargs): super().__init__(*args, **kwargs) self.param_name = param_name - + def update_from_params(self, params: SolverParams): param_val = params.__getattribute__(self.param_name) if param_val is not None: - text = str(param_val) if isinstance(param_val, int) else f"{param_val:.1f}" + text = ( + str(param_val) + if isinstance(param_val, int) + else f"{param_val:.1f}" + ) self.setText(text) - + def toggle_enable(self, checked: bool): self.setEnabled(checked) - + class ParamSpinBox(QSpinBox): send_value = Signal(object) + def __init__(self, param_name, *args, **kwargs): super().__init__(*args, **kwargs) self.param_name = param_name @@ -41,12 +46,12 @@ def __init__(self, param_name, *args, **kwargs): # necessary to have custom signal that is also emitted when checkboxes are # checked, without changing the spinbox value self.valueChanged.connect(self.send_value.emit) - + def update_from_params(self, params: SolverParams): param_val = params.__getattribute__(self.param_name) if param_val is not None: self.setValue(param_val) - + def toggle_enable(self, checked: bool): if checked: self.enable() @@ -56,6 +61,7 @@ def toggle_enable(self, checked: bool): class ParamDoubleSpinBox(QDoubleSpinBox): send_value = Signal(object) + def __init__(self, param_name, *args, **kwargs): super().__init__(*args, **kwargs) self.param_name = param_name @@ -68,7 +74,7 @@ def update_from_params(self, params: SolverParams): param_val = params.__getattribute__(self.param_name) if param_val is not None: self.setValue(param_val) - + def toggle_enable(self, checked: bool): if checked: self.setEnabled(True) @@ -77,18 +83,19 @@ def toggle_enable(self, checked: bool): self.setEnabled(False) self.send_value.emit(None) + class ParamCheckBox(QCheckBox): def __init__(self, param_name, *args, **kwargs): super().__init__(*args, **kwargs) self.param_name = param_name - + def update_from_params(self, params: SolverParams): param_val = params.__getattribute__(self.param_name) if param_val is None: self.setChecked(False) else: self.setChecked(True) - + class ParamCheckGroup(QGroupBox): def __init__(self, param_names, *args, **kwargs): @@ -96,8 +103,10 @@ def __init__(self, param_names, *args, **kwargs): self.param_names = param_names def update_from_params(self, params: SolverParams): - param_vals = [params.__getattribute__(name) for name in self.param_names] - if all([v is None for v in param_vals]): + param_vals = [ + params.__getattribute__(name) for name in self.param_names + ] + if all(v is None for v in param_vals): self.setChecked(False) else: self.setChecked(True) @@ -105,32 +114,38 @@ def update_from_params(self, params: SolverParams): class SolverParamsWidget(QWidget): new_params = Signal(SolverParams) - """ Widget for viewing and editing SolverParams. - Spinboxes will be created for each parameter in SolverParams and linked such that - editing the value in the spinbox will change the corresponding parameter. + """ Widget for viewing and editing SolverParams. + Spinboxes will be created for each parameter in SolverParams and linked such that + editing the value in the spinbox will change the corresponding parameter. Checkboxes will also be created for each optional parameter (group) and linked such - that unchecking the box will update the parameter value to None, and checking will + that unchecking the box will update the parameter value to None, and checking will update the parameter to the current spinbox value. If editable is false, the whole widget will be disabled. To update for a backend change to SolverParams, emit the new_params signal, which the spinboxes and checkboxes will connect to and use to update the UI and thus the stored solver params. """ + def __init__(self, solver_params: SolverParams, editable=False): super().__init__() self.solver_params = solver_params self.editable = editable self.param_categories = { - "data_params": ["max_edge_distance", "max_children", "max_parents"], - "constant_costs": ["appear_cost", "division_cost", "disappear_cost"], - "variable_costs": ["distance", "iou",], - "fixed": [("merge_cost", None)] + "data_params": ["max_edge_distance", "max_children"], + "constant_costs": [ + "appear_cost", + "division_cost", + "disappear_cost", + ], + "variable_costs": [ + "distance", + "iou", + ], } - for fixed_param, val in self.param_categories["fixed"]: - solver_params.__setattr__(fixed_param, val) main_layout = QVBoxLayout() main_layout.addWidget(self._ui_data_specific_hyperparameters()) main_layout.addWidget(self._ui_constant_costs()) + main_layout.setContentsMargins(0, 0, 0, 0) for group in self._ui_variable_costs(): main_layout.addWidget(group) self.setLayout(main_layout) @@ -146,7 +161,12 @@ def _ui_data_specific_hyperparameters(self) -> QGroupBox: spinbox = self._param_spinbox(param_name, negative=False) else: spinbox = self._param_label(param_name) - self._add_form_row(hyperparameters_layout, field.title, spinbox, tooltip=field.description) + self._add_form_row( + hyperparameters_layout, + field.title, + spinbox, + tooltip=field.description, + ) hyperparameters_group.setLayout(hyperparameters_layout) return hyperparameters_group @@ -179,12 +199,14 @@ def _ui_variable_costs(self) -> list[QGroupBox]: title = f"{param_type.title()} Cost" group_tooltip = f"Use the {param_type.title()} between objects as a linking feature." param_names = [f"{param_type}_weight", f"{param_type}_offset"] - groups.append(self._create_feature_cost_group( - title, - param_names=param_names, - checked=True, - group_tooltip=group_tooltip, - )) + groups.append( + self._create_feature_cost_group( + title, + param_names=param_names, + checked=True, + group_tooltip=group_tooltip, + ) + ) return groups def _create_feature_cost_group( @@ -207,7 +229,9 @@ def _create_feature_cost_group( else: spinbox = self._param_label(param_name) feature_cost.toggled.connect(spinbox.toggle_enable) - self._add_form_row(layout, field.title, spinbox, tooltip=field.description) + self._add_form_row( + layout, field.title, spinbox, tooltip=field.description + ) feature_cost.setLayout(layout) return feature_cost @@ -232,12 +256,11 @@ def _param_spinbox(self, param_name, negative=False) -> QWidget: spinbox = ParamDoubleSpinBox(param_name) spinbox.setDecimals(1) else: - raise ValueError(f"Expected dtype int or float, got {field.annotation}") + raise ValueError( + f"Expected dtype int or float, got {field.annotation}" + ) max_val = 10000 - if negative: - min_val = -1 * max_val - else: - min_val = 0 + min_val = -1 * max_val if negative else 0 spinbox.setRange(min_val, max_val) curr_val = self.solver_params.__getattribute__(param_name) if curr_val is None: @@ -259,8 +282,9 @@ def _param_label(self, param_name) -> ParamLabel: self.new_params.connect(param_label.update_from_params) return param_label - def _add_form_row(self, layout: QFormLayout, label, value, tooltip=None): layout.addRow(label, value) - row_widget = layout.itemAt(layout.rowCount() - 1, QFormLayout.LabelRole).widget() + row_widget = layout.itemAt( + layout.rowCount() - 1, QFormLayout.LabelRole + ).widget() row_widget.setToolTip(tooltip) diff --git a/src/motile_plugin/widgets/solver_status.py b/src/motile_plugin/widgets/solver_status.py deleted file mode 100644 index 7fcfd48..0000000 --- a/src/motile_plugin/widgets/solver_status.py +++ /dev/null @@ -1,29 +0,0 @@ -from qtpy.QtWidgets import ( - QWidget, - QVBoxLayout, - QLabel -) - -class SolverStatus(QWidget): - def __init__(self): - super().__init__() - main_layout = QVBoxLayout() - main_layout.addWidget(QLabel("Solver is running")) - self.setLayout(main_layout) - self.presolving = False - self.round = 0 - self.gap = None - - def update(self, event_data): - event_type = event_data["event_type"] - if event_type in ["PRESOLVE", "PRESOLVEROUND"]: - self.presolving = True - elif event_type in ["MIPSOL", "BESTSOLFOUND"]: - self.presolving = False - self.round += 1 - self.gap = event_data["gap"] - - def reset(self): - self.presolving = False - self.round = 0 - self.gap = None \ No newline at end of file