diff --git a/src/hist/basehist.py b/src/hist/basehist.py index a5c3a93d..626e3f0a 100644 --- a/src/hist/basehist.py +++ b/src/hist/basehist.py @@ -297,12 +297,16 @@ def show(self, **kwargs: Any) -> Any: return histoprint.print_hist(self, **kwargs) - def plot(self, *args: Any, **kwargs: Any) -> "Union[Hist1DArtists, Hist2DArtists]": + def plot( + self, *args: Any, overlay: "Optional[str]" = None, **kwargs: Any + ) -> "Union[Hist1DArtists, Hist2DArtists]": """ Plot method for BaseHist object. """ - if self.ndim == 1: - return self.plot1d(*args, **kwargs) + _has_categorical = np.sum([ax.traits.discrete for ax in self.axes]) == 1 + _project = _has_categorical or overlay is not None + if self.ndim == 1 or (self.ndim == 2 and _project): + return self.plot1d(*args, overlay=overlay, **kwargs) elif self.ndim == 2: return self.plot2d(*args, **kwargs) else: @@ -312,6 +316,7 @@ def plot1d( self, *, ax: "Optional[matplotlib.axes.Axes]" = None, + overlay: "Optional[Union[str, int]]" = None, **kwargs: Any, ) -> "Hist1DArtists": """ @@ -320,7 +325,15 @@ def plot1d( import hist.plot - return hist.plot.histplot(self, ax=ax, **_proc_kw_for_lw(kwargs)) + if self.ndim == 1: + return hist.plot.histplot(self, ax=ax, **_proc_kw_for_lw(kwargs)) + if overlay is None: + (overlay,) = (i for i, ax in enumerate(self.axes) if ax.traits.discrete) + assert overlay is not None + cat_ax = self.axes[overlay] + cats = cat_ax if cat_ax.traits.discrete else np.arange(len(cat_ax.centers)) + d1hists = [self[{overlay: cat}] for cat in cats] + return hist.plot.histplot(d1hists, ax=ax, label=cats, **_proc_kw_for_lw(kwargs)) def plot2d( self, diff --git a/tests/baseline/test_plot1d_auto_handling.png b/tests/baseline/test_plot1d_auto_handling.png new file mode 100644 index 00000000..0c8726f4 Binary files /dev/null and b/tests/baseline/test_plot1d_auto_handling.png differ diff --git a/tests/test_plot.py b/tests/test_plot.py index f3b2a4e3..1b6d55c2 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -623,3 +623,42 @@ def pdf(x, a=1 / np.sqrt(2 * np.pi), x0=0, sigma=1, offset=0): assert h.plot_pull(pdf, fit_fmt=r"{name} = {value:.3g} $\pm$ {error:.3g}") return fig + + +@pytest.mark.mpl_image_compare(baseline_dir="baseline", savefig_kwargs={"dpi": 50}) +def test_plot1d_auto_handling(): + """ + Test plot() by comparing against a reference image generated via + `pytest --mpl-generate-path=tests/baseline` + """ + + np.random.seed(42) + + h = Hist( + axis.Regular(10, 0, 10, name="variable", label="variable"), + axis.StrCategory("", name="dataset", growth=True), + ) + + h_nameless = Hist( + axis.Regular(10, 0, 10), + axis.StrCategory("", growth=True), + ) + + h.fill(dataset="A", variable=np.random.normal(3, 2, 100)) + h.fill(dataset="B", variable=np.random.normal(5, 2, 100)) + h.fill(dataset="C", variable=np.random.normal(7, 2, 100)) + + h_nameless.fill(np.random.normal(3, 2, 1000), "A") + h_nameless.fill(np.random.normal(5, 2, 1000), "B") + h_nameless.fill(np.random.normal(7, 2, 1000), "C") + + fig, (ax1, ax2) = plt.subplots(2, 2, figsize=(14, 10)) + + assert h.plot(ax=ax1[0]) + assert h_nameless.plot(ax=ax2[0]) + + # Discrete axis plotting not yet implemented + # assert h.plot(ax=ax1[1], overlay='variable') + # assert h.plot(ax=ax2[1], overlay=1) + + return fig