-
-
Notifications
You must be signed in to change notification settings - Fork 403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pointwise elpd diagnostics (text formatting and plot) #678
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments
Update tests and add azure pipeline to branch
Do not require "samples" dim to be last.
log_likelihood = log_likelihood.stack(samples=("chain", "draw")) | ||
shape = log_likelihood.shape | ||
n_samples = shape[-1] | ||
n_data_points = np.product(shape[:-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if loglikelihood is nD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still works (i checked locally with the same model as the first plot and the test for this passed too). I decided to not stack the observations dimensions in order to keep this information also in the pointwise loo and waic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the only braking change is the shape of psislw input. Which now, in the 2D (sample, and obs) is transposed from what it used to be. In the case of dataarray inputs, this should not affect (as long as the dimension is calles "samples") but for array inputs it does. It could be changed to the original shape and then call np.rollaxis
to get the n_samples
axis to the last position (as expected in xarray apply ufunc).
Ready for review! I can't decide on the yaxis ticks in the case of more than 2 models. Right now the axis are not shared, thus, the ticks on the left do not correspond to the values in the other plots. I came across with two options, but none of them really convince me. The first is to share the axis (or force the ylims to be common throughout each row) and the second is to put yticklabels in all plots. |
PR summaryThis PR has adressed elpd statistics, loo and waic, and created a plot to compare various models according to pointwise elpd statistics. It adds a completely new plot to ArviZ, and modifies various stats functions. The main goal of all these changes in the stats functions is to use xarray for the computation. This allows to work easily with nd objects while maintaining the coords information. Eventually, plots like It also adds the class Modified functions
|
On the testing side, tests have already been included, covering nearly 100% of the lines and taking into account multidimensional objects. |
ufunc_kwargs=ufunc_kwargs, | ||
**kwargs | ||
).values | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any speedup obtained via this ufunc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None. The computations performed are the same, I have not modified how logsumexp works. I did this so that there is no need to convert to arrays and the named dims and coords info is not lost.
LGTM I checked the code, and everything looks good. |
Thanks @OriolAbril this looks great. Looking forward to try these changes! |
PR to improve the output of waic and loo to make it more verbose and to create the pointwise elpd comparison plot. Will fix #496 and will fix #660.
This first version of the plot already works, but there is still much work needed.
Plots
Data from the toy model in this notebook. This was generated with
color="river_distance"
andlegend=True
.tick labels?
highlight worst points?
Added a
threshold
argument in order to show labels of point further away than threshold times elpd_i.std(). In the example,threshold=1
to force showing the labels.move common coloring/labeling data functions to
plot_utils.py
to have the same functionality inplot_elpd
andplot_khat
Stats functions
improve verbosity
Use xarray for all computations.
See Diagnostics + Stats to use xarray #501 for some more detail.
Both
Everything working after some corrections, still debating between reshaping pointwise loo/waic to original shape or work with multi-index objects from there on (eventually the scatter plot is flattened, thus at some point they are needed)
Idea: Store pointwise loo, waic and pareto_k as dataarrays in the
ELPDData
object instead of as a flattened array and callplot_elpd
with a dict ofELPDData
andplot_khat
with anELPDData
object instead of the array of pareto shape values. Using the dataarray will allow coloring, ticklabels, selection of a subset of observations and so on, and thanks to the overwritten__str__
method, including this extra information it theELPDData
object won't clutter the relevant info when printed.