diff --git a/arviz/__init__.py b/arviz/__init__.py index b75be6327f..7ed98d5c65 100644 --- a/arviz/__init__.py +++ b/arviz/__init__.py @@ -1,6 +1,6 @@ # pylint: disable=wildcard-import,invalid-name,wrong-import-position """ArviZ is a library for exploratory analysis of Bayesian models.""" -__version__ = "0.8.1" +__version__ = "0.8.2" import os import logging diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 749d697453..982773aa17 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -98,7 +98,7 @@ def __init__( 0 ].model self.nchains = trace.nchains if hasattr(trace, "nchains") else 1 - if hasattr(trace.report, "n_tune"): + if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None: self.ndraws = trace.report.n_draws self.attrs = { "sampling_time": trace.report.t_sampling, @@ -109,7 +109,8 @@ def __init__( if self.save_warmup: warnings.warn( "Warmup samples will be stored in posterior group and will not be" - " excluded from stats and diagnostics. Please consider using PyMC3>=3.9", + " excluded from stats and diagnostics." + " Please consider using PyMC3>=3.9 and do not slice the trace manually.", UserWarning, ) else: diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index acbe05205d..e817c09eb8 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -1,7 +1,6 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name from sys import version_info from typing import Dict, Tuple -import packaging import numpy as np import pytest @@ -399,8 +398,10 @@ def test_no_model_deprecation(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails + +class TestPyMC3WarmupHandling: @pytest.mark.skipif( - packaging.version.parse(pm.__version__) < packaging.version.parse("3.9"), + not hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.9 or higher", ) @pytest.mark.parametrize("save_warmup", [False, True]) @@ -434,3 +435,58 @@ def test_save_warmup(self, save_warmup): if save_warmup: assert idata.warmup_posterior.dims["chain"] == 2 assert idata.warmup_posterior.dims["draw"] == 100 + + @pytest.mark.skipif( + hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.8 or lower", + ) + def test_save_warmup_issue_1208_before_3_9(self): + with pm.Model(): + pm.Uniform("u1") + pm.Normal("n1") + trace = pm.sample( + tune=100, + draws=200, + chains=2, + cores=1, + step=pm.Metropolis(), + discard_tuned_samples=False, + ) + assert isinstance(trace, pm.backends.base.MultiTrace) + assert len(trace) == 300 + + # <=3.8 did not track n_draws in the sampler report, + # making from_pymc3 fall back to len(trace) and triggering a warning + with pytest.warns(UserWarning, match="Warmup samples"): + idata = from_pymc3(trace, save_warmup=True) + assert idata.posterior.dims["draw"] == 300 + assert idata.posterior.dims["chain"] == 2 + + @pytest.mark.skipif( + not hasattr(pm.backends.base.SamplerReport, "n_draws"), + reason="requires pymc3 3.9 or higher", + ) + def test_save_warmup_issue_1208_after_3_9(self): + with pm.Model(): + pm.Uniform("u1") + pm.Normal("n1") + trace = pm.sample( + tune=100, + draws=200, + chains=2, + cores=1, + step=pm.Metropolis(), + discard_tuned_samples=False, + ) + assert isinstance(trace, pm.backends.base.MultiTrace) + assert len(trace) == 300 + + # from original trace, warmup draws should be separated out + idata = from_pymc3(trace, save_warmup=True) + assert idata.posterior.dims["chain"] == 2 + assert idata.posterior.dims["draw"] == 200 + + # manually sliced trace triggers the same warning as <=3.8 + with pytest.warns(UserWarning, match="Warmup samples"): + idata = from_pymc3(trace[-30:], save_warmup=True) + assert idata.posterior.dims["chain"] == 2 + assert idata.posterior.dims["draw"] == 30