Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding option to include transformed variables in InferenceData #6232

Merged
merged 2 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def __init__(
dims: Optional[DimSpec] = None,
model=None,
save_warmup: Optional[bool] = None,
include_transformed: bool = False,
):

self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.include_transformed = include_transformed
self.trace = trace

# this permits us to get the model from command-line argument or from with model:
Expand Down Expand Up @@ -311,7 +313,9 @@ def _extract_log_likelihood(self, trace):
@requires("trace")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
var_names = get_default_varnames(self.trace.varnames, include_transformed=False)
var_names = get_default_varnames(
self.trace.varnames, include_transformed=self.include_transformed
)
data = {}
data_warmup = {}
for var_name in var_names:
Expand Down Expand Up @@ -539,6 +543,7 @@ def to_inference_data(
dims: Optional[DimSpec] = None,
model: Optional["Model"] = None,
save_warmup: Optional[bool] = None,
include_transformed: bool = False,
) -> InferenceData:
"""Convert pymc data into an InferenceData object.

Expand Down Expand Up @@ -571,6 +576,9 @@ def to_inference_data(
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
include_transformed : bool, optional
Save the transformed parameters in the InferenceData object. By default, these are
not saved.

Returns
-------
Expand All @@ -588,6 +596,7 @@ def to_inference_data(
dims=dims,
model=model,
save_warmup=save_warmup,
include_transformed=include_transformed,
).to_inference_data()


Expand Down
19 changes: 18 additions & 1 deletion pymc/tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_autodetect_coords_from_model(self, use_context):
np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"])
np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"])

def test_ovewrite_model_coords_dims(self):
def test_overwrite_model_coords_dims(self):
"""Check coords and dims from model object can be partially overwritten."""
dim1 = ["a", "b"]
new_dim1 = ["c", "d"]
Expand Down Expand Up @@ -617,6 +617,23 @@ def test_variable_dimension_name_collision(self):
var = at.as_tensor([1, 2, 3])
pmodel.register_rv(var, name="time", dims=("time",))

def test_include_transformed(self):
with pm.Model():
pm.Uniform("p", 0, 1)

# First check that the default is to exclude the transformed variables
sample_kwargs = dict(tune=5, draws=7, chains=2, cores=1)
inference_data = pm.sample(**sample_kwargs, step=pm.Metropolis())
assert "p_interval__" not in inference_data.posterior

# Now check that they are included when requested
inference_data = pm.sample(
**sample_kwargs,
step=pm.Metropolis(),
idata_kwargs={"include_transformed": True},
)
assert "p_interval__" in inference_data.posterior


class TestPyMCWarmupHandling:
@pytest.mark.parametrize("save_warmup", [False, True])
Expand Down