Skip to content

Commit

Permalink
Merge pull request #138 from ckmah/lp-fix
Browse files Browse the repository at this point in the history
LP Fix
  • Loading branch information
ckmah authored Aug 20, 2024
2 parents 21a6f2c + b3e38d4 commit 7fdb42c
Show file tree
Hide file tree
Showing 21 changed files with 317 additions and 206 deletions.
4 changes: 2 additions & 2 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"context": "..",
"args": {
// Update 'VARIANT' to pick a Python version: 3, 3.6, 3.7, 3.8, 3.9
"VARIANT": "3.8",
"VARIANT": "3.11",
// Options
"INSTALL_NODE": "true",
"NODE_VERSION": "lts/*"
Expand Down Expand Up @@ -41,7 +41,7 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "pip3 install poetry==1.2.0; pip3 install -e .",
"postCreateCommand": "curl -sSf https://rye.astral.sh/get | RYE_VERSION='0.38.0' RYE_INSTALL_OPTION='--yes' bash",

// Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
"remoteUser": "vscode"
Expand Down
21 changes: 13 additions & 8 deletions bento/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import geopandas as gpd
import numpy as np
import pandas as pd
from dask import dataframe as dd
import dask

dask.config.set({"dataframe.query-planning": False})
import dask.dataframe as dd
from spatialdata import SpatialData
from spatialdata.models import PointsModel, ShapesModel, TableModel

Expand Down Expand Up @@ -34,12 +37,14 @@ def filter_by_gene(
-------
sdata : SpatialData
.points[points_key] is updated to remove genes with low expression.
.table is updated to remove genes with low expression.
.tables["table"] is updated to remove genes with low expression.
"""
gene_filter = (sdata.table.X >= min_count).sum(axis=0) > 0
filtered_table = sdata.table[:, gene_filter]
gene_filter = (sdata.tables["table"].X >= min_count).sum(axis=0) > 0
filtered_table = sdata.tables["table"][:, gene_filter]

filtered_genes = list(sdata.table.var_names.difference(filtered_table.var_names))
filtered_genes = list(
sdata.tables["table"].var_names.difference(filtered_table.var_names)
)
points = get_points(sdata, points_key=points_key, astype="pandas", sync=False)
points = points[~points[feature_key].isin(filtered_genes)]
points[feature_key] = points[feature_key].cat.remove_unused_categories()
Expand All @@ -52,10 +57,10 @@ def filter_by_gene(
sdata.points[points_key] = points

try:
del sdata.table
del sdata.tables["table"]
except KeyError:
pass
sdata.table = TableModel.parse(filtered_table)
sdata.tables["table"] = TableModel.parse(filtered_table)

return sdata

Expand Down Expand Up @@ -126,7 +131,7 @@ def get_shape(sdata: SpatialData, shape_key: str, sync: bool = True) -> gpd.GeoS
GeoSeries
GeoSeries of Polygon objects
"""
instance_key = sdata.table.uns["spatialdata_attrs"]["instance_key"]
instance_key = sdata.tables["table"].uns["spatialdata_attrs"]["instance_key"]

# Make sure shape exists in sdata.shapes
if shape_key not in sdata.shapes.keys():
Expand Down
6 changes: 3 additions & 3 deletions bento/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def prep(
by=instance_key,
value_key=feature_key,
aggfunc="count",
).table
).tables["table"]
)

pbar.update()

try:
del sdata.table
del sdata.tables["table"]
except KeyError:
pass

sdata.table = table
sdata.tables["table"] = table
# Set instance key to cell_shape_key for all points and table
sdata.points[points_key].attrs["spatialdata_attrs"]["instance_key"] = instance_key
sdata.points[points_key].attrs["spatialdata_attrs"]["feature_key"] = feature_key
Expand Down
27 changes: 16 additions & 11 deletions bento/plotting/_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from ._utils import savefig
from ._multidimensional import _radviz


@savefig
def lp_dist(sdata, show_counts="", show_percentages='{:.1%}', scale=1, fname=None):
def lp_dist(sdata, show_counts="", show_percentages="{:.1%}", scale=1, fname=None):
"""Plot pattern combination frequencies as an UpSet plot.
Parameters
Expand All @@ -32,7 +33,7 @@ def lp_dist(sdata, show_counts="", show_percentages='{:.1%}', scale=1, fname=Non
fname : str, optional
Save the figure to specified filename, by default None
"""
sample_labels = sdata.table.uns["lp"]
sample_labels = sdata.tables["table"].uns["lp"]
sample_labels = sample_labels == 1

