-
-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add new groups to io_pyro #1090
Conversation
Works fine when trace is passed. |
arviz/data/io_pyro.py
Outdated
idata_origin=None, | ||
inplace=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not add this to from_pyro
, they could go to a from_pyro_predictions
(to mimic pymc3 pattern) but I am not sure it is worth it to include such a function it yet, it would basically be only a concat
call, so users can call az.concat
directly
arviz/data/io_pyro.py
Outdated
import pyro | ||
|
||
if self.predictions is not None and self.pred_dims is None: | ||
raise ValueError("Prediction dims are needed for predictions group.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dims are not needed, if not present they will get some default names by ArviZ internal functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the thing is, the predictions
variables have the same name as the posterior_predictive
variables, so it uses the posterior_predictive
coords which has different value and causes error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I follow, predictions should not have access to dims
at any point, only to pred_dims
. Then the possible cases are:
- pred dims is None -> default dim and coord names
- pred dims is a dict -> use pred dims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, both predictions and predictions_constant_data use pred_dims
instead of dims
. I have raised that error just in case the user passes the predictions but does not pass pred_dims
.
All other groups are using the default dims
, only the predictions
and predictions_constant_data
groups need pred_dims
because their variable names are same as posterior_predictive
and constant_data
respectively but their data has actually different dimensions. So they cannot use the default dims
in any case. You can see an example here.
Are we on the same page now? Other groups will always use the default dims
but predictions and predictions_constant_data must always use pred_dims
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, my point is that dims can be None, which makes ArviZ use some defaults, pred_dims should have the option of being none and generating default values, not using dims, generating the defaults that correspond to the dataset. I am not sure what I am missing, I thought default dims were generated on a dataset basis, not on an inference data basis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh yes dims are generated on dataset basis.
- pred dims is None -> default dim and coord names
Sorry I got lost here, thought that you were talking about self.dims
:)
Maybe an nchains argument? It should be enough with only one, as the number of samples is already a dimension |
arviz/data/io_pyro.py
Outdated
@@ -45,10 +66,18 @@ def __init__( | |||
self.nchains = self.ndraws = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just seen the if posterior is not None
and chain, draw definition is repeated below. This one can be removed
arviz/data/io_pyro.py
Outdated
self.coords = coords | ||
self.dims = dims | ||
self.pred_dims = pred_dims | ||
self.num_chains = num_chains |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would handle this in the else: self.nchains = self.ndraws = 0
below, to set self.nchains to num_chains and then get the number of draws from the predictions, posterior predictive or prior (I think io_pymc does something similar).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, as it is only used in init, it does not need to be saved in self
arviz/data/io_pyro.py
Outdated
|
||
@requires("predictions") | ||
def predictions_to_xarray(self): | ||
"""Convert predictions (out of sample predictions) to xarray.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would either remove the out of sample predictions or change it to out of sample posterior predictive
Can I now add the tests? |
arviz/data/io_pyro.py
Outdated
pred_dims: dict | ||
Dims for predictions data. Map variable names to their coordinates. | ||
num_chains: int | ||
Number of chains used for sampling. Only needed when posterior is not provided. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "ignored if posterior is present instead"
arviz/data/io_pyro.py
Outdated
self.coords = coords | ||
self.dims = dims | ||
self.pred_dims = pred_dims | ||
self.num_chains = num_chains |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, as it is only used in init, it does not need to be saved in self
arviz/data/io_pyro.py
Outdated
else: | ||
self.nchains = self.ndraws = 0 | ||
raise ValueError("`num_chains` is needed if trace is not given.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not raise an error, for example with prior, sampling is generally not performed with multiple chains, from_pyro with only prior or predictions and no num chains should work (even if merging afterwards does not or does not work properly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then should I set chains and draws to 0 here? Or change the elif above to else and within that:
if num_chains is not None:
self.nchains = num_chains
else:
self.nchains = 1
. . .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the interest of not sending warnings when they are not needed, maybe it would be best to set num_chains=1
in the function definitions. I think it will do the same as the else you are proposing given that num_chains is ignored if posterior is present.
Definitely, I think you have already used |
@OriolAbril, do you think these tests are enough? Any tests for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think these tests are enough? Any tests for pred_dims or num_chains or any changes to be made in test_inference_data_no_posterior?
I would add some more tests, as you say, one checking that dims and pred dims work properly would be great and one checking num_chain works too. Also, either in the no posterior test, constant data test or a new one, it would be great to check what happens when predictions is alone (maybe test num_chains and predictions alone at the same time in a new test?).
Notes: by predictions alone I mean any of predictions, constant_data_predictions and pred_dims, also, https://github.com/arviz-devs/arviz/blob/master/arviz/tests/external_tests/test_data_pymc.py#L322 can also help with testing some combinations in the same test.
Another idea, please say if you think it would be useful or not. I have seen that constant data (either for model or for predictions) is generally already in a dict, however it is a dict of pytorch tensors, do you think it would be useful to allow a dict of tensors as constant data argument? It could be handled with a try except or checking ìf hasattr("detach", value)
?
I think constant data already works with a dict of tensors. Is there any incompatibility? |
Not sure, there should probably be a test for this (or just say it is not supported), like numpy, xarray hardly ever raises an error when creating arrays/datasets, however, the result can be unexpected in many cases:
We should make sure to add in docstring to convert to array or add a test checking that the generated dataset has the right shape, not length one and dtype object or something similar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also make predictions data a fixture as it is used several times (I think)
inference_data = from_pyro(prior=prior) | ||
test_dict = {"prior": ["mu", "tau", "eta"]} | ||
fails = check_multiple_attrs(test_dict, inference_data) | ||
assert not fails |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using
assert not fails, "only prior: {}".format(fails)
will yield more informative error messages, not sure about formatting working though. I'll try to check this
) | ||
|
||
inference_data = from_pyro(predictions=predictions, num_chains=2) | ||
nchains = inference_data.predictions["obs"].shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference_data.predictions.dims["chain"]
should return the length of chain dim
Sorry, I don't understand. The incorrect dims case would only be there when pyro samples data from the model without chains (like in case of |
I have played a little with tensors and it looks like xarray converts them automatically to arrays. My concern came from the fact that when prompted with the object:
it can either be converted to an array of dtype float and length 5 or to an array of dtype object and length 1 whose first position contains the whole tensor. Everything seems to work properly, even test error messages. Thanks! |
modify changelog
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nits
arviz/data/io_pyro.py
Outdated
"""When constructing InferenceData must have at least | ||
one of trace, prior, posterior_predictive or predictions.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"... one of posterior, prior, ..."
, trace is a pymc specific argument
@pytest.fixture(scope="class") | ||
def predictions_data(self, data): | ||
posterior_samples = data.obj.get_samples() | ||
model = data.obj.kernel.model | ||
pred_data = {"J": 8, "sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0])} | ||
predictions = Predictive(model, posterior_samples)( | ||
pred_data["J"], torch.from_numpy(pred_data["sigma"]).float() | ||
) | ||
return predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about making a prediction_params
(like eight_school_params with scope class), so that predictions_data(self, data, prediction_params)
and this prediction_params
fixture would then also be used in get_inference_data
and in test_inference_data_no_posterior
dims = inference_data.posterior_predictive["obs"].shape[2:] | ||
pred_dims = inference_data.predictions["obs"].shape[2:] | ||
assert dims == (8,) | ||
assert pred_dims == (8,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference_data.posterior_predictive.dims["school"]
and inference_data.predictions.dims["school_pred"]
, then the assert will be dims == 8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be moved to the test_inference_data
test above too, both start with the same code.
|
||
def test_inference_data_num_chains(self, predictions_data): | ||
predictions = predictions_data | ||
inference_data = from_pyro(predictions=predictions, num_chains=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_chains=chains
(use chains
fixture imported from helpers to make sure chains is actually the number of chains in posterior)
# test dims | ||
dims = inference_data.posterior_predictive.dims["school"] | ||
pred_dims = inference_data.predictions.dims["school_pred"] | ||
assert dims == 8, pred_dims == 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should still be one in each line, otherwise the pred_dims == 8
will only be executed if assert dims fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think it is ready to merge
I'll add the groups to numpyro soon |
fix changelog merge issue
Codecov Report
@@ Coverage Diff @@
## master #1090 +/- ##
==========================================
- Coverage 92.68% 92.67% -0.01%
==========================================
Files 93 93
Lines 9032 9069 +37
==========================================
+ Hits 8371 8405 +34
- Misses 661 664 +3
Continue to review full report at Codecov.
|
Thanks! |
* add predictions * add remaining groups * black changes * modify chains * remove repeated lines * minor changes done * fix pred_dims * add tests * add more tests * modified tests * update changelog modify changelog * minor changes * correct test
Description
Add the following groups to
io_pyro
:predictions
constant_data
predictions_constant_data
Tests:
predictions
constant_data
predictions_constant_data
Checklist