Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add/3d peakmap #20

Merged
merged 13 commits into from
Sep 3, 2024
602 changes: 580 additions & 22 deletions nbs/PeakMap.ipynb

Large diffs are not rendered by default.

87,403 changes: 87,280 additions & 123 deletions nbs/pyopenms_viz_tutorial.ipynb

Large diffs are not rendered by default.

70 changes: 38 additions & 32 deletions pyopenms_viz/_bokeh/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class BOKEHLinePlot(BOKEHPlot, LinePlot):

@classmethod
@APPEND_PLOT_DOC
def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
"""
Plot a line plot
"""
Expand All @@ -219,7 +219,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
if by is None:
source = ColumnDataSource(data)
if color_gen is not None:
kwargs["line_color"] = next(color_gen)
kwargs["line_color"] = color_gen if isinstance(color_gen, str) else next(color_gen)
line = fig.line(x=x, y=y, source=source, **kwargs)

return fig, None
Expand All @@ -229,7 +229,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
for group, df in data.groupby(by):
source = ColumnDataSource(df)
if color_gen is not None:
kwargs["line_color"] = next(color_gen)
kwargs["line_color"] = color_gen if isinstance(color_gen, str) else next(color_gen)
line = fig.line(x=x, y=y, source=source, **kwargs)
legend_items.append((group, [line]))

Expand All @@ -245,30 +245,17 @@ class BOKEHVLinePlot(BOKEHPlot, VLinePlot):

@classmethod
@APPEND_PLOT_DOC
def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
"""
Plot a set of vertical lines
"""
color_gen = kwargs.pop("line_color", None)
if color_gen is None:
color_gen = ColorGenerator()
data["line_color"] = [next(color_gen) for _ in range(len(data))]
if by is None:
source = ColumnDataSource(data)
line = fig.segment(
x0=x,
y0=0,
x1=x,
y1=y,
source=source,
line_color="line_color",
**kwargs,
)
return fig, None
else:
legend_items = []
for group, df in data.groupby(by):
source = ColumnDataSource(df)
if not plot_3d:
if by is None:
source = ColumnDataSource(data)
line = fig.segment(
x0=x,
y0=0,
Expand All @@ -278,11 +265,27 @@ def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
line_color="line_color",
**kwargs,
)
legend_items.append((group, [line]))

legend = Legend(items=legend_items)

return fig, legend
return fig, None
else:
legend_items = []
for group, df in data.groupby(by):
source = ColumnDataSource(df)
line = fig.segment(
x0=x,
y0=0,
x1=x,
y1=y,
source=source,
line_color="line_color",
**kwargs,
)
legend_items.append((group, [line]))

legend = Legend(items=legend_items)

return fig, legend
else:
raise NotImplementedError("3D Vline plots are not supported in Bokeh")

