Skip to content

Commit

Permalink
Fixes #124
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Apr 17, 2024
1 parent c13e31b commit 9fd2dcd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
16 changes: 16 additions & 0 deletions docs/examples/advanced_examples/plot_4_misc_chain_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,19 @@
# It's beautiful. And it's hard to find a nice balance.
c.set_override(ChainConfig(smooth=0, bins=100))
fig = c.plotter.plot()


# %%
# Controlling the legend
# ----------------------
#
# Sometimes we have a bunch of things we want to show, whether they are
# chains, markers, truth values, or something else. By default, ChainConsumer
# will try to show everything, but you specifically tell it to hide items on
# the legend.


c = ChainConsumer()
c.add_chain(Chain(samples=df1, name="I'm in the legend!"))
c.add_chain(Chain(samples=df2, name="I'm not!", show_label_in_legend=False))
fig = c.plotter.plot()
6 changes: 5 additions & 1 deletion src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class Chain(ChainConfig):
default=1.0,
description="Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.",
)
show_label_in_legend: bool = Field(
default=True,
description="Whether to show the label in the legend",
)

@property
def data_columns(self) -> list[str]:
Expand Down Expand Up @@ -218,7 +222,7 @@ def _validate_color(cls, v: str | np.ndarray | list[float] | None) -> str | None
def _copy_df(cls, v: pd.DataFrame) -> pd.DataFrame:
return v.copy()

@model_validator(mode="after")
@model_validator(mode="after") # type: ignore
def _validate_model(self) -> Chain:
assert not self.samples.empty, "Your chain is empty. This is not ideal."

Expand Down
7 changes: 5 additions & 2 deletions src/chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def get_size(
def get_artists_from_chains(chains: list[Chain]) -> list[Artist]:
artists: list[Artist] = []
for chain in chains:
if not chain.show_label_in_legend:
continue
if chain.plot_contour and not chain.plot_point:
artists.append(
Line2D(
Expand Down Expand Up @@ -241,10 +243,11 @@ def plot(
if "markerfirst" not in legend_kwargs:
legend_kwargs["markerfirst"] = legend_outside or not self.config.legend_artists

artists = get_artists_from_chains(base.chains)
chains_to_show_on_legend = [c for c in base.chains if c.show_label_in_legend]
artists = get_artists_from_chains(chains_to_show_on_legend)
leg = ax.legend(handles=artists, **legend_kwargs)
if self.config.legend_color_text:
for text, chain in zip(leg.get_texts(), base.chains):
for text, chain in zip(leg.get_texts(), chains_to_show_on_legend):
text.set_fontweight("medium")
text.set_color(colors.format(chain.color))
fig.canvas.draw()
Expand Down

0 comments on commit 9fd2dcd

Please sign in to comment.