Skip to content

Commit

Permalink
Merge pull request #97 from aertslab/dev
Browse files Browse the repository at this point in the history
pattern clustering plot with logos
  • Loading branch information
nkempynck authored Jan 27, 2025
2 parents e5ecc0c + f20b0a4 commit 77b57a4
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/api/plotting/patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Plot contribution scores and analyze them using tfmodisco.
selected_instances
class_instances
clustermap
clustermap_with_pwm_logos
clustermap_tf_motif
tf_expression_per_cell_type
similarity_heatmap
Expand Down
3 changes: 3 additions & 0 deletions src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _optional_function_warning(*args, **kwargs):
class_instances,
clustermap,
clustermap_tf_motif,
clustermap_with_pwm_logos,
modisco_results,
selected_instances,
similarity_heatmap,
Expand All @@ -48,6 +49,7 @@ def _optional_function_warning(*args, **kwargs):
class_instances = _optional_function_warning
clustermap_tf_motif = _optional_function_warning
tf_expression_per_cell_type = _optional_function_warning
clustermap_with_pwm_logos= _optional_function_warning

# Export these functions for public use
__all__ = [
Expand All @@ -66,5 +68,6 @@ def _optional_function_warning(*args, **kwargs):
"selected_instances",
"clustermap_tf_motif",
"tf_expression_per_cell_type",
"clustermap_with_pwm_logos",
]
)
171 changes: 166 additions & 5 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def modisco_results(
y_min: float = -0.05,
y_max: float = 0.25,
background: list[float] = None,
trim_pattern: bool = True,
trim_ic_threshold: float = 0.1,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -62,6 +64,10 @@ def modisco_results(
Maximum y-axis limit for the plot if viz is "contrib".
background
Background probabilities for each nucleotide. Default is [0.27, 0.23, 0.23, 0.27].
trim_pattern
Boolean for trimming modisco patterns.
trim_ic_threshold
If trimming patterns, indicate threshold.
kwargs
Additional keyword arguments for the plot.
Expand Down Expand Up @@ -152,7 +158,7 @@ def modisco_results(
logger.info("total seqlets:", num_seqlets)
if num_seqlets < min_seqlets:
break
pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, 0.1)
pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, trim_ic_threshold) if trim_pattern else pattern
if viz == "contrib":
ax = _plot_attribution_map(
ax=ax,
Expand All @@ -165,8 +171,7 @@ def modisco_results(
f"{cell_type}: {np.around(num_seqlets / num_seq * 100, 2)}% seqlet frequency"
)
elif viz == "pwm":
pattern = _trim_pattern_by_ic(pattern, pos_pat, 0.1)
ppm = _pattern_to_ppm(pattern)
ppm = _pattern_to_ppm(pattern_trimmed)
ic, ic_pos, ic_mat = compute_ic(ppm)
pwm = np.array(ic_mat)
rounded_mean = np.around(np.mean(pwm), 2)
Expand Down Expand Up @@ -196,7 +201,7 @@ def modisco_results(
def clustermap(
pattern_matrix: np.ndarray,
classes: list[str],
subset: list[str] | None = None, # Subset option
subset: list[str] | None = None,
figsize: tuple[int, int] = (25, 8),
grid: bool = False,
cmap: str = "coolwarm",
Expand Down Expand Up @@ -359,8 +364,164 @@ def clustermap(

plt.show()

def clustermap_with_pwm_logos(
pattern_matrix: np.ndarray,
classes: list[str],
pattern_dict: dict,
subset: list[str] | None = None,
figsize: tuple[int, int] = (25, 8),
grid: bool = False,
cmap: str = "coolwarm",
center: float = 0,
method: str = "average",
fig_path: str | None = None,
dendrogram_ratio: tuple[float, float] = (0.05, 0.2),
importance_threshold: float = 0,
logo_height_fraction: float = 0.35,
logo_y_padding: float = 0.3,
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap with additional PWM logo plots below the heatmap.
Parameters
----------
pattern_matrix:
A 2D array representing the data matrix for clustering.
classes:
The class labels for the rows of the matrix.
pattern_dict:
A dictionary containing PWM patterns for x-tick plots.
subset
List of class labels to subset the matrix.
figsize:
Size of the clustermap figure (width, height). Default is (25, 8).
grid:
Whether to overlay grid lines on the heatmap. Default is False.
cmap:
Colormap for the heatmap. Default is "coolwarm".
center:
The value at which to center the colormap. Default is 0.
method:
Linkage method for hierarchical clustering. Default is "average".
fig_path:
Path to save the final figure. If None, the figure is not saved. Default is None.
dendrogram_ratio:
Ratios for the size of row and column dendrograms. Default is (0.05, 0.2).
importance_threshold:
Threshold for filtering columns based on maximum absolute importance. Default is 0.
logo_height_fraction:
Fraction of clustermap height to allocate for PWM logos. Default is 0.35.
logo_y_padding:
Vertical padding for the PWM logos relative to the heatmap. Default is 0.3.
Returns
-------
sns.matrix.ClusterGrid: A seaborn ClusterGrid object containing the clustermap with the PWM logos.
"""
# Subset the pattern_matrix and classes if subset is provided
if subset is not None:
subset_indices = [
i for i, class_label in enumerate(classes) if class_label in subset
]
pattern_matrix = pattern_matrix[subset_indices, :]
classes = [classes[i] for i in subset_indices]

# Filter columns based on importance threshold
max_importance = np.max(np.abs(pattern_matrix), axis=0)
above_threshold = max_importance > importance_threshold
pattern_matrix = pattern_matrix[:, above_threshold]

# Subset the pattern_dict to match filtered columns
selected_patterns = [pattern_dict[str(i)] for i in np.where(above_threshold)[0]]

data = pd.DataFrame(pattern_matrix)

# Generate the clustermap with the specified figsize
g = sns.clustermap(
data,
cmap=cmap,
figsize=figsize,
row_colors=None,
yticklabels=classes,
center=center,
xticklabels=False,
method=method,
dendrogram_ratio=dendrogram_ratio,
cbar_pos=(1.05, 0.4, 0.01, 0.3),
)

col_order = g.dendrogram_col.reordered_ind
cbar = g.ax_heatmap.collections[0].colorbar
cbar.set_label("Motif importance", rotation=270, labelpad=20)

# Reorder selected_patterns based on clustering
reordered_patterns = [selected_patterns[i] for i in col_order]

# Compute space for x-tick images
original_height = figsize[1]
extra_height = logo_height_fraction * original_height
total_height = original_height + extra_height

# Update the figure size to accommodate the logos
fig = g.fig
fig.set_size_inches(figsize[0], total_height)

# Adjust width and height of logos
logo_width = g.ax_heatmap.get_position().width / len(reordered_patterns) * 2.5
logo_height = logo_height_fraction * g.ax_heatmap.get_position().height
ratio = logo_height / logo_width

for i, pattern in enumerate(reordered_patterns):
plot_start_x = g.ax_heatmap.get_position().x0 + ((i - 0.75) / len(reordered_patterns)) * g.ax_heatmap.get_position().width
plot_start_y = g.ax_heatmap.get_position().y0 - logo_height - logo_height * logo_y_padding
pwm_ax = fig.add_axes([plot_start_x, plot_start_y, logo_width, logo_height])
pwm_ax.clear()

# Plot the PWM logo with dynamic figsize
ppm = _pattern_to_ppm(pattern["pattern"])
ic, ic_pos, ic_mat = compute_ic(ppm)
pwm = np.array(ic_mat)
pwm_ax = _plot_attribution_map(
ax=pwm_ax,
saliency_df=pwm,
return_ax=True,
figsize=(8 * ratio, 8),
rotate=True,
)
pwm_ax.axis("off")

if grid:
ax = g.ax_heatmap
x_positions = np.arange(pattern_matrix.shape[1] + 1)
y_positions = np.arange(len(pattern_matrix) + 1)

# Add horizontal grid lines
for y in y_positions:
ax.hlines(y, *ax.get_xlim(), color="grey", linewidth=0.25)

# Add vertical grid lines
for x in x_positions:
ax.vlines(x, *ax.get_ylim(), color="grey", linewidth=0.25)

g.fig.canvas.draw()

ax = g.ax_heatmap
ax.xaxis.tick_bottom()
ax.set_xticks(np.arange(pattern_matrix.shape[1]) + 0.5)
ax.set_xticklabels([f"{i}" for i in col_order], rotation=90)
for tick in ax.get_xticklabels():
tick.set_verticalalignment("top")

if fig_path is not None:
plt.savefig(fig_path, bbox_inches="tight", dpi=600)

plt.show()
return g

def selected_instances(pattern_dict: dict, idcs: list[int]) -> None:
def selected_instances(
pattern_dict: dict,
idcs: list[int],
)-> None:
"""
Plot the patterns specified by the indices in `idcs` from the `pattern_dict`.
Expand Down
67 changes: 56 additions & 11 deletions src/crested/pl/patterns/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image


def grad_times_input_to_df(x, grad, alphabet="ACGT"):
Expand Down Expand Up @@ -55,24 +56,68 @@ def _plot_attribution_map(
ax=None,
return_ax: bool = True,
spines: bool = True,
figsize: tuple | None = (20, 1),
figsize: tuple[int, int] = (20, 1),
rotate: bool = False,
):
"""Plot an attribution map using logomaker."""
if type(saliency_df) is not pd.DataFrame:
"""
Plot an attribution map (PWM logo) and optionally rotate it by 90 degrees.
Parameters
----------
saliency_df (pd.DataFrame or np.ndarray): A DataFrame or array with attribution scores,
where columns are nucleotide bases (A, C, G, T).
ax (matplotlib.axes.Axes, optional): Axes object to plot on. Default is None,
which creates a new Axes.
return_ax (bool, optional): Whether to return the Axes object. Default is True.
spines (bool, optional): Whether to display spines (axes borders). Default is True.
figsize (tuple[int, int], optional): Figure size for temporary rendering. Default is (20, 1).
rotate (bool, optional): Whether to rotate the resulting plot by 90 degrees. Default is False.
Returns
-------
matplotlib.axes.Axes: The Axes object with the plotted attribution map, if `return_ax` is True.
"""
# Convert input to DataFrame if needed
if not isinstance(saliency_df, pd.DataFrame):
saliency_df = pd.DataFrame(saliency_df, columns=["A", "C", "G", "T"])
if figsize is not None:
logomaker.Logo(saliency_df, figsize=figsize, ax=ax)
else:

# Standard plotting (no rotation)
if not rotate:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
logomaker.Logo(saliency_df, ax=ax)
if not spines:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if return_ax:
return ax
return

# Rotation case: render plot to an image
temp_fig, temp_ax = plt.subplots(figsize=figsize)
logomaker.Logo(saliency_df, ax=temp_ax)
temp_ax.axis("off") # Remove axes for clean rendering

# Render the plot as an image
temp_fig.canvas.draw()
width, height = map(int, temp_fig.get_size_inches() * temp_fig.get_dpi())
image = np.frombuffer(temp_fig.canvas.tostring_rgb(), dtype="uint8").reshape(height, width, 3)
plt.close(temp_fig) # Close the temporary figure to avoid memory leaks

# Rotate the rendered image
rotated_image = np.rot90(image)
rotated_image_pil = Image.fromarray(rotated_image)

# Display the rotated image on the given Axes
if ax is None:
ax = plt.gca()
if not spines:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
_, ax = plt.subplots(figsize=figsize)
ax.clear()
ax.imshow(rotated_image_pil)
ax.axis("off") # Hide axes for a clean look

if return_ax:
return ax


def _plot_mutagenesis_map(mutagenesis_df, ax=None):
"""Plot an attribution map for mutagenesis using different colored dots, with adjusted x-axis limits."""
colors = {"A": "green", "C": "blue", "G": "orange", "T": "red"}
Expand Down
14 changes: 12 additions & 2 deletions src/crested/tl/modisco/_modisco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,18 @@ def _trim_pattern_by_ic(
v = (v - v.min()) / (v.max() - v.min() + 1e-9)

try:
start_idx = min(np.where(np.diff((v > min_v) * 1))[0])
end_idx = max(np.where(np.diff((v > min_v) * 1))[0]) + 1
if min_v>0:
start_idx = min(np.where(v > min_v)[0])
end_idx = max(np.where(v > min_v)[0])+1
else:
start_idx=0
end_idx=len(ppm)

if end_idx==start_idx:
end_idx=start_idx+1

if end_idx==len(v):
end_idx=len(v)-1
except ValueError:
logger.error("No valid pattern found. Aborting...")

Expand Down

0 comments on commit 77b57a4

Please sign in to comment.