From 2ca7c0f68f68c6e0b2f0bea251c54c170e39c09d Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 09:45:22 +0100 Subject: [PATCH 1/6] pattern clustering plot with logos --- src/crested/pl/patterns/__init__.py | 3 + src/crested/pl/patterns/_modisco_results.py | 170 +++++++++++++++++++- src/crested/pl/patterns/_utils.py | 72 +++++++-- src/crested/tl/modisco/_modisco_utils.py | 11 +- 4 files changed, 238 insertions(+), 18 deletions(-) diff --git a/src/crested/pl/patterns/__init__.py b/src/crested/pl/patterns/__init__.py index 028ec26d..c834c613 100644 --- a/src/crested/pl/patterns/__init__.py +++ b/src/crested/pl/patterns/__init__.py @@ -31,6 +31,7 @@ def _optional_function_warning(*args, **kwargs): from ._modisco_results import ( class_instances, clustermap, + clustermap_with_pwm_logos, clustermap_tf_motif, modisco_results, selected_instances, @@ -48,6 +49,7 @@ def _optional_function_warning(*args, **kwargs): class_instances = _optional_function_warning clustermap_tf_motif = _optional_function_warning tf_expression_per_cell_type = _optional_function_warning + clustermap_with_pwm_logos= _optional_function_warning # Export these functions for public use __all__ = [ @@ -66,5 +68,6 @@ def _optional_function_warning(*args, **kwargs): "selected_instances", "clustermap_tf_motif", "tf_expression_per_cell_type", + "clustermap_with_pwm_logos", ] ) diff --git a/src/crested/pl/patterns/_modisco_results.py b/src/crested/pl/patterns/_modisco_results.py index ddde9957..d74d78b7 100644 --- a/src/crested/pl/patterns/_modisco_results.py +++ b/src/crested/pl/patterns/_modisco_results.py @@ -30,6 +30,8 @@ def modisco_results( y_min: float = -0.05, y_max: float = 0.25, background: list[float] = None, + trim_pattern: bool = True, + trim_ic_threshold: float = 0.1, **kwargs, ) -> None: """ @@ -62,6 +64,10 @@ def modisco_results( Maximum y-axis limit for the plot if viz is "contrib". background Background probabilities for each nucleotide. Default is [0.27, 0.23, 0.23, 0.27]. + trim_pattern + Boolean for trimming modisco patterns. + trim_ic_threshold + If trimming patterns, indicate threshold. kwargs Additional keyword arguments for the plot. @@ -152,7 +158,7 @@ def modisco_results( logger.info("total seqlets:", num_seqlets) if num_seqlets < min_seqlets: break - pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, 0.1) + pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, trim_ic_threshold) if trim_pattern else pattern if viz == "contrib": ax = _plot_attribution_map( ax=ax, @@ -165,8 +171,7 @@ def modisco_results( f"{cell_type}: {np.around(num_seqlets / num_seq * 100, 2)}% seqlet frequency" ) elif viz == "pwm": - pattern = _trim_pattern_by_ic(pattern, pos_pat, 0.1) - ppm = _pattern_to_ppm(pattern) + ppm = _pattern_to_ppm(pattern_trimmed) ic, ic_pos, ic_mat = compute_ic(ppm) pwm = np.array(ic_mat) rounded_mean = np.around(np.mean(pwm), 2) @@ -196,7 +201,7 @@ def modisco_results( def clustermap( pattern_matrix: np.ndarray, classes: list[str], - subset: list[str] | None = None, # Subset option + subset: list[str] | None = None, figsize: tuple[int, int] = (25, 8), grid: bool = False, cmap: str = "coolwarm", @@ -359,8 +364,163 @@ def clustermap( plt.show() +def clustermap_with_pwm_logos( + pattern_matrix: np.ndarray, + classes: list[str], + pattern_dict: dict, + subset: list[str] | None = None, + figsize: tuple[int, int] = (25, 8), + grid: bool = False, + cmap: str = "coolwarm", + center: float = 0, + method: str = "average", + fig_path: str | None = None, + dendrogram_ratio: tuple[float, float] = (0.05, 0.2), + importance_threshold: float = 0, + logo_height_fraction: float = 0.35, + logo_y_padding: float = 0.3, +) -> sns.matrix.ClusterGrid: + """ + Create a clustermap with additional PWM logo plots below the heatmap. + + Parameters + ---------- + pattern_matrix: + A 2D array representing the data matrix for clustering. + classes: + The class labels for the rows of the matrix. + pattern_dict: + A dictionary containing PWM patterns for x-tick plots. + subset + List of class labels to subset the matrix. + figsize: + Size of the clustermap figure (width, height). Default is (25, 8). + grid: + Whether to overlay grid lines on the heatmap. Default is False. + cmap: + Colormap for the heatmap. Default is "coolwarm". + center: + The value at which to center the colormap. Default is 0. + method: + Linkage method for hierarchical clustering. Default is "average". + fig_path: + Path to save the final figure. If None, the figure is not saved. Default is None. + dendrogram_ratio: + Ratios for the size of row and column dendrograms. Default is (0.05, 0.2). + importance_threshold: + Threshold for filtering columns based on maximum absolute importance. Default is 0. + logo_height_fraction: + Fraction of clustermap height to allocate for PWM logos. Default is 0.35. + logo_y_padding: + Vertical padding for the PWM logos relative to the heatmap. Default is 0.3. + + Returns: + sns.matrix.ClusterGrid: A seaborn ClusterGrid object containing the clustermap with the PWM logos. + """ + # Subset the pattern_matrix and classes if subset is provided + if subset is not None: + subset_indices = [ + i for i, class_label in enumerate(classes) if class_label in subset + ] + pattern_matrix = pattern_matrix[subset_indices, :] + classes = [classes[i] for i in subset_indices] + + # Filter columns based on importance threshold + max_importance = np.max(np.abs(pattern_matrix), axis=0) + above_threshold = max_importance > importance_threshold + pattern_matrix = pattern_matrix[:, above_threshold] + + # Subset the pattern_dict to match filtered columns + selected_patterns = [pattern_dict[str(i)] for i in np.where(above_threshold)[0]] + + data = pd.DataFrame(pattern_matrix) + + # Generate the clustermap with the specified figsize + g = sns.clustermap( + data, + cmap=cmap, + figsize=figsize, + row_colors=None, + yticklabels=classes, + center=center, + xticklabels=False, + method=method, + dendrogram_ratio=dendrogram_ratio, + cbar_pos=(1.05, 0.4, 0.01, 0.3), + ) + + col_order = g.dendrogram_col.reordered_ind + cbar = g.ax_heatmap.collections[0].colorbar + cbar.set_label("Motif importance", rotation=270, labelpad=20) + + # Reorder selected_patterns based on clustering + reordered_patterns = [selected_patterns[i] for i in col_order] + + # Compute space for x-tick images + original_height = figsize[1] + extra_height = logo_height_fraction * original_height + total_height = original_height + extra_height + + # Update the figure size to accommodate the logos + fig = g.fig + fig.set_size_inches(figsize[0], total_height) + + # Adjust width and height of logos + logo_width = g.ax_heatmap.get_position().width / len(reordered_patterns) * 2.5 + logo_height = logo_height_fraction * g.ax_heatmap.get_position().height + ratio = logo_height / logo_width + + for i, pattern in enumerate(reordered_patterns): + plot_start_x = g.ax_heatmap.get_position().x0 + ((i - 0.75) / len(reordered_patterns)) * g.ax_heatmap.get_position().width + plot_start_y = g.ax_heatmap.get_position().y0 - logo_height - logo_height * logo_y_padding + pwm_ax = fig.add_axes([plot_start_x, plot_start_y, logo_width, logo_height]) + pwm_ax.clear() + + # Plot the PWM logo with dynamic figsize + ppm = _pattern_to_ppm(pattern["pattern"]) + ic, ic_pos, ic_mat = compute_ic(ppm) + pwm = np.array(ic_mat) + pwm_ax = _plot_attribution_map( + ax=pwm_ax, + saliency_df=pwm, + return_ax=True, + figsize=(8 * ratio, 8), + rotate=True, + ) + pwm_ax.axis("off") + + if grid: + ax = g.ax_heatmap + x_positions = np.arange(pattern_matrix.shape[1] + 1) + y_positions = np.arange(len(pattern_matrix) + 1) + + # Add horizontal grid lines + for y in y_positions: + ax.hlines(y, *ax.get_xlim(), color="grey", linewidth=0.25) + + # Add vertical grid lines + for x in x_positions: + ax.vlines(x, *ax.get_ylim(), color="grey", linewidth=0.25) + + g.fig.canvas.draw() + + ax = g.ax_heatmap + ax.xaxis.tick_bottom() + ax.set_xticks(np.arange(pattern_matrix.shape[1]) + 0.5) + ax.set_xticklabels([f"{i}" for i in col_order], rotation=90) + for tick in ax.get_xticklabels(): + tick.set_verticalalignment("top") + + if fig_path is not None: + plt.savefig(fig_path, bbox_inches="tight", dpi=600) + + plt.show() + return g -def selected_instances(pattern_dict: dict, idcs: list[int]) -> None: +def selected_instances( + pattern_dict: dict, + idcs: list[int], +)-> None: """ Plot the patterns specified by the indices in `idcs` from the `pattern_dict`. diff --git a/src/crested/pl/patterns/_utils.py b/src/crested/pl/patterns/_utils.py index 8b2fff2b..2c12cdff 100644 --- a/src/crested/pl/patterns/_utils.py +++ b/src/crested/pl/patterns/_utils.py @@ -4,6 +4,8 @@ import logomaker import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from PIL import Image import numpy as np import pandas as pd @@ -55,24 +57,72 @@ def _plot_attribution_map( ax=None, return_ax: bool = True, spines: bool = True, - figsize: tuple | None = (20, 1), + figsize: tuple[int, int] = (20, 1), + rotate: bool = False, ): - """Plot an attribution map using logomaker.""" - if type(saliency_df) is not pd.DataFrame: + """ + Plots an attribution map (PWM logo) and optionally rotates it by 90 degrees. + + Parameters: + saliency_df (pd.DataFrame or np.ndarray): A DataFrame or array with attribution scores, + where columns are nucleotide bases (A, C, G, T). + ax (matplotlib.axes.Axes, optional): Axes object to plot on. Default is None, + which creates a new Axes. + return_ax (bool, optional): Whether to return the Axes object. Default is True. + spines (bool, optional): Whether to display spines (axes borders). Default is True. + figsize (tuple[int, int], optional): Figure size for temporary rendering. Default is (20, 1). + rotate (bool, optional): Whether to rotate the resulting plot by 90 degrees. Default is False. + + Returns: + matplotlib.axes.Axes: The Axes object with the plotted attribution map, if `return_ax` is True. + """ + import logomaker + import matplotlib.pyplot as plt + import pandas as pd + import numpy as np + from PIL import Image + + # Convert input to DataFrame if needed + if not isinstance(saliency_df, pd.DataFrame): saliency_df = pd.DataFrame(saliency_df, columns=["A", "C", "G", "T"]) - if figsize is not None: - logomaker.Logo(saliency_df, figsize=figsize, ax=ax) - else: + + # Standard plotting (no rotation) + if not rotate: + if ax is None: + _, ax = plt.subplots(figsize=figsize) logomaker.Logo(saliency_df, ax=ax) + if not spines: + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + if return_ax: + return ax + return + + # Rotation case: render plot to an image + temp_fig, temp_ax = plt.subplots(figsize=figsize) + logomaker.Logo(saliency_df, ax=temp_ax) + temp_ax.axis("off") # Remove axes for clean rendering + + # Render the plot as an image + temp_fig.canvas.draw() + width, height = map(int, temp_fig.get_size_inches() * temp_fig.get_dpi()) + image = np.frombuffer(temp_fig.canvas.tostring_rgb(), dtype="uint8").reshape(height, width, 3) + plt.close(temp_fig) # Close the temporary figure to avoid memory leaks + + # Rotate the rendered image + rotated_image = np.rot90(image) + rotated_image_pil = Image.fromarray(rotated_image) + + # Display the rotated image on the given Axes if ax is None: - ax = plt.gca() - if not spines: - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) + _, ax = plt.subplots(figsize=figsize) + ax.clear() + ax.imshow(rotated_image_pil) + ax.axis("off") # Hide axes for a clean look + if return_ax: return ax - def _plot_mutagenesis_map(mutagenesis_df, ax=None): """Plot an attribution map for mutagenesis using different colored dots, with adjusted x-axis limits.""" colors = {"A": "green", "C": "blue", "G": "orange", "T": "red"} diff --git a/src/crested/tl/modisco/_modisco_utils.py b/src/crested/tl/modisco/_modisco_utils.py index 91609ab7..2428a78c 100644 --- a/src/crested/tl/modisco/_modisco_utils.py +++ b/src/crested/tl/modisco/_modisco_utils.py @@ -93,8 +93,15 @@ def _trim_pattern_by_ic( v = (v - v.min()) / (v.max() - v.min() + 1e-9) try: - start_idx = min(np.where(np.diff((v > min_v) * 1))[0]) - end_idx = max(np.where(np.diff((v > min_v) * 1))[0]) + 1 + if min_v>0: + start_idx = min(np.where(v > min_v)[0]) + end_idx = max(np.where(v > min_v)[0]) + else: + start_idx=0 + end_idx=len(ppm) + + if end_idx==start_idx: + end_idx=start_idx+1 except ValueError: logger.error("No valid pattern found. Aborting...") From e5dd0228687f553b218ae0dc36f81e8e78e01936 Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 09:52:32 +0100 Subject: [PATCH 2/6] ruff --- src/crested/pl/patterns/__init__.py | 2 +- src/crested/pl/patterns/_modisco_results.py | 7 ++++--- src/crested/pl/patterns/_utils.py | 19 +++++++------------ 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/crested/pl/patterns/__init__.py b/src/crested/pl/patterns/__init__.py index c834c613..3718e611 100644 --- a/src/crested/pl/patterns/__init__.py +++ b/src/crested/pl/patterns/__init__.py @@ -31,8 +31,8 @@ def _optional_function_warning(*args, **kwargs): from ._modisco_results import ( class_instances, clustermap, - clustermap_with_pwm_logos, clustermap_tf_motif, + clustermap_with_pwm_logos, modisco_results, selected_instances, similarity_heatmap, diff --git a/src/crested/pl/patterns/_modisco_results.py b/src/crested/pl/patterns/_modisco_results.py index d74d78b7..d98ef146 100644 --- a/src/crested/pl/patterns/_modisco_results.py +++ b/src/crested/pl/patterns/_modisco_results.py @@ -385,7 +385,7 @@ def clustermap_with_pwm_logos( Parameters ---------- - pattern_matrix: + pattern_matrix: A 2D array representing the data matrix for clustering. classes: The class labels for the rows of the matrix. @@ -395,7 +395,7 @@ def clustermap_with_pwm_logos( List of class labels to subset the matrix. figsize: Size of the clustermap figure (width, height). Default is (25, 8). - grid: + grid: Whether to overlay grid lines on the heatmap. Default is False. cmap: Colormap for the heatmap. Default is "coolwarm". @@ -414,7 +414,8 @@ def clustermap_with_pwm_logos( logo_y_padding: Vertical padding for the PWM logos relative to the heatmap. Default is 0.3. - Returns: + Returns + ------- sns.matrix.ClusterGrid: A seaborn ClusterGrid object containing the clustermap with the PWM logos. """ # Subset the pattern_matrix and classes if subset is provided diff --git a/src/crested/pl/patterns/_utils.py b/src/crested/pl/patterns/_utils.py index 2c12cdff..225d2f70 100644 --- a/src/crested/pl/patterns/_utils.py +++ b/src/crested/pl/patterns/_utils.py @@ -4,10 +4,9 @@ import logomaker import matplotlib.pyplot as plt -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas -from PIL import Image import numpy as np import pandas as pd +from PIL import Image def grad_times_input_to_df(x, grad, alphabet="ACGT"): @@ -63,25 +62,21 @@ def _plot_attribution_map( """ Plots an attribution map (PWM logo) and optionally rotates it by 90 degrees. - Parameters: - saliency_df (pd.DataFrame or np.ndarray): A DataFrame or array with attribution scores, + Parameters + ---------- + saliency_df (pd.DataFrame or np.ndarray): A DataFrame or array with attribution scores, where columns are nucleotide bases (A, C, G, T). - ax (matplotlib.axes.Axes, optional): Axes object to plot on. Default is None, + ax (matplotlib.axes.Axes, optional): Axes object to plot on. Default is None, which creates a new Axes. return_ax (bool, optional): Whether to return the Axes object. Default is True. spines (bool, optional): Whether to display spines (axes borders). Default is True. figsize (tuple[int, int], optional): Figure size for temporary rendering. Default is (20, 1). rotate (bool, optional): Whether to rotate the resulting plot by 90 degrees. Default is False. - Returns: + Returns + ------- matplotlib.axes.Axes: The Axes object with the plotted attribution map, if `return_ax` is True. """ - import logomaker - import matplotlib.pyplot as plt - import pandas as pd - import numpy as np - from PIL import Image - # Convert input to DataFrame if needed if not isinstance(saliency_df, pd.DataFrame): saliency_df = pd.DataFrame(saliency_df, columns=["A", "C", "G", "T"]) From 824bbd96c87544d3453edcb254fb333323eb5095 Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 09:54:02 +0100 Subject: [PATCH 3/6] pydocstyle --- src/crested/pl/patterns/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crested/pl/patterns/_utils.py b/src/crested/pl/patterns/_utils.py index 225d2f70..05db7d62 100644 --- a/src/crested/pl/patterns/_utils.py +++ b/src/crested/pl/patterns/_utils.py @@ -60,7 +60,7 @@ def _plot_attribution_map( rotate: bool = False, ): """ - Plots an attribution map (PWM logo) and optionally rotates it by 90 degrees. + Plot an attribution map (PWM logo) and optionally rotate it by 90 degrees. Parameters ---------- From afce1b45b7534b941c84921f9ac580e67996748e Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 10:22:43 +0100 Subject: [PATCH 4/6] clipping update --- src/crested/tl/modisco/_modisco_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/crested/tl/modisco/_modisco_utils.py b/src/crested/tl/modisco/_modisco_utils.py index 2428a78c..b5642a71 100644 --- a/src/crested/tl/modisco/_modisco_utils.py +++ b/src/crested/tl/modisco/_modisco_utils.py @@ -95,13 +95,16 @@ def _trim_pattern_by_ic( try: if min_v>0: start_idx = min(np.where(v > min_v)[0]) - end_idx = max(np.where(v > min_v)[0]) + end_idx = max(np.where(v > min_v)[0])+1 else: start_idx=0 end_idx=len(ppm) if end_idx==start_idx: end_idx=start_idx+1 + + if end_idx==len(v): + end_idx=len(v)-1 except ValueError: logger.error("No valid pattern found. Aborting...") From 72d34e33f06f00679537aa3a511c68ff6e45578f Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 10:24:51 +0100 Subject: [PATCH 5/6] ruff check --- src/crested/tl/modisco/_modisco_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crested/tl/modisco/_modisco_utils.py b/src/crested/tl/modisco/_modisco_utils.py index b5642a71..669a7d14 100644 --- a/src/crested/tl/modisco/_modisco_utils.py +++ b/src/crested/tl/modisco/_modisco_utils.py @@ -102,7 +102,7 @@ def _trim_pattern_by_ic( if end_idx==start_idx: end_idx=start_idx+1 - + if end_idx==len(v): end_idx=len(v)-1 except ValueError: From f20b0a4513d6471e1a7c1afc135706106627a888 Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 27 Jan 2025 10:26:58 +0100 Subject: [PATCH 6/6] api update --- docs/api/plotting/patterns.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/plotting/patterns.md b/docs/api/plotting/patterns.md index 32bb6f16..bab30f86 100644 --- a/docs/api/plotting/patterns.md +++ b/docs/api/plotting/patterns.md @@ -17,6 +17,7 @@ Plot contribution scores and analyze them using tfmodisco. selected_instances class_instances clustermap + clustermap_with_pwm_logos clustermap_tf_motif tf_expression_per_cell_type similarity_heatmap