Skip to content

Commit

Permalink
Pass showlegend/template kwargs to plotly (#122)
Browse files Browse the repository at this point in the history
* Pass showlegend/template kwargs to plotly

* Update plotly.py

- minor fix

* fix dropped "

* Fix showlegend for plotly

* Fix showlegend for plotly

* Add handling of continuous labels

* Add tests for plotly integration

* fix typo

---------

Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
Co-authored-by: Anastasiia Filippova <filippova.av@phystech.edu>
  • Loading branch information
3 people authored Jan 7, 2024
1 parent 6a6a07d commit 45e2229
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 23 deletions.
77 changes: 54 additions & 23 deletions cebra/integrations/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,55 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure:
Returns:
The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot.
"""

idx1, idx2, idx3 = self.idx_order
data = [
plotly.graph_objects.Scatter3d(
x=self.embedding[:, idx1],
y=self.embedding[:, idx2],
z=self.embedding[:, idx3],
mode="markers",
marker=dict(
size=self.markersize,
opacity=self.alpha,
color=self.embedding_labels,
colorscale=self.colorscale,
),
)
]
showlegend = kwargs.get("showlegend", False)
discrete = kwargs.get("discrete", False)
col = kwargs.get("col", None)
row = kwargs.get("row", None)
template = kwargs.get("template", "plotly_white")
data = []

if col is None or row is None:
self.axis.add_trace(data[0])
if not discrete and showlegend:
raise ValueError("Cannot show legend with continuous labels.")

idx1, idx2, idx3 = self.idx_order

if discrete:
unique_labels = np.unique(self.embedding_labels)
else:
self.axis.add_trace(data[0], row=row, col=col)
unique_labels = [self.embedding_labels]

for label in unique_labels:
if discrete:
filtered_idx = [
i for i, x in enumerate(self.embedding_labels) if x == label
]
else:
filtered_idx = np.arange(self.embedding.shape[0])
data.append(
plotly.graph_objects.Scatter3d(x=self.embedding[filtered_idx,
idx1],
y=self.embedding[filtered_idx,
idx2],
z=self.embedding[filtered_idx,
idx3],
mode="markers",
marker=dict(
size=self.markersize,
opacity=self.alpha,
color=label,
colorscale=self.colorscale,
),
name=str(label)))

for trace in data:
if col is None or row is None:
self.axis.add_trace(trace)
else:
self.axis.add_trace(trace, row=row, col=col)

self.axis.update_layout(
template="plotly_white",
showlegend=False,
template=template,
showlegend=showlegend,
title=self.title,
)

Expand Down Expand Up @@ -166,8 +188,17 @@ def plot_embedding_interactive(
title: The title on top of the embedding.
figsize: Figure width and height in inches.
dpi: Figure resolution.
kwargs: Optional arguments to customize the plots. See :py:class:`plotly.graph_objects.Scatter` documentation for more
details on which arguments to use.
kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments:
-- showlegend: Whether to show the legend or not.
-- discrete: Whether the labels are discrete or not.
-- col: The column of the subplot to plot the embedding on.
-- row: The row of the subplot to plot the embedding on.
-- template: The template to use for the plot.
Note: showlegend can be True only if discrete is True.
See :py:class:`plotly.graph_objects.Scatter` documentation for more
details on which arguments to use.
Returns:
The plotly figure.
Expand Down
35 changes: 35 additions & 0 deletions tests/test_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,38 @@ def test_plot_embedding(output_dimension, idx_order):

fig_subplots.data = []
fig_subplots.layout = {}


def test_discrete_with_legend():
embedding = np.random.uniform(0, 1, (1000, 3))
labels = np.random.randint(0, 10, (1000,))

fig = cebra_plotly.plot_embedding_interactive(embedding,
labels,
discrete=True,
showlegend=True)

assert len(fig._data_objs) == np.unique(labels).shape[0]
assert isinstance(fig, go.Figure)


def test_continuous_no_legend():
embedding = np.random.uniform(0, 1, (1000, 3))
labels = np.random.uniform(0, 1, (1000,))

fig = cebra_plotly.plot_embedding_interactive(embedding, labels)

assert len(fig._data_objs) == 1

assert isinstance(fig, go.Figure)


def test_continuous_with_legend_raises_error():
embedding = np.random.uniform(0, 1, (1000, 3))
labels = np.random.uniform(0, 1, (1000,))

with pytest.raises(ValueError):
cebra_plotly.plot_embedding_interactive(embedding,
labels,
discrete=False,
showlegend=True)

0 comments on commit 45e2229

Please sign in to comment.