From 1d288b229e759d4cb5c9dcfbcc2687cc4a83789f Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 10:50:45 +0200 Subject: [PATCH 1/6] add regression test for #1208 --- arviz/tests/external_tests/test_data_pymc.py | 26 ++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index acbe05205d..8ad00de679 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -434,3 +434,29 @@ def test_save_warmup(self, save_warmup): if save_warmup: assert idata.warmup_posterior.dims["chain"] == 2 assert idata.warmup_posterior.dims["draw"] == 100 + + def test_save_warmup_issue_1208(self): + with pm.Model(): + pm.Uniform("u1") + pm.Normal("n1") + trace = pm.sample( + tune=100, + draws=200, + chains=2, + cores=1, + step=pm.Metropolis(), + discard_tuned_samples=False, + ) + assert isinstance(trace, pm.backends.base.MultiTrace) + idata = from_pymc3(trace, save_warmup=True) + assert idata.posterior.dims["chain"] == 2 + if pm.__version__ <= '3.8': + # <=3.8 did not track n_draws in the sampler report + assert idata.posterior.dims["draw"] == 300 + else: + assert idata.posterior.dims["draw"] == 200 + # test with manually sliced trace + with pytest.warns(UserWarning): + idata = from_pymc3(trace[-30:], save_warmup=True) + assert idata.posterior.dims["chain"] == 2 + assert idata.posterior.dims["draw"] == 30 From fe0a1757ad070705ba806e6396d21ccf52e17971 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 10:56:17 +0200 Subject: [PATCH 2/6] check if n_draws is available and informative before using it + also warn the user about manually slicing + closes #1208 + bumps to 0.8.2 because of the hotfix --- arviz/__init__.py | 2 +- arviz/data/io_pymc3.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/arviz/__init__.py b/arviz/__init__.py index b75be6327f..7ed98d5c65 100644 --- a/arviz/__init__.py +++ b/arviz/__init__.py @@ -1,6 +1,6 @@ # pylint: disable=wildcard-import,invalid-name,wrong-import-position """ArviZ is a library for exploratory analysis of Bayesian models.""" -__version__ = "0.8.1" +__version__ = "0.8.2" import os import logging diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 749d697453..982773aa17 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -98,7 +98,7 @@ def __init__( 0 ].model self.nchains = trace.nchains if hasattr(trace, "nchains") else 1 - if hasattr(trace.report, "n_tune"): + if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None: self.ndraws = trace.report.n_draws self.attrs = { "sampling_time": trace.report.t_sampling, @@ -109,7 +109,8 @@ def __init__( if self.save_warmup: warnings.warn( "Warmup samples will be stored in posterior group and will not be" - " excluded from stats and diagnostics. Please consider using PyMC3>=3.9", + " excluded from stats and diagnostics." + " Please consider using PyMC3>=3.9 and do not slice the trace manually.", UserWarning, ) else: From 434fd1f394cf777e3d4bb192e459550f6ec9a48d Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 12:36:51 +0200 Subject: [PATCH 3/6] split new test and adapt conditions for different pymc3 versions --- arviz/tests/external_tests/test_data_pymc.py | 46 ++++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 8ad00de679..8b35af37e4 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -400,7 +400,7 @@ def test_no_model_deprecation(self): assert not fails @pytest.mark.skipif( - packaging.version.parse(pm.__version__) < packaging.version.parse("3.9"), + not hasattr(pm.backends.base.SamplerReport, 'n_draws'), reason="requires pymc3 3.9 or higher", ) @pytest.mark.parametrize("save_warmup", [False, True]) @@ -435,7 +435,37 @@ def test_save_warmup(self, save_warmup): assert idata.warmup_posterior.dims["chain"] == 2 assert idata.warmup_posterior.dims["draw"] == 100 - def test_save_warmup_issue_1208(self): + @pytest.mark.skipif( + hasattr(pm.backends.base.SamplerReport, 'n_draws'), + reason="requires pymc3 3.8 or lower", + ) + def test_save_warmup_issue_1208_before_3_9(self): + with pm.Model(): + pm.Uniform("u1") + pm.Normal("n1") + trace = pm.sample( + tune=100, + draws=200, + chains=2, + cores=1, + step=pm.Metropolis(), + discard_tuned_samples=False, + ) + assert isinstance(trace, pm.backends.base.MultiTrace) + assert len(trace) == 300 + + # <=3.8 did not track n_draws in the sampler report, + # making from_pymc3 fall back to len(trace) and triggering a warning + with pytest.warns(UserWarning): + idata = from_pymc3(trace, save_warmup=True) + assert idata.posterior.dims["draw"] == 300 + assert idata.posterior.dims["chain"] == 2 + + @pytest.mark.skipif( + not hasattr(pm.backends.base.SamplerReport, 'n_draws'), + reason="requires pymc3 3.9 or higher", + ) + def test_save_warmup_issue_1208_after_3_9(self): with pm.Model(): pm.Uniform("u1") pm.Normal("n1") @@ -448,14 +478,14 @@ def test_save_warmup_issue_1208(self): discard_tuned_samples=False, ) assert isinstance(trace, pm.backends.base.MultiTrace) + assert len(trace) == 300 + + # from original trace, warmup draws should be separated out idata = from_pymc3(trace, save_warmup=True) assert idata.posterior.dims["chain"] == 2 - if pm.__version__ <= '3.8': - # <=3.8 did not track n_draws in the sampler report - assert idata.posterior.dims["draw"] == 300 - else: - assert idata.posterior.dims["draw"] == 200 - # test with manually sliced trace + assert idata.posterior.dims["draw"] == 200 + + # manually sliced trace triggers the same warning as <=3.8 with pytest.warns(UserWarning): idata = from_pymc3(trace[-30:], save_warmup=True) assert idata.posterior.dims["chain"] == 2 From 5ef468e00130a543ba50693d149ec8d9aedd87f0 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 13:11:44 +0200 Subject: [PATCH 4/6] make pylint happy --- arviz/tests/external_tests/test_data_pymc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 8b35af37e4..36fe4d0557 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -1,7 +1,6 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name from sys import version_info from typing import Dict, Tuple -import packaging import numpy as np import pytest @@ -399,6 +398,8 @@ def test_no_model_deprecation(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails + +class TestPyMC3WarmupHandling: @pytest.mark.skipif( not hasattr(pm.backends.base.SamplerReport, 'n_draws'), reason="requires pymc3 3.9 or higher", From 7419b783d2af80b9174cde76a2d605c974c33912 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 15:42:58 +0200 Subject: [PATCH 5/6] make black happy --- arviz/tests/external_tests/test_data_pymc.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 36fe4d0557..5565bc23ac 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -401,7 +401,7 @@ def test_no_model_deprecation(self): class TestPyMC3WarmupHandling: @pytest.mark.skipif( - not hasattr(pm.backends.base.SamplerReport, 'n_draws'), + not hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.9 or higher", ) @pytest.mark.parametrize("save_warmup", [False, True]) @@ -437,8 +437,7 @@ def test_save_warmup(self, save_warmup): assert idata.warmup_posterior.dims["draw"] == 100 @pytest.mark.skipif( - hasattr(pm.backends.base.SamplerReport, 'n_draws'), - reason="requires pymc3 3.8 or lower", + hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.8 or lower", ) def test_save_warmup_issue_1208_before_3_9(self): with pm.Model(): @@ -463,7 +462,7 @@ def test_save_warmup_issue_1208_before_3_9(self): assert idata.posterior.dims["chain"] == 2 @pytest.mark.skipif( - not hasattr(pm.backends.base.SamplerReport, 'n_draws'), + not hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.9 or higher", ) def test_save_warmup_issue_1208_after_3_9(self): From f51cf4a26d7f3442668d6c4bf23d36a376cd4a3c Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 25 May 2020 16:26:06 +0200 Subject: [PATCH 6/6] address review feedback --- arviz/tests/external_tests/test_data_pymc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 5565bc23ac..e817c09eb8 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -456,7 +456,7 @@ def test_save_warmup_issue_1208_before_3_9(self): # <=3.8 did not track n_draws in the sampler report, # making from_pymc3 fall back to len(trace) and triggering a warning - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="Warmup samples"): idata = from_pymc3(trace, save_warmup=True) assert idata.posterior.dims["draw"] == 300 assert idata.posterior.dims["chain"] == 2 @@ -486,7 +486,7 @@ def test_save_warmup_issue_1208_after_3_9(self): assert idata.posterior.dims["draw"] == 200 # manually sliced trace triggers the same warning as <=3.8 - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match="Warmup samples"): idata = from_pymc3(trace[-30:], save_warmup=True) assert idata.posterior.dims["chain"] == 2 assert idata.posterior.dims["draw"] == 30