Skip to content

Commit

Permalink
added functionality to draw the network in pl.scatter_labels()
Browse files Browse the repository at this point in the history
  • Loading branch information
MeyerBender committed Oct 10, 2024
1 parent f6c1add commit eee8a1b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
10 changes: 9 additions & 1 deletion spatialproteomics/la/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def _relabel_dict(self, dictionary: dict):
_, fw, _ = relabel_sequential(self._obj.coords[Dims.LABELS].values)
return {fw[k]: v for k, v in dictionary.items()}

def _label_to_dict(self, prop: str, reverse: bool = False, relabel: bool = False) -> dict:
def _label_to_dict(
self, prop: str, reverse: bool = False, relabel: bool = False, keys_as_str: bool = False
) -> dict:
"""
Returns a dictionary that maps each label to a list to their property.
Expand All @@ -173,6 +175,8 @@ def _label_to_dict(self, prop: str, reverse: bool = False, relabel: bool = False
If True, the dictionary will be reversed.
relabel : bool
If True, the dictionary keys will be relabeled.
keys_as_str : bool
If True, the dictionary keys will be converted to the cell type labels instead of the numeric keys.
Returns
-------
Expand All @@ -194,6 +198,10 @@ def _label_to_dict(self, prop: str, reverse: bool = False, relabel: bool = False
if reverse:
label_dict = {v: k for k, v in label_dict.items()}

if keys_as_str:
labels = dict(zip(labels.values, self._obj.pp.get_layer_as_df(Layers.LA_PROPERTIES)[Props.NAME].values))
label_dict = {labels[k]: v for k, v in label_dict.items()}

return label_dict

def _cells_to_label(self, relabel: bool = False, include_unlabeled: bool = False) -> dict:
Expand Down
59 changes: 52 additions & 7 deletions spatialproteomics/pl/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def render_neighborhoods(
else:
attrs = {}
rendered = _render_labels(mask, cmap, alpha=alpha, alpha_boundary=alpha_boundary, mode="inner")
else:
elif style == "neighborhoods":
cells_dict = self._obj.nh._cells_to_neighborhood(relabel=True)

# step 1: apply a Voronoi tesselation to the neighborhoods
Expand Down Expand Up @@ -939,6 +939,7 @@ def scatter_labels(
size: float = 1.0,
alpha: float = 0.9,
zorder: int = 10,
render_edges: bool = False,
ax=None,
legend_kwargs: dict = {"framealpha": 1},
scatter_kwargs: dict = {},
Expand All @@ -956,6 +957,8 @@ def scatter_labels(
Transparency of the scatter markers. Default is 0.9.
zorder : int, optional
The z-order of the scatter markers. Default is 10.
render_edges : bool, optional
Whether to render the edges between cells within the same neighborhood. Default is False.
ax : matplotlib.axes.Axes, optional
The axes on which to plot the scatter. If not provided, the current axes will be used.
legend_kwargs : dict, optional
Expand All @@ -974,12 +977,54 @@ def scatter_labels(
color_dict = self._obj.la._label_to_dict(Props.COLOR)
label_dict = self._obj.la._cells_to_label()

for celltype in label_dict.keys():
label_subset = self._obj.la[celltype]
obs_layer = label_subset[Layers.OBS]
x = obs_layer.loc[:, Features.X]
y = obs_layer.loc[:, Features.Y]
ax.scatter(x.values, y.values, s=size, c=color_dict[celltype], alpha=alpha, zorder=zorder, **scatter_kwargs)
if not render_edges:
for celltype in label_dict.keys():
label_subset = self._obj.la[celltype]
obs_layer = label_subset[Layers.OBS]
x = obs_layer.loc[:, Features.X]
y = obs_layer.loc[:, Features.Y]
ax.scatter(
x.values, y.values, s=size, c=color_dict[celltype], alpha=alpha, zorder=zorder, **scatter_kwargs
)
else:
# if we want to render the edges, we need to use the adjacency matrix and networkx
assert (
Layers.ADJACENCY_MATRIX in self._obj
), "No adjacency matrix found in the object. Please compute the adjacency matrix first by running either of the methods contained in the nh module (e. g. nh.compute_neighborhoods_radius())."

try:
import networkx as nx

adjacency_matrix = self._obj[Layers.ADJACENCY_MATRIX].values
G = nx.from_numpy_array(adjacency_matrix)
spatial_df = self._obj.pp.get_layer_as_df(Layers.OBS)
assert Features.X in spatial_df.columns, f"Feature {Features.X} not found in the observation layer."
assert Features.Y in spatial_df.columns, f"Feature {Features.Y} not found in the observation layer."
assert (
Features.LABELS in spatial_df.columns
), f"Feature {Features.LABELS} not found in the observation layer."
spatial_df = spatial_df[[Features.X, Features.Y, Features.LABELS]].reset_index(drop=True)
# Create node positions based on the centroid coordinates
positions = {
i: (spatial_df.loc[i, Features.X], spatial_df.loc[i, Features.Y]) for i in range(len(spatial_df))
}
color_dict = self._obj.la._label_to_dict(Props.COLOR, keys_as_str=True)

# Assign node colors based on the label
node_colors = [color_dict[spatial_df.loc[i, Features.LABELS]] for i in range(len(spatial_df))]
# Use networkx to draw the graph
nx.draw(
G,
pos=positions,
node_color=node_colors,
with_labels=False,
node_size=size,
edge_color="gray",
ax=ax,
**scatter_kwargs,
)
except ImportError:
raise ImportError("Please install networkx to render edges between cells.")

xmin, xmax, ymin, ymax = self._obj.pl._get_bounds()
ax.set_ylim([ymin, ymax])
Expand Down

0 comments on commit eee8a1b

Please sign in to comment.