diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 826f64d6f9e..7939378c97a 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -164,9 +164,11 @@ def __init__( dims: Optional[DimSpec] = None, model=None, save_warmup: Optional[bool] = None, + include_transformed: bool = False, ): self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.include_transformed = include_transformed self.trace = trace # this permits us to get the model from command-line argument or from with model: @@ -311,7 +313,9 @@ def _extract_log_likelihood(self, trace): @requires("trace") def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" - var_names = get_default_varnames(self.trace.varnames, include_transformed=False) + var_names = get_default_varnames( + self.trace.varnames, include_transformed=self.include_transformed + ) data = {} data_warmup = {} for var_name in var_names: @@ -539,6 +543,7 @@ def to_inference_data( dims: Optional[DimSpec] = None, model: Optional["Model"] = None, save_warmup: Optional[bool] = None, + include_transformed: bool = False, ) -> InferenceData: """Convert pymc data into an InferenceData object. @@ -571,6 +576,9 @@ def to_inference_data( save_warmup : bool, optional Save warmup iterations InferenceData object. If not defined, use default defined by the rcParams. + include_transformed : bool, optional + Save the transformed parameters in the InferenceData object. By default, these are + not saved. Returns ------- @@ -588,6 +596,7 @@ def to_inference_data( dims=dims, model=model, save_warmup=save_warmup, + include_transformed=include_transformed, ).to_inference_data() diff --git a/pymc/tests/backends/test_arviz.py b/pymc/tests/backends/test_arviz.py index ce287d0302c..ec87728acc5 100644 --- a/pymc/tests/backends/test_arviz.py +++ b/pymc/tests/backends/test_arviz.py @@ -279,7 +279,7 @@ def test_autodetect_coords_from_model(self, use_context): np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"]) np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"]) - def test_ovewrite_model_coords_dims(self): + def test_overwrite_model_coords_dims(self): """Check coords and dims from model object can be partially overwritten.""" dim1 = ["a", "b"] new_dim1 = ["c", "d"] @@ -617,6 +617,23 @@ def test_variable_dimension_name_collision(self): var = at.as_tensor([1, 2, 3]) pmodel.register_rv(var, name="time", dims=("time",)) + def test_include_transformed(self): + with pm.Model(): + pm.Uniform("p", 0, 1) + + # First check that the default is to exclude the transformed variables + sample_kwargs = dict(tune=5, draws=7, chains=2, cores=1) + inference_data = pm.sample(**sample_kwargs, step=pm.Metropolis()) + assert "p_interval__" not in inference_data.posterior + + # Now check that they are included when requested + inference_data = pm.sample( + **sample_kwargs, + step=pm.Metropolis(), + idata_kwargs={"include_transformed": True}, + ) + assert "p_interval__" in inference_data.posterior + class TestPyMCWarmupHandling: @pytest.mark.parametrize("save_warmup", [False, True])