Skip to content

Commit

Permalink
Merge pull request #90 from funkelab/fix_graph_data_json_serializable
Browse files Browse the repository at this point in the history
convert all graph data where necessary before saving to make it json serializable
  • Loading branch information
cmalinmayor authored Oct 21, 2024
2 parents 8a1d1fb + d685b38 commit 240db32
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/motile_plugin/data_model/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def get_time(self, node: Node) -> int:
return int(self.get_times([node])[0])

def set_times(self, nodes: Iterable[Node], times: Iterable[int]):
times = [int(t) for t in times]
self._remove_from_seg_time_to_node(nodes)
self._set_nodes_attr(nodes, self.time_attr, times)
self._add_to_seg_time_to_node(nodes)
Expand All @@ -176,7 +177,7 @@ def set_time(self, node: Any, time: int):
time (int): The time to set
"""
self.set_times([node], [time])
self.set_times([node], [int(time)])

def get_seg_ids(
self, nodes: Iterable[Node], required=False
Expand Down Expand Up @@ -205,12 +206,13 @@ def set_seg_ids(self, nodes: Iterable[Node], seg_ids: Iterable[int]):
node (Any): The node id to set the seg id of
seg_id (int): The segmentation id to set for the node
"""
seg_ids = [int(seg_id) for seg_id in seg_ids]
self._remove_from_seg_time_to_node(nodes)
self._set_nodes_attr(nodes, NodeAttr.SEG_ID.value, seg_ids)
self._add_to_seg_time_to_node(nodes)

def set_seg_id(self, node: Node, seg_id: int):
self.set_seg_ids([node], [seg_id])
self.set_seg_ids([node], [int(seg_id)])

def add_nodes(
self,
Expand Down Expand Up @@ -493,8 +495,29 @@ def _save_graph(self, directory: Path):
directory (Path): The directory in which to save the graph file.
"""
graph_file = directory / self.GRAPH_FILE
graph_data = nx.node_link_data(self.graph)

def convert_np_types(data):
"""Recursively convert numpy types to native Python types."""

if isinstance(data, dict):
return {key: convert_np_types(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_np_types(item) for item in data]
elif isinstance(data, np.ndarray):
return data.tolist() # Convert numpy arrays to Python lists
elif isinstance(data, np.integer):
return int(data) # Convert numpy integers to Python int
elif isinstance(data, np.floating):
return float(data) # Convert numpy floats to Python float
else:
return (
data # Return the data as-is if it's already a native Python type
)

graph_data = convert_np_types(graph_data)
with open(graph_file, "w") as f:
json.dump(nx.node_link_data(self.graph), f)
json.dump(graph_data, f)

def _save_seg(self, directory: Path):
"""Save a segmentation as a numpy array using np.save. In the future,
Expand All @@ -514,9 +537,7 @@ def _save_attrs(self, directory: Path):
"""
out_path = directory / self.ATTRS_FILE
attrs_dict = {
"time_attr": self.time_attr
if not isinstance(self.time_attr, np.ndarray)
else self.time_attr.tolist(),
"time_attr": self.time_attr,
"pos_attr": self.pos_attr
if not isinstance(self.pos_attr, np.ndarray)
else self.pos_attr.tolist(),
Expand Down

0 comments on commit 240db32

Please sign in to comment.