diff --git a/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py b/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py index dff18149..cc5a8eb5 100644 --- a/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py +++ b/docs/examples/advanced_examples/plot_4_misc_chain_visuals.py @@ -112,3 +112,19 @@ # It's beautiful. And it's hard to find a nice balance. c.set_override(ChainConfig(smooth=0, bins=100)) fig = c.plotter.plot() + + +# %% +# Controlling the legend +# ---------------------- +# +# Sometimes we have a bunch of things we want to show, whether they are +# chains, markers, truth values, or something else. By default, ChainConsumer +# will try to show everything, but you specifically tell it to hide items on +# the legend. + + +c = ChainConsumer() +c.add_chain(Chain(samples=df1, name="I'm in the legend!")) +c.add_chain(Chain(samples=df2, name="I'm not!", show_label_in_legend=False)) +fig = c.plotter.plot() diff --git a/src/chainconsumer/chain.py b/src/chainconsumer/chain.py index 7632b21b..25aab13a 100644 --- a/src/chainconsumer/chain.py +++ b/src/chainconsumer/chain.py @@ -137,6 +137,10 @@ class Chain(ChainConfig): default=1.0, description="Raise the posterior surface to this. Useful for inflating or deflating uncertainty for debugging.", ) + show_label_in_legend: bool = Field( + default=True, + description="Whether to show the label in the legend", + ) @property def data_columns(self) -> list[str]: @@ -218,7 +222,7 @@ def _validate_color(cls, v: str | np.ndarray | list[float] | None) -> str | None def _copy_df(cls, v: pd.DataFrame) -> pd.DataFrame: return v.copy() - @model_validator(mode="after") + @model_validator(mode="after") # type: ignore def _validate_model(self) -> Chain: assert not self.samples.empty, "Your chain is empty. This is not ideal." diff --git a/src/chainconsumer/plotter.py b/src/chainconsumer/plotter.py index f814bd47..25325b4f 100644 --- a/src/chainconsumer/plotter.py +++ b/src/chainconsumer/plotter.py @@ -63,6 +63,8 @@ def get_size( def get_artists_from_chains(chains: list[Chain]) -> list[Artist]: artists: list[Artist] = [] for chain in chains: + if not chain.show_label_in_legend: + continue if chain.plot_contour and not chain.plot_point: artists.append( Line2D( @@ -241,10 +243,11 @@ def plot( if "markerfirst" not in legend_kwargs: legend_kwargs["markerfirst"] = legend_outside or not self.config.legend_artists - artists = get_artists_from_chains(base.chains) + chains_to_show_on_legend = [c for c in base.chains if c.show_label_in_legend] + artists = get_artists_from_chains(chains_to_show_on_legend) leg = ax.legend(handles=artists, **legend_kwargs) if self.config.legend_color_text: - for text, chain in zip(leg.get_texts(), base.chains): + for text, chain in zip(leg.get_texts(), chains_to_show_on_legend): text.set_fontweight("medium") text.set_color(colors.format(chain.color)) fig.canvas.draw()