# Sort by degree, then pattern name
Expand Down Expand Up @@ -60,12 +61,13 @@ def lp_dist(sdata, show_counts="", show_percentages='{:.1%}', scale=1, fname=Non
upset.plot()
plt.suptitle(f"Localization Patterns\n{sample_labels.shape[0]} samples")


@savefig
def lp_genes(
sdata: SpatialData,
groupby: str = "feature_name",
points_key = "transcripts",
instance_key = "cell_boundaries",
points_key="transcripts",
instance_key="cell_boundaries",
annotate: Union[int, List[str], None] = None,
sizes: Tuple[int] = (2, 100),
size_norm: Tuple[int] = (0, 100),
Expand Down Expand Up @@ -101,17 +103,19 @@ def lp_genes(

palette = dict(zip(PATTERN_NAMES, PATTERN_COLORS))

n_cells = sdata.table.n_obs
gene_frac = sdata.table.uns["lp_stats"][PATTERN_NAMES] / n_cells
n_cells = sdata.tables["table"].n_obs
gene_frac = sdata.tables["table"].uns["lp_stats"][PATTERN_NAMES] / n_cells
genes = gene_frac.index
gene_expression_array = sdata.table[:,genes].X.toarray()
gene_expression_array = sdata.tables["table"][:, genes].X.toarray()
gene_logcount = gene_expression_array.mean(axis=0, where=gene_expression_array > 0)
gene_logcount = np.log2(gene_logcount + 1)
gene_frac["logcounts"] = gene_logcount

cell_fraction = (
100
* get_points(sdata, points_key, astype="pandas", sync=True).groupby(groupby, observed=True)[instance_key].nunique()
* get_points(sdata, points_key, astype="pandas", sync=True)
.groupby(groupby, observed=True)[instance_key]
.nunique()
/ n_cells
)
gene_frac["cell_fraction"] = cell_fraction
Expand All @@ -120,6 +124,7 @@ def lp_genes(
scatter_kws.update(kwargs)
_radviz(gene_frac, annotate=annotate, ax=ax, **scatter_kws)


@savefig
def lp_diff_discrete(sdata: SpatialData, phenotype: str, fname: str = None):
"""Visualize gene pattern frequencies between groups of cells. Plots the
Expand All @@ -134,7 +139,7 @@ def lp_diff_discrete(sdata: SpatialData, phenotype: str, fname: str = None):
fname : str, optional
Save the figure to specified filename, by default None
"""
diff_stats = sdata.table.uns[f"diff_{phenotype}"]
diff_stats = sdata.tables["table"].uns[f"diff_{phenotype}"]

palette = dict(zip(PATTERN_NAMES, PATTERN_COLORS))
g = sns.relplot(
Expand Down Expand Up @@ -162,4 +167,4 @@ def lp_diff_discrete(sdata: SpatialData, phenotype: str, fname: str = None):
) # line where FDR = 0.05
sns.despine()

return g
return g
6 changes: 3 additions & 3 deletions bento/plotting/_multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def comp(
"""

comp_key = f"{groupby}_comp_stats"
if groupby and comp_key in sdata.table.uns.keys():
comp_stats = sdata.table.uns[comp_key]
if groupby and comp_key in sdata.tables["table"].uns.keys():
comp_stats = sdata.tables["table"].uns[comp_key]
if group_order is None:
groups = list(comp_stats.keys())
else:
Expand Down Expand Up @@ -240,7 +240,7 @@ def comp(
ax.set_title(group, fontsize=12)
else:
comp_key = "comp_stats"
comp_stats = sdata.table.uns[comp_key]
comp_stats = sdata.tables["table"].uns[comp_key]
return _radviz(
comp_stats,
annotate=annotate,
Expand Down
9 changes: 5 additions & 4 deletions bento/plotting/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._colors import red2blue, red_light
from ._utils import savefig


def colocation(
sdata,
rank,
Expand Down Expand Up @@ -42,9 +43,9 @@ def colocation(
fname : str, optional
Path to save figure, by default None
"""
factors = sdata.table.uns["factors"][rank].copy()
labels = sdata.table.uns["tensor_labels"].copy()
names = sdata.table.uns["tensor_names"].copy()
factors = sdata.tables["table"].uns["factors"][rank].copy()
labels = sdata.tables["table"].uns["tensor_labels"].copy()
names = sdata.tables["table"].uns["tensor_names"].copy()

# Perform z-scaling upfront
for i in range(len(factors)):
Expand Down Expand Up @@ -250,4 +251,4 @@ def _plot_loading(df, name, n_top, cut, show_labels, cluster, ax, **kwargs):

ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_title(f"{name}: [{df.shape[0]} x {df.shape[1]}]")
sns.despine(ax=ax, right=False, top=False)
sns.despine(ax=ax, right=False, top=False)
22 changes: 11 additions & 11 deletions bento/tools/_colocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def colocation(
Returns
-------
sdata : SpatialData
.table.uns['factors']: Decomposed tensor factors.
.table.uns['factors_error']: Decomposition error.
.tables["table"].uns['factors']: Decomposed tensor factors.
.tables["table"].uns['factors_error']: Decomposition error.
"""

print("Preparing tensor...")
_colocation_tensor(sdata, instance_key, feature_key)

tensor = sdata.table.uns["tensor"]
tensor = sdata.tables["table"].uns["tensor"]

print(emoji.emojize(":running: Decomposing tensor..."))
factors, errors = decompose(tensor, ranks, iterations=iterations)
Expand All @@ -61,8 +61,8 @@ def colocation(
kl.plot_knee()
sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o")

sdata.table.uns["factors"] = factors
sdata.table.uns["factors_error"] = errors
sdata.tables["table"].uns["factors"] = factors
sdata.tables["table"].uns["factors_error"] = errors

print(emoji.emojize(":heavy_check_mark: Done."))

Expand All @@ -81,7 +81,7 @@ def _colocation_tensor(sdata: SpatialData, instance_key: str, feature_key: str):
Key that specifies genes in sdata.
"""

clqs = sdata.table.uns["clq"]
clqs = sdata.tables["table"].uns["clq"]

clq_long = []
for shape, clq in clqs.items():
Expand All @@ -106,9 +106,9 @@ def _colocation_tensor(sdata: SpatialData, instance_key: str, feature_key: str):
s = sparse.COO(label_orders, data=clq_long["log_clq"].values)
tensor = s.todense()

sdata.table.uns["tensor"] = tensor
sdata.table.uns["tensor_labels"] = labels
sdata.table.uns["tensor_names"] = label_names
sdata.tables["table"].uns["tensor"] = tensor
sdata.tables["table"].uns["tensor_labels"] = labels
sdata.tables["table"].uns["tensor_names"] = label_names


def coloc_quotient(
Expand Down Expand Up @@ -145,7 +145,7 @@ def coloc_quotient(
Returns
-------
sdata : SpatialData
.table.uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe.
.tables["table"].uns['clq']: Pairwise gene colocalization similarity within each cell formatted as a long dataframe.
"""

all_clq = dict()
Expand Down Expand Up @@ -191,7 +191,7 @@ def coloc_quotient(
# Save to uns['clq'] as adjacency list
all_clq[shape] = cell_clqs

sdata.table.uns["clq"] = all_clq
sdata.tables["table"].uns["clq"] = all_clq


def _cell_clq(cell_points, radius, min_points, feature_key):
Expand Down
6 changes: 3 additions & 3 deletions bento/tools/_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list):
Returns
-------
sdata : spatialdata.SpatialData
Updates `sdata.table.uns` with average gene compositions for each shape.
Updates `sdata.tables["table"].uns` with average gene compositions for each shape.
"""
points = get_points(sdata, points_key=points_key, astype="pandas")

Expand All @@ -83,7 +83,7 @@ def comp(sdata: SpatialData, points_key: str, shape_names: list):
points, shape_names, instance_key=instance_key, feature_key=feature_key
)

sdata.table.uns["comp_stats"] = comp_stats
sdata.tables["table"].uns["comp_stats"] = comp_stats


def comp_diff(
Expand Down Expand Up @@ -132,4 +132,4 @@ def comp_diff(
index=ref_comp.index,
)

sdata.table.uns[f"{groupby}_comp_stats"] = comp_stats
sdata.tables["table"].uns[f"{groupby}_comp_stats"] = comp_stats
12 changes: 6 additions & 6 deletions bento/tools/_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def flux(
flux values: <gene_name> for each gene used in embedding.
embeddings: flux_embed_<i> for each component of the embedding.
colors: hex color codes for each pixel.
.table.uns["flux_genes"] : list
.tables["table"].uns["flux_genes"] : list
List of genes used for embedding.
.table.uns["flux_variance_ratio"] : np.ndarray
.tables["table"].uns["flux_variance_ratio"] : np.ndarray
[components] array of explained variance ratio for each component.
"""

Expand Down Expand Up @@ -145,7 +145,7 @@ def flux(
# points_grouped = dask.delayed(points_grouped)
# rpoints_grouped = dask.delayed(rpoints_grouped)

cell_composition = sdata.table[cells, gene_names].X.toarray()
cell_composition = sdata.tables["table"][cells, gene_names].X.toarray()

# Compute cell composition
cell_composition = cell_composition / (cell_composition.sum(axis=1).reshape(-1, 1))
Expand Down Expand Up @@ -268,8 +268,8 @@ def process_cell(bag):
columns=metadata.columns,
)

sdata.table.uns["flux_variance_ratio"] = variance_ratio
sdata.table.uns["flux_genes"] = gene_names # gene names
sdata.tables["table"].uns["flux_variance_ratio"] = variance_ratio
sdata.tables["table"].uns["flux_genes"] = gene_names # gene names

pbar.set_description(emoji.emojize("Done. :bento_box:"))
pbar.update()
Expand Down Expand Up @@ -523,7 +523,7 @@ def fluxmap(
del sdata.shapes[key]

sd_attrs = sdata.shapes[instance_key].attrs
fluxmap_df = fluxmap_df.reindex(sdata.table.obs_names).where(
fluxmap_df = fluxmap_df.reindex(sdata.tables["table"].obs_names).where(
fluxmap_df.notna(), other=Polygon()
)
fluxmap_names = fluxmap_df.columns.tolist()
Expand Down
Loading

0 comments on commit 7fdb42c

Please sign in to comment.