def _add_annotations(
self,
Expand Down Expand Up @@ -312,7 +315,7 @@ class BOKEHScatterPlot(BOKEHPlot, ScatterPlot):

@classmethod
@APPEND_PLOT_DOC
def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
"""
Plot a scatter plot
"""
Expand Down Expand Up @@ -466,16 +469,19 @@ class BOKEHPeakMapPlot(BOKEH_MSPlot, PeakMapPlot):
"""

def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
scatterPlot = self.get_scatter_renderer(self.data, x, y, **class_kwargs)
if not self.plot_3d:
scatterPlot = self.get_scatter_renderer(self.data, x, y, **class_kwargs)

self.fig = scatterPlot.generate(z=z, **other_kwargs)
self.fig = scatterPlot.generate(z=z, **other_kwargs)

if self.annotation_data is not None:
self._add_box_boundaries(self.annotation_data)
if self.annotation_data is not None:
self._add_box_boundaries(self.annotation_data)

tooltips, _ = self._create_tooltips({self.xlabel: x, self.ylabel: y, "intensity": z})
tooltips, _ = self._create_tooltips({self.xlabel: x, self.ylabel: y, "intensity": z})

self._add_tooltips(self.fig, tooltips)
self._add_tooltips(self.fig, tooltips)
else:
raise NotImplementedError("3D PeakMap plots are not supported in Bokeh")

def create_x_axis_plot(self, x, z, class_kwargs):
x_fig = super().create_x_axis_plot(x, z, class_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions pyopenms_viz/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def default_legend_factory():
title: str = "1D Plot"
xlabel: str = "X-axis"
ylabel: str = "Y-axis"
zlabel: str = "Z-axis"
x_axis_location: str = "below"
y_axis_location: str = "left"
min_border: str = 0
Expand Down Expand Up @@ -231,6 +232,7 @@ def set_plot_labels(self):
"title": "PeakMap",
"xlabel": "Retention Time",
"ylabel": "mass-to-charge",
"zlabel": "Intensity",
},
# Add more plot types as needed
}
Expand All @@ -239,6 +241,8 @@ def set_plot_labels(self):
self.title = plot_configs[self.kind]["title"]
self.xlabel = plot_configs[self.kind]["xlabel"]
self.ylabel = plot_configs[self.kind]["ylabel"]
if self.kind == "peakmap":
self.zlabel = plot_configs[self.kind]["zlabel"]

if self.relative_intensity and "Intensity" in self.ylabel:
self.ylabel = "Relative " + self.ylabel
54 changes: 44 additions & 10 deletions pyopenms_viz/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
z: str | None = None,
kind=None,
by: str | None = None,
plot_3d: bool = False,
relative_intensity: bool = False,
subplots: bool | None = None,
sharex: bool | None = None,
Expand All @@ -120,6 +121,7 @@ def __init__(
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
zlabel: str | None = None,
x_axis_location: str | None = None,
y_axis_location: str | None = None,
line_type: str | None = None,
Expand All @@ -136,6 +138,7 @@ def __init__(
self.data = data.copy()
self.kind = kind
self.by = by
self.plot_3d = plot_3d
self.relative_intensity = relative_intensity

# Plotting attributes
Expand All @@ -150,6 +153,7 @@ def __init__(
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.zlabel = zlabel
self.x_axis_location = x_axis_location
self.y_axis_location = y_axis_location
self.line_type = line_type
Expand Down Expand Up @@ -298,7 +302,7 @@ def _make_plot(self, fig, **kwargs) -> None:
tooltips = kwargs.pop("tooltips", None)
custom_hover_data = kwargs.pop("custom_hover_data", None)

newlines, legend = self.plot(fig, self.data, self.x, self.y, self.by, **kwargs)
newlines, legend = self.plot(fig, self.data, self.x, self.y, self.by, self.plot_3d, **kwargs)

if legend is not None:
self._add_legend(newlines, legend)
Expand All @@ -308,7 +312,7 @@ def _make_plot(self, fig, **kwargs) -> None:
self._add_tooltips(newlines, tooltips, custom_hover_data)

@abstractmethod
def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
def plot(cls, fig, data, x, y, by: str | None = None, plot_3d: bool = False, **kwargs):
"""
Create the plot
"""
Expand Down Expand Up @@ -491,7 +495,11 @@ def plot(self, data, x, y, **kwargs):
"""
Create the plot
"""
color_gen = ColorGenerator()
if 'line_color' not in kwargs:
color_gen = ColorGenerator()
else:
color_gen = kwargs['line_color']

tooltip_entries = {"retention time": x, "intensity": y}
if "Annotation" in self.data.columns:
tooltip_entries["annotation"] = "Annotation"
Expand Down Expand Up @@ -795,6 +803,7 @@ def __init__(
num_x_bins: int = 50,
num_y_bins: int = 50,
z_log_scale: bool = False,
# plot_3d: bool = False,
**kwargs,
) -> None:
# Copy data since it will be modified
Expand Down Expand Up @@ -826,13 +835,23 @@ def __init__(
):
data[x] = cut(data[x], bins=num_x_bins)
data[y] = cut(data[y], bins=num_y_bins)

# Group by x and y bins and calculate the mean intensity within each bin
data = (
data.groupby([x, y], observed=True)
.agg({z: "mean"})
.reset_index()
)
by = kwargs.pop("by", None)
if by is not None:
# Group by x, y and by columns and calculate the mean intensity within each bin
data = (
data.groupby([x, y, by], observed=True)
.agg({z: "mean"})
.reset_index()
)
# Add by back to kwargs
kwargs["by"] = by
else:
# Group by x and y bins and calculate the mean intensity within each bin
data = (
data.groupby([x, y], observed=True)
.agg({z: "mean"})
.reset_index()
)
data[x] = data[x].apply(lambda interval: interval.mid).astype(float)
data[y] = data[y].apply(lambda interval: interval.mid).astype(float)
data = data.fillna(0)
Expand All @@ -846,7 +865,11 @@ def __init__(

super().__init__(data, x, y, z=z, **kwargs)

# if not plot_3d:
self.plot(x, y, z, **kwargs)
# else:
# self.plot_3d(x, y, z, **kwargs)

if self.show_plot:
self.show()

Expand All @@ -873,6 +896,12 @@ def plot(self, x, y, z, **kwargs):
y_fig = self.create_y_axis_plot(y, z, class_kwargs_copy)

self.combine_plots(x_fig, y_fig)

# def plot_3d(self, x, y, z, **kwargs):
# class_kwargs, other_kwargs = self._separate_class_kwargs(**kwargs)

# self.create_main_plot_3d(x, y, z, class_kwargs, other_kwargs)
# pass

@staticmethod
def _integrate_data_along_dim(
Expand All @@ -896,6 +925,10 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
# by default the main plot with marginals is plotted the same way as the main plot unless otherwise specified
def create_main_plot_marginals(self, x, y, z, class_kwargs, other_kwargs):
self.create_main_plot(x, y, z, class_kwargs, other_kwargs)

# @abstractmethod
# def create_main_plot_3d(self, x, y, z, class_kwargs, other_kwargs):
# pass

@abstractmethod
def create_x_axis_plot(self, x, z, class_kwargs) -> "figure":
Expand Down Expand Up @@ -1006,6 +1039,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Call the plot method of the selected backend
if "backend" in kwargs:
kwargs.pop("backend")

return plot_backend.plot(self._parent, x=x, y=y, kind=kind, **kwargs)

@staticmethod
Expand Down
Loading