diff --git a/pyproject.toml b/pyproject.toml index af5444f..6b273d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ authors = [ ] dynamic = ["version"] dependencies = [ - "motile", + "motile>=0.3.0", "networkx", "numpy", "matplotlib", diff --git a/src/motile_toolbox/candidate_graph/graph_to_nx.py b/src/motile_toolbox/candidate_graph/graph_to_nx.py index ac722a3..5260c2a 100644 --- a/src/motile_toolbox/candidate_graph/graph_to_nx.py +++ b/src/motile_toolbox/candidate_graph/graph_to_nx.py @@ -2,11 +2,14 @@ from motile import TrackGraph -def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: +def graph_to_nx(graph: TrackGraph, flatten_hyperedges=True) -> nx.DiGraph: """Convert a motile TrackGraph into a networkx DiGraph. Args: graph (TrackGraph): TrackGraph to be converted to networkx + flatten_hyperedges (bool, optional): If True, include one edge for each + (source, target) combo in a hyperedge. If False, introduce a new + hypernode to represent hyperedges. Defaults to True. Returns: nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes. @@ -14,8 +17,25 @@ def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: nx_graph = nx.DiGraph() nodes_list = list(graph.nodes.items()) nx_graph.add_nodes_from(nodes_list) - edges_list = [ - (edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items() - ] + edges_list = [] + for edge, data in graph.edges.items(): + if graph.is_hyperedge(edge): + us, vs = edge + if flatten_hyperedges: + # flatten the hyperedges into multiple normal edges + for u in us: + for v in vs: + edges_list.append((u, v, data)) + else: + # add a hypernode to connect all in nodes with all out nodes + hypernode_id = "_".join(list(map(str, us)) + list(map(str, vs))) + for u in us: + edges_list.append((u, hypernode_id, data)) + for v in vs: + edges_list.append((hypernode_id, v, data)) + else: + u, v = edge + edges_list.append((u, v, data)) + nx_graph.add_edges_from(edges_list) return nx_graph diff --git a/tests/visualization/test_plot.py b/tests/visualization/test_plot.py index 834a1be..5be8392 100644 --- a/tests/visualization/test_plot.py +++ b/tests/visualization/test_plot.py @@ -12,10 +12,10 @@ @pytest.fixture def solver(arlo_graph: motile.TrackGraph) -> motile.Solver: solver = motile.Solver(arlo_graph) - solver.add_costs(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) - solver.add_costs(EdgeSelection(weight=1.0, attribute="prediction_distance")) - solver.add_costs(Appear(constant=200.0)) - solver.add_costs(Split(constant=100.0)) + solver.add_cost(NodeSelection(weight=-1.0, attribute="score", constant=-100.0)) + solver.add_cost(EdgeSelection(weight=1.0, attribute="prediction_distance")) + solver.add_cost(Appear(constant=200.0)) + solver.add_cost(Split(constant=100.0)) return solver