Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better API for obtaining posterior point estimates & more #1899

Open
ccaprani opened this issue Oct 19, 2021 · 1 comment
Open

Better API for obtaining posterior point estimates & more #1899

ccaprani opened this issue Oct 19, 2021 · 1 comment

Comments

@ccaprani
Copy link

As part of a notebook for pymc (pymc-devs/pymc-examples#241, seen here) to support the addition of the Generalized Extreme Value distribution (pymc-devs/pymc#5085), I had a few complexities processing results with the current API.

I'm raising this issue here to see if there is appetite for a PR along the proposed lines here.

  1. This code snippet is used to compare a posterior result with the maximum likelihood estimate from the reference book:
_, vals = az.sel_utils.xarray_to_ndarray(trace["posterior"], var_names=["μ", "σ", "ξ"])
mle = [az.plots.plot_utils.calculate_point_estimate("mode", val) for val in vals]

As can be seen, this uses/abuses a few back-end Arviz functions. It would seem better to have a cleaner API to access the point estimates that can be obtained in the hdi plots parameters point_estimate argument, such as mean, mode, median. Something like: az.get_point_estimate(point_estimate='mode', var_names=["μ", "σ", "ξ"]).

  1. Getting the variance-covariance matrix of the estimates requires a pandas interface:
trace["posterior"].drop_vars("z_p").to_dataframe().cov().round(6)

Again, this is a bit non-bayesian, but is useful for comparison with results from other sources. So something like: az.get_var_covar(var_names=["μ", "σ", "ξ"]).

  1. Again, looking at that InferenceData accessor to the xarray drop_vars, it would be neat if there was a comparable get_vars which returned the results for the selected variables - this functionality is already built-in of course, as is used through the arguments to many of the plot functions. But something directly like: trace["posterior"].get_vars(["μ", "σ", "ξ"]) would be helpful.

  2. More minor: It seems that to examine the prior predictive checks, we should now use the plot_posterior function. I suspect a plot_prior wrapper would be more logical and more readable code.

az.plot_posterior(
    prior_pc, group="prior", var_names=["μ", "σ", "ξ"], hdi_prob="hide", point_estimate=None
);
@ccaprani ccaprani changed the title Better API for obtaining posterior point estimates Better API for obtaining posterior point estimates & more Oct 19, 2021
@OriolAbril
Copy link
Member

OriolAbril commented Oct 22, 2021

Intro

such as mean, mode, median

mean and median are already available via xarray, so I think we should not reimplement those. Doing idata.posterior[["subset", "of vars", "if desired"]].median() already works, and you can use dim=("chain", "draw") to specify which dimensions to reduce by name. Same for mean. Reference: https://xarray.pydata.org/en/stable/generated/xarray.Dataset.median.html, https://xarray.pydata.org/en/stable/generated/xarray.Dataset.mean.html.

It might be interesting to try and make the mode we use in some plots available with a similar api. It should not be too difficult given it's already implemented if using apply_ufunc or wrap_xarray_ufunc carefully

Getting the variance-covariance matrix of the estimates requires a pandas interface:

Never used that but this looks general enough to live in xarray directly and seems to be somewhat available already: https://xarray.pydata.org/en/stable/generated/xarray.cov.html. If this is not good enough we should try and push those improvements directly to xarray.

Again, looking at that InferenceData accessor to the xarray drop_vars, it would be neat if there was a comparable get_vars which returned the results for the selected variables - this functionality is already built-in of course, as is used through the arguments to many of the plot functions. But something directly like: trace["posterior"].get_vars(["μ", "σ", "ξ"]) would be helpful.

Is this the same (or would be solved satisfactorly) as #1725 (that fixes #1469)?

It seems that to examine the prior predictive checks, we should now use the plot_posterior function. I suspect a plot_prior wrapper would be more logical and more readable code.

The name might not have been the best choice but plot_posterior, like plot_density and several other functions can be used on any group, not necessarily the posterior or the prior groups. Adding plot_prior would probably mean open season for plot_posterior_predictive, plot_sample_stats... which I believe would end up being even more confusing.

Also note that plot_ppc can be used for either posterior or prior predictive checks in the comparing generated distributions to observed data, plot_posterior plots distributions, and so it can be used to plot the prior distributions and also consequently for prior predictive checks, but is not necessarily tied to prior predictive checks not includes the observed data in the plot.

Practical remarks

  1. We should consider making the mode computation available. The main con on that I can see is that it will probably be towards the low priority end and might take time to get even if it is not too many lines of code. I don't think there would be opposition to the change but might be wrong
  2. Can you try the xarray.cov thing and let us know how it goes?
  3. I have not had much time lately and have not gone back to finish the extract_dataset PR. If you feel up for it, feel free to take the work there and push it after the finish line
  4. Maybe we should consider renaming plot_posterior?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants