Skip to content

Commit

Permalink
fix(plots): update vector field summary interface
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Aug 27, 2024
1 parent fde8e62 commit e900db0
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions src/pyrovelocity/plots/_vector_fields.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from os import PathLike
from typing import Dict

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scvelo as scv
import seaborn as sns
from anndata import AnnData
from beartype import beartype
from beartype.typing import Dict, Optional
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.figure import FigureBase
Expand Down Expand Up @@ -38,6 +39,9 @@ def plot_vector_field_summary(
vector_field_basis: str,
plot_name: PathLike | str,
cell_state: str = "cell_type",
state_color_dict: Optional[Dict[str, str]] = None,
default_fontsize: int = 7,
default_title_padding: int = 2,
) -> FigureBase:
# posterior_vector_field = posterior_samples["vector_field_posterior_samples"]
posterior_time = posterior_samples["cell_time"]
Expand All @@ -47,7 +51,6 @@ def plot_vector_field_summary(
embed_mean = posterior_samples["vector_field_posterior_mean"]

dot_size = 3.5
font_size = 6.5
scale = 0.35
scale_high = 7.8
scale_low = 7.8
Expand All @@ -71,25 +74,31 @@ def plot_vector_field_summary(
sns.scatterplot(
x="X1",
y="X2",
hue="cell_type",
s=dot_size,
palette=state_color_dict,
data=ress,
alpha=0.9,
s=dot_size,
linewidth=0,
edgecolor="none",
hue="cell_type",
ax=ax[0],
legend="brief",
)
ax[0].axis("off")
ax[0].set_title("Cell types\n", fontsize=font_size)
ax[0].set_title(
"Cell types",
fontsize=default_fontsize,
pad=default_title_padding,
)
ax[0].legend(
loc="lower left",
bbox_to_anchor=(0.5, -0.48),
bbox_to_anchor=(0.2, -0.4),
ncol=5,
fancybox=True,
prop={"size": font_size},
fontsize=font_size,
prop={"size": default_fontsize},
fontsize=default_fontsize,
frameon=False,
markerscale=3,
)
kwargs = dict(
color="gray",
Expand All @@ -106,22 +115,32 @@ def plot_vector_field_summary(
scv.pl.velocity_embedding_grid(
adata,
basis=vector_field_basis,
fontsize=font_size,
fontsize=default_fontsize,
ax=ax[1],
title="",
**kwargs,
)
ax[1].set_title("Scvelo\n", fontsize=7)
ax[1].set_title(
"scVelo",
fontsize=default_fontsize,
pad=default_title_padding,
)
scv.pl.velocity_embedding_grid(
adata,
fontsize=font_size,
fontsize=default_fontsize,
basis=vector_field_basis,
title="",
ax=ax[2],
vkey="velocity_pyro",
**kwargs,
)
ax[2].set_title("Pyro-Velocity\n", fontsize=7)
ax[2].set_title(
rf"Pyro\thinspace-Velocity"
if matplotlib.rcParams["text.usetex"]
else f"Pyro\u2009-Velocity",
fontsize=default_fontsize,
pad=default_title_padding,
)

pca_cell_angles = pca_embeds_angle / np.pi * 180 # degree
pca_angles_std = get_posterior_sample_angle_uncertainty(pca_cell_angles)
Expand Down Expand Up @@ -222,6 +241,7 @@ def plot_vector_field_uncertainty(
color_vector_field_by_measure=False,
dot_size=1,
show_titles: bool = True,
default_fontsize: int = 7,
):
if uncertain_measure == "angle":
adata.obs["uncertain"] = get_posterior_sample_angle_uncertainty(
Expand Down Expand Up @@ -255,8 +275,9 @@ def plot_vector_field_uncertainty(
ax[0].axis("off")
if show_titles:
ax[0].set_title(
f"Single-cell\n {uncertain_measure} uncertainty ",
fontsize=7,
# f"Single-cell\n {uncertain_measure} uncertainty ",
f"{uncertain_measure} uncertainty",
fontsize=default_fontsize,
)
ax = ax[1]

Expand Down Expand Up @@ -303,7 +324,8 @@ def plot_vector_field_uncertainty(
)
if show_titles:
ax.set_title(
f"Averaged\n {uncertain_measure} uncertainty ", fontsize=7
f"Averaged\n {uncertain_measure} uncertainty ",
fontsize=default_fontsize,
)
ax.axis("off")
else:
Expand All @@ -329,7 +351,9 @@ def plot_vector_field_uncertainty(
ax.axis("off")
if show_titles:
ax.set_title(
f"Single-cell\n {uncertain_measure} uncertainty ", fontsize=7
# f"Single-cell\n {uncertain_measure} uncertainty ", fontsize=default_fontsize
f"{uncertain_measure} uncertainty",
fontsize=default_fontsize,
)
if cbar:
# from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -350,7 +374,7 @@ def plot_vector_field_uncertainty(
) # fraction=0.046, pad=0.04
cbar.ax.tick_params(axis="x", labelsize=5.5)
cbar.ax.locator = MaxNLocator(nbins=2, integer=True)
# cbar.ax.set_xlabel(f"{uncertain_measure} uncertainty", fontsize=7)
# cbar.ax.set_xlabel(f"{uncertain_measure} uncertainty", fontsize=default_fontsize)


def plot_mean_vector_field(
Expand Down

0 comments on commit e900db0

Please sign in to comment.