diff --git a/arviz/plots/__init__.py b/arviz/plots/__init__.py index 046150086b..653f680ccc 100644 --- a/arviz/plots/__init__.py +++ b/arviz/plots/__init__.py @@ -6,6 +6,7 @@ from .forestplot import plot_forest from .kdeplot import plot_kde, _fast_kde, _fast_kde_2d from .parallelplot import plot_parallel +from .elpdplot import plot_elpd from .posteriorplot import plot_posterior from .traceplot import plot_trace from .pairplot import plot_pair @@ -28,6 +29,7 @@ "_fast_kde", "_fast_kde_2d", "plot_parallel", + "plot_elpd", "plot_posterior", "plot_trace", "plot_pair", diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py new file mode 100644 index 0000000000..b51b07e35d --- /dev/null +++ b/arviz/plots/elpdplot.py @@ -0,0 +1,276 @@ +"""Plot pointwise elpd estimations of inference data.""" +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.cm as cm +from matplotlib.ticker import NullFormatter +from matplotlib.lines import Line2D + +from ..data import convert_to_inference_data +from .plot_utils import ( + _scale_fig_size, + get_coords, + color_from_dim, + format_coords_as_labels, + set_xticklabels, +) +from ..stats import waic, loo, ELPDData + + +def plot_elpd( + compare_dict, + color=None, + xlabels=False, + figsize=None, + textsize=None, + coords=None, + legend=False, + threshold=None, + ax=None, + ic="waic", + scale="deviance", + plot_kwargs=None, +): + """ + Plot a scatter or hexbin matrix of the sampled parameters. + + Parameters + ---------- + compare_dict : mapping, str -> ELPDData or InferenceData + A dictionary mapping the model name to the object containing its inference data or + the result of `waic`/`loo` functions. + Refer to az.convert_to_inference_data for details on possible dict items + color : str or array_like, optional + Colors of the scatter plot, if color is a str all dots will have the same color, + if it is the size of the observations, each dot will have the specified color, + otherwise, it will be interpreted as a list of the dims to be used for the color code + xlabels : bool, optional + Use coords as xticklabels + figsize : figure size tuple, optional + If None, size is (8 + numvars, 8 + numvars) + textsize: int, optional + Text size for labels. If None it will be autoscaled based on figsize. + coords : mapping, optional + Coordinates of points to plot. **All** values are used for computation, but only a + a subset can be plotted for convenience. + legend : bool, optional + Include a legend to the plot. Only taken into account when color argument is a dim name. + threshold : float + If some elpd difference is larger than `threshold * elpd.std()`, show its label. If + `None`, no observations will be highlighted. + ax: axes, optional + Matplotlib axes + ic : str, optional + Information Criterion (WAIC or LOO) used to compare models. Default WAIC. Only taken + into account when input is InferenceData. + scale : str, optional + scale argument passed to az.waic or az.loo, see their docs for details. Only taken + into account when input is InferenceData. + plot_kwargs : dicts, optional + Additional keywords passed to ax.plot + + Returns + ------- + ax : matplotlib axes + + Examples + -------- + Compare pointwise WAIC for centered and non centered models of the 8school problem + + .. plot:: + :context: close-figs + + >>> import arviz as az + >>> idata1 = az.load_arviz_data("centered_eight") + >>> idata2 = az.load_arviz_data("non_centered_eight") + >>> az.plot_elpd( + >>> {"centered model": idata1, "non centered model": idata2}, + >>> xlabels=True + >>> ) + + """ + valid_ics = ["waic", "loo"] + ic = ic.lower() + if ic not in valid_ics: + raise ValueError( + ("Information Criteria type {} not recognized." "IC must be in {}").format( + ic, valid_ics + ) + ) + ic_fun = waic if ic == "waic" else loo + + # Make sure all object are ELPDData + for k, item in compare_dict.items(): + if not isinstance(item, ELPDData): + compare_dict[k] = ic_fun(convert_to_inference_data(item), pointwise=True, scale=scale) + ics = [elpd_data.index[0] for elpd_data in compare_dict.values()] + if not all(x == ics[0] for x in ics): + raise SyntaxError( + "All Information Criteria must be of the same kind, but both loo and waic data present" + ) + ic = ics[0] + scales = [elpd_data["{}_scale".format(ic)] for elpd_data in compare_dict.values()] + if not all(x == scales[0] for x in scales): + raise SyntaxError( + "All Information Criteria must be on the same scale, but {} are present".format( + set(scales) + ) + ) + numvars = len(compare_dict) + models = list(compare_dict.keys()) + + if coords is None: + coords = {} + + if plot_kwargs is None: + plot_kwargs = {} + plot_kwargs.setdefault("marker", "+") + + pointwise_data = [ + get_coords(compare_dict[model]["{}_i".format(ic)], coords) for model in models + ] + xdata = np.arange(pointwise_data[0].size) + + if isinstance(color, str): + if color in pointwise_data[0].dims: + colors, color_mapping = color_from_dim(pointwise_data[0], color) + if legend: + cmap_name = plot_kwargs.pop("cmap", plt.rcParams["image.cmap"]) + markersize = plot_kwargs.pop("s", plt.rcParams["lines.markersize"]) + cmap = getattr(cm, cmap_name) + handles = [ + Line2D( + [], + [], + color=cmap(float_color), + label=coord, + ms=markersize, + lw=0, + **plot_kwargs + ) + for coord, float_color in color_mapping.items() + ] + plot_kwargs.setdefault("cmap", cmap_name) + plot_kwargs.setdefault("s", markersize ** 2) + plot_kwargs.setdefault("c", colors) + else: + plot_kwargs.setdefault("c", color) + legend = False + else: + legend = False + plot_kwargs.setdefault("c", color) + + if xlabels: + coord_labels = format_coords_as_labels(pointwise_data[0]) + + if numvars < 2: + raise Exception("Number of models to compare must be 2 or greater.") + + if numvars == 2: + (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size( + figsize, textsize, numvars - 1, numvars - 1 + ) + plot_kwargs.setdefault("s", markersize ** 2) + + if ax is None: + fig, ax = plt.subplots(figsize=figsize, constrained_layout=(not xlabels and not legend)) + + ydata = pointwise_data[0] - pointwise_data[1] + ax.scatter(xdata, ydata, **plot_kwargs) + if threshold is not None: + ydata = ydata.values.flatten() + diff_abs = np.abs(ydata - ydata.mean()) + bool_ary = diff_abs > threshold * ydata.std() + try: + coord_labels + except NameError: + coord_labels = xdata.astype(str) + outliers = np.argwhere(bool_ary).squeeze() + for outlier in outliers: + label = coord_labels[outlier] + ax.text( + outlier, + ydata[outlier], + label, + horizontalalignment="center", + verticalalignment="bottom" if ydata[outlier] > 0 else "top", + fontsize=0.8 * xt_labelsize, + ) + + ax.set_title("{} - {}".format(*models), fontsize=titlesize, wrap=True) + ax.set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True) + ax.tick_params(labelsize=xt_labelsize) + if xlabels: + set_xticklabels(ax, coord_labels) + fig.autofmt_xdate() + if legend: + ncols = len(handles) // 6 + 1 + ax.legend(handles=handles, ncol=ncols, title=color) + + else: + (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size( + figsize, textsize, numvars - 2, numvars - 2 + ) + plot_kwargs.setdefault("s", markersize ** 2) + + if ax is None: + fig, ax = plt.subplots( + numvars - 1, + numvars - 1, + figsize=figsize, + constrained_layout=(not xlabels and not legend), + ) + + for i in range(0, numvars - 1): + var1 = pointwise_data[i] + + for j in range(0, numvars - 1): + if j < i: + ax[j, i].axis("off") + continue + + var2 = pointwise_data[j + 1] + ax[j, i].scatter(xdata, var1 - var2, **plot_kwargs) + if threshold is not None: + ydata = (var1 - var2).values.flatten() + diff_abs = np.abs(ydata - ydata.mean()) + bool_ary = diff_abs > threshold * ydata.std() + try: + coord_labels + except NameError: + coord_labels = xdata.astype(str) + outliers = np.argwhere(bool_ary).squeeze() + for outlier in outliers: + label = coord_labels[outlier] + ax[j, i].text( + outlier, + ydata[outlier], + label, + horizontalalignment="center", + verticalalignment="bottom" if ydata[outlier] > 0 else "top", + fontsize=0.8 * xt_labelsize, + ) + + if j + 1 != numvars - 1: + ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter()) + ax[j, i].set_xticks([]) + elif xlabels: + set_xticklabels(ax[j, i], coord_labels) + + if i != 0: + ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter()) + ax[j, i].set_yticks([]) + else: + ax[j, i].set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True) + + ax[j, i].tick_params(labelsize=xt_labelsize) + ax[j, i].set_title( + "{} - {}".format(models[i], models[j + 1]), fontsize=titlesize, wrap=True + ) + if xlabels: + fig.autofmt_xdate() + if legend: + ncols = len(handles) // 6 + 1 + ax[0, 1].legend( + handles=handles, ncol=ncols, title=color, bbox_to_anchor=(0, 1), loc="upper left" + ) + return ax diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 82c1648000..512a239dd0 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -375,3 +375,56 @@ def get_coords(data, coords): " dimensions are valid. {}" ).format(err) ) + + +def color_from_dim(dataarray, dim_name): + """Return colors and color mapping of a DataArray using coord values as color code. + + Parameters + ---------- + dataarray : xarray.DataArray + dim_name : str + dimension whose coordinates will be used as color code. + + Returns + ------- + colors : array of floats + Array of colors (as floats for use with a cmap) for each element in the dataarray. + color_mapping : mapping coord_value -> float + Mapping from coord values to corresponding color + """ + present_dims = dataarray.dims + coord_values = dataarray[dim_name].values + unique_coords = set(coord_values) + color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)} + if len(present_dims) > 1: + multi_coords = dataarray.coords.to_index() + coord_idx = present_dims.index(dim_name) + colors = [color_mapping[coord[coord_idx]] for coord in multi_coords] + else: + colors = [color_mapping[coord] for coord in coord_values] + return colors, color_mapping + + +def format_coords_as_labels(dataarray): + """Format 1d or multi-d dataarray coords as strings.""" + coord_labels = dataarray.coords.to_index().values + if isinstance(coord_labels[0], tuple): + fmt = ", ".join(["{}" for _ in coord_labels[0]]) + coord_labels[:] = [fmt.format(*x) for x in coord_labels] + else: + coord_labels[:] = ["{}".format(s) for s in coord_labels] + return coord_labels + + +def set_xticklabels(ax, coord_labels): + """Set xticklabels to label list using Matplotlib default formatter.""" + ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10]) + xticks = ax.get_xticks().astype(np.int64) + xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))] + if len(xticks) > len(coord_labels): + ax.set_xticks(np.arange(len(coord_labels))) + ax.set_xticklabels(coord_labels) + else: + ax.set_xticks(xticks) + ax.set_xticklabels(coord_labels[xticks]) diff --git a/arviz/stats/__init__.py b/arviz/stats/__init__.py index 1fa258c6f8..b4951f4f75 100644 --- a/arviz/stats/__init__.py +++ b/arviz/stats/__init__.py @@ -15,6 +15,7 @@ "summary", "waic", "effective_sample_size", + "ELPDData", "ess", "rhat", "mcse", diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 6ace9c706e..ba23ec2a31 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -12,7 +12,12 @@ from ..data import convert_to_inference_data, convert_to_dataset from .diagnostics import _multichain_statistics, _mc_error, ess -from .stats_utils import make_ufunc as _make_ufunc, logsumexp as _logsumexp +from .stats_utils import ( + make_ufunc as _make_ufunc, + wrap_xarray_ufunc as _wrap_xarray_ufunc, + logsumexp as _logsumexp, + ELPDData, +) from ..utils import _var_names _log = logging.getLogger(__name__) @@ -225,9 +230,9 @@ def gradient(weights): for idx, val in enumerate(ics.index): res = ics.loc[val] if scale_value < 0: - diff = res[ic_i] - min_ic_i_val + diff = (res[ic_i] - min_ic_i_val).values else: - diff = min_ic_i_val - res[ic_i] + diff = (min_ic_i_val - res[ic_i]).values d_ic = np.sum(diff) d_std_err = np.sqrt(len(diff) * np.var(diff)) std_err = ses.loc[val] @@ -360,7 +365,7 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): Returns ------- - pandas.Series with the following columns: + pandas.Series with the following rows: loo : approximated Leave-one-out cross-validation loo_se : standard error of loo p_loo : effective number of parameters @@ -371,6 +376,9 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): pareto_k : array of Pareto shape values, only if pointwise True loo_scale : scale of the loo results + The returned object has a custom print method that overrides pd.Series method. It is + specific to expected log pointwise predictive density (elpd) information criteria. + Examples -------- Calculate the LOO-CV of a model: @@ -379,7 +387,15 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): In [1]: import arviz as az ...: data = az.load_arviz_data("centered_eight") - ...: az.loo(data, pointwise=True) + ...: az.loo(data) + + The custom print method can be seen here, printing only the relevant information and + with a specific organization. ``IC_loo`` stands for information criteria, which is the + `deviance` scale, the `log` (and `negative_log`) correspond to ``elpd`` (and ``-elpd``) + + .. ipython:: + + In [2]: az.loo(data, pointwise=True, scale="log") """ inference_data = convert_to_inference_data(data) @@ -392,9 +408,10 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): raise TypeError("Data must include log_likelihood in sample_stats") posterior = inference_data.posterior log_likelihood = inference_data.sample_stats.log_likelihood - n_samples = log_likelihood.chain.size * log_likelihood.draw.size - new_shape = (n_samples, np.product(log_likelihood.shape[2:])) - log_likelihood = log_likelihood.values.reshape(*new_shape) + log_likelihood = log_likelihood.stack(samples=("chain", "draw")) + shape = log_likelihood.shape + n_samples = shape[-1] + n_data_points = np.product(shape[:-1]) if scale.lower() == "deviance": scale_value = -2 @@ -430,29 +447,60 @@ def loo(data, pointwise=False, reff=None, scale="deviance"): ) warn_mg = True - loo_lppd_i = scale_value * _logsumexp(log_weights, axis=0) - loo_lppd = loo_lppd_i.sum() - loo_lppd_se = (len(loo_lppd_i) * np.var(loo_lppd_i)) ** 0.5 - - lppd = np.sum(_logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0])) + ufunc_kwargs = {"n_dims": 1, "ravel": False} + kwargs = {"input_core_dims": [["samples"]]} + loo_lppd_i = scale_value * _wrap_xarray_ufunc( + _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs + ) + loo_lppd = loo_lppd_i.values.sum() + loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5 + + lppd = np.sum( + _wrap_xarray_ufunc( + _logsumexp, + log_likelihood, + func_kwargs={"b_inv": n_samples}, + ufunc_kwargs=ufunc_kwargs, + **kwargs + ).values + ) p_loo = lppd - loo_lppd / scale_value if pointwise: if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member warnings.warn( - """The point-wise LOO is the same with the sum LOO, please double check - the Observed RV in your model to make sure it returns element-wise logp. - """ + "The point-wise LOO is the same with the sum LOO, please double check " + "the Observed RV in your model to make sure it returns element-wise logp." ) - return pd.Series( - data=[loo_lppd, loo_lppd_se, p_loo, warn_mg, loo_lppd_i, pareto_shape, scale], - index=["loo", "loo_se", "p_loo", "warning", "loo_i", "pareto_k", "loo_scale"], + return ELPDData( + data=[ + loo_lppd, + loo_lppd_se, + p_loo, + n_samples, + n_data_points, + warn_mg, + loo_lppd_i.rename("loo_i"), + pareto_shape, + scale, + ], + index=[ + "loo", + "loo_se", + "p_loo", + "n_samples", + "n_data_points", + "warning", + "loo_i", + "pareto_k", + "loo_scale", + ], ) else: - return pd.Series( - data=[loo_lppd, loo_lppd_se, p_loo, warn_mg, scale], - index=["loo", "loo_se", "p_loo", "warning", "loo_scale"], + return ELPDData( + data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale], + index=["loo", "loo_se", "p_loo", "n_samples", "n_data_points", "warning", "loo_scale"], ) @@ -463,7 +511,7 @@ def psislw(log_weights, reff=1.0): Parameters ---------- log_weights : array - Array of size (n_samples, n_observations) + Array of size (n_observations, n_samples) reff : float relative MCMC efficiency, `ess / n` @@ -474,57 +522,92 @@ def psislw(log_weights, reff=1.0): kss : array Pareto tail indices """ - rows, cols = log_weights.shape - - log_weights_out = np.copy(log_weights, order="F") - kss = np.empty(cols) - + if hasattr(log_weights, "samples"): + n_samples = len(log_weights.samples) + shape = [size for size, dim in zip(log_weights.shape, log_weights.dims) if dim != "samples"] + else: + n_samples = log_weights.shape[-1] + shape = log_weights.shape[:-1] # precalculate constants - cutoff_ind = -int(np.ceil(min(rows / 5.0, 3 * (rows / reff) ** 0.5))) - 1 + cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1 cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return k_min = 1.0 / 3 - # loop over sets of log weights - for i, x in enumerate(log_weights_out.T): - # improve numerical accuracy - x -= np.max(x) - # sort the array - x_sort_ind = np.argsort(x) - # divide log weights into body and right tail - xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin) - - expxcutoff = np.exp(xcutoff) - tailinds, = np.where(x > xcutoff) # pylint: disable=unbalanced-tuple-unpacking - x_tail = x[tailinds] - tail_len = len(x_tail) - if tail_len <= 4: - # not enough tail samples for gpdfit - k = np.inf - else: - # order of tail samples - x_tail_si = np.argsort(x_tail) - # fit generalized Pareto distribution to the right tail samples - x_tail = np.exp(x_tail) - expxcutoff - k, sigma = _gpdfit(x_tail[x_tail_si]) - - if k >= k_min: - # no smoothing if short tail or GPD fit failed - # compute ordered statistic for the fit - sti = np.arange(0.5, tail_len) / tail_len - smoothed_tail = _gpinv(sti, k, sigma) - smoothed_tail = np.log( # pylint: disable=assignment-from-no-return - smoothed_tail + expxcutoff - ) - # place the smoothed tail into the output array - x[tailinds[x_tail_si]] = smoothed_tail - # truncate smoothed values to the largest raw weight 0 - x[x > 0] = 0 - # renormalize weights - x -= _logsumexp(x) - # store tail index k - kss[i] = k + # create output array with proper dimensions + out = tuple([np.empty_like(log_weights), np.empty(shape)]) + + # define kwargs + func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "k_min": k_min, "out": out} + ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False} + kwargs = {"input_core_dims": [["samples"]], "output_core_dims": [["sample"], []]} + log_weights, pareto_shape = _wrap_xarray_ufunc( + _psislw, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs, **kwargs + ) + if isinstance(log_weights, xr.DataArray): + log_weights = log_weights.rename("log_weights").rename(sample="samples") + if isinstance(pareto_shape, xr.DataArray): + pareto_shape = pareto_shape.rename("pareto_shape") + return log_weights, pareto_shape + + +def _psislw(log_weights, cutoff_ind, cutoffmin, k_min=1.0 / 3): + """ + Pareto smoothed importance sampling (PSIS) for a 1D vector. - return log_weights_out, kss + Parameters + ---------- + log_weights : array + Array of length n_observations + cutoff_ind : int + cutoffmin : float + k_min : float + + Returns + ------- + lw_out : array + Smoothed log weights + kss : float + Pareto tail index + """ + x = np.asarray(log_weights) + + # improve numerical accuracy + x -= np.max(x) + # sort the array + x_sort_ind = np.argsort(x) + # divide log weights into body and right tail + xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin) + + expxcutoff = np.exp(xcutoff) + tailinds, = np.where(x > xcutoff) # pylint: disable=unbalanced-tuple-unpacking + x_tail = x[tailinds] + tail_len = len(x_tail) + if tail_len <= 4: + # not enough tail samples for gpdfit + k = np.inf + else: + # order of tail samples + x_tail_si = np.argsort(x_tail) + # fit generalized Pareto distribution to the right tail samples + x_tail = np.exp(x_tail) - expxcutoff + k, sigma = _gpdfit(x_tail[x_tail_si]) + + if k >= k_min: + # no smoothing if short tail or GPD fit failed + # compute ordered statistic for the fit + sti = np.arange(0.5, tail_len) / tail_len + smoothed_tail = _gpinv(sti, k, sigma) + smoothed_tail = np.log( # pylint: disable=assignment-from-no-return + smoothed_tail + expxcutoff + ) + # place the smoothed tail into the output array + x[tailinds[x_tail_si]] = smoothed_tail + # truncate smoothed values to the largest raw weight 0 + x[x > 0] = 0 + # renormalize weights + x -= _logsumexp(x) + + return x, k def _gpdfit(ary): @@ -924,7 +1007,7 @@ def waic(data, pointwise=False, scale="deviance"): Returns ------- - DataFrame with the following columns: + Series with the following rows: waic : widely available information criterion waic_se : standard error of waic p_waic : effective number parameters @@ -934,9 +1017,12 @@ def waic(data, pointwise=False, scale="deviance"): waic_i : and array of the pointwise predictive accuracy, only if pointwise True waic_scale : scale of the waic results + The returned object has a custom print method that overrides pd.Series method. It is + specific to expected log pointwise predictive density (elpd) information criteria. + Examples -------- - Calculate the LOO-CV of a model: + Calculate the WAIC of a model: .. ipython:: @@ -944,6 +1030,9 @@ def waic(data, pointwise=False, scale="deviance"): ...: data = az.load_arviz_data("centered_eight") ...: az.waic(data, pointwise=True) + The custom print method can be seen here, printing only the relevant information and + with a specific organization. ``IC_loo`` stands for information criteria, which is the + `deviance` scale, the `log` (and `negative_log`) correspond to ``elpd`` (and ``-elpd``) """ inference_data = convert_to_inference_data(data) for group in ("sample_stats",): @@ -964,13 +1053,22 @@ def waic(data, pointwise=False, scale="deviance"): else: raise TypeError('Valid scale values are "deviance", "log", "negative_log"') - n_samples = log_likelihood.chain.size * log_likelihood.draw.size - new_shape = (n_samples, np.product(log_likelihood.shape[2:])) - log_likelihood = log_likelihood.values.reshape(*new_shape) - - lppd_i = _logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0]) + log_likelihood = log_likelihood.stack(samples=("chain", "draw")) + shape = log_likelihood.shape + n_samples = shape[-1] + n_data_points = np.product(shape[:-1]) + + ufunc_kwargs = {"n_dims": 1, "ravel": False} + kwargs = {"input_core_dims": [["samples"]]} + lppd_i = _wrap_xarray_ufunc( + _logsumexp, + log_likelihood, + func_kwargs={"b_inv": n_samples}, + ufunc_kwargs=ufunc_kwargs, + **kwargs + ) - vars_lpd = np.var(log_likelihood, axis=0) + vars_lpd = log_likelihood.var(dim="samples") warn_mg = False if np.any(vars_lpd > 0.4): warnings.warn( @@ -982,9 +1080,9 @@ def waic(data, pointwise=False, scale="deviance"): warn_mg = True waic_i = scale_value * (lppd_i - vars_lpd) - waic_se = (len(waic_i) * np.var(waic_i)) ** 0.5 - waic_sum = np.sum(waic_i) - p_waic = np.sum(vars_lpd) + waic_se = (n_data_points * np.var(waic_i.values)) ** 0.5 + waic_sum = np.sum(waic_i.values) + p_waic = np.sum(vars_lpd.values) if pointwise: if np.equal(waic_sum, waic_i).all(): # pylint: disable=no-member @@ -993,12 +1091,38 @@ def waic(data, pointwise=False, scale="deviance"): the Observed RV in your model to make sure it returns element-wise logp. """ ) - return pd.Series( - data=[waic_sum, waic_se, p_waic, warn_mg, waic_i, scale], - index=["waic", "waic_se", "p_waic", "warning", "waic_i", "waic_scale"], + return ELPDData( + data=[ + waic_sum, + waic_se, + p_waic, + n_samples, + n_data_points, + warn_mg, + waic_i.rename("waic_i"), + scale, + ], + index=[ + "waic", + "waic_se", + "p_waic", + "n_samples", + "n_data_points", + "warning", + "waic_i", + "waic_scale", + ], ) else: - return pd.Series( - data=[waic_sum, waic_se, p_waic, warn_mg, scale], - index=["waic", "waic_se", "p_waic", "warning", "waic_scale"], + return ELPDData( + data=[waic_sum, waic_se, p_waic, n_samples, n_data_points, warn_mg, scale], + index=[ + "waic", + "waic_se", + "p_waic", + "n_samples", + "n_data_points", + "warning", + "waic_scale", + ], ) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 91602c4b84..dc8a1071d0 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -4,13 +4,14 @@ import warnings import numpy as np +import pandas as pd from scipy.fftpack import next_fast_len from scipy.stats.mstats import mquantiles from xarray import apply_ufunc _log = logging.getLogger(__name__) -__all__ = ["autocorr", "autocov", "make_ufunc", "wrap_xarray_ufunc"] +__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"] def autocov(ary, axis=-1): @@ -71,7 +72,9 @@ def autocorr(ary, axis=-1): return corr -def make_ufunc(func, n_dims=2, n_output=1, index=Ellipsis, ravel=True): # noqa: D202 +def make_ufunc( + func, n_dims=2, n_output=1, index=Ellipsis, ravel=True, check_shape=True +): # noqa: D202 """Make ufunc from a function taking 1D array input. Parameters @@ -87,6 +90,9 @@ def make_ufunc(func, n_dims=2, n_output=1, index=Ellipsis, ravel=True): # noqa: Slice ndarray with `index`. Defaults to `Ellipsis`. ravel : bool, optional If true, ravel the ndarray before calling `func`. + check_shape: bool, optional + If false, do not check if the shape of the output is compatible with n_dims and + n_output. Returns ------- @@ -100,7 +106,7 @@ def _ufunc(ary, *args, out=None, **kwargs): """General ufunc for single-output function.""" if out is None: out = np.empty(ary.shape[:-n_dims]) - else: + elif check_shape: if out.shape != ary.shape[:-n_dims]: msg = "Shape incorrect for `out`: {}.".format(out.shape) msg += " Correct shape is {}".format(ary.shape[:-n_dims]) @@ -115,7 +121,7 @@ def _multi_ufunc(ary, *args, out=None, **kwargs): element_shape = ary.shape[:-n_dims] if out is None: out = tuple(np.empty(element_shape) for _ in range(n_output)) - else: + elif check_shape: raise_error = False correct_shape = tuple(element_shape for _ in range(n_output)) if isinstance(out, tuple): @@ -363,3 +369,59 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar _log.warning(error_msg) return nan_error | chain_error | draw_error + + +BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix + +{{0:{0}}} Estimate SE +{{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}} +p_{{kind:{1}}} {{3:8.2f}} -""" +POINTWISE_LOO_FMT = """------ + +Pareto k diagnostic values: + {{0:>{0}}} {{1:>6}} +(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% + (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% + (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% + (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% +""" +SCALE_DICT = {"deviance": "IC", "log": "elpd", "negative_log": "-elpd"} + + +class ELPDData(pd.Series): # pylint: disable=too-many-ancestors + """Class to contain the data from elpd information criterion like waic or loo.""" + + def __str__(self): + """Print elpd data in a user friendly way.""" + kind = self.index[0] + + if kind not in ("waic", "loo"): + raise ValueError("Invalid ELPDData object") + + scale_str = SCALE_DICT[self["{}_scale".format(kind)]] + padding = len(scale_str) + len(kind) + 1 + base = BASE_FMT.format(padding, padding - 2) + base = base.format( + "", + kind=kind, + scale=scale_str, + n_samples=self.n_samples, + n_points=self.n_data_points, + *self.values + ) + + if self.warning: + base += "\n\nThere has been a warning during the calculation. Please check the results." + + if kind == "loo" and "pareto_k" in self: + counts, _ = np.histogram(self.pareto_k, bins=[-np.inf, 0.5, 0.7, 1, np.inf]) + extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) + extended = extended.format( + "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] + ) + base = "\n".join([base, extended]) + return base + + def __repr__(self): + """Alias to ``__str__``.""" + return self.__str__() diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index b99040874c..7ac9c9bf83 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -7,7 +7,7 @@ import pytest from ..data import from_dict, load_arviz_data -from ..stats import compare, psislw +from ..stats import compare, psislw, loo, waic from .helpers import eight_schools_params # pylint: disable=unused-import from ..plots import ( plot_density, @@ -28,6 +28,7 @@ plot_hpd, plot_dist, plot_rank, + plot_elpd, ) np.random.seed(0) @@ -74,7 +75,56 @@ def create_model(seed=10): prior_predictive=prior_predictive, sample_stats_prior=sample_stats_prior, observed_data={"y": data["y"]}, - dims={"y": ["obs_dim"]}, + dims={"y": ["obs_dim"], "log_likelihood": ["obs_dim"]}, + coords={"obs_dim": range(data["J"])}, + ) + return model + + +def create_multidimensional_model(seed=10): + """Create model with fake data.""" + np.random.seed(seed) + nchains = 4 + ndraws = 500 + ndim1 = 5 + ndim2 = 7 + data = { + "y": np.random.normal(size=(ndim1, ndim2)), + "sigma": np.random.normal(size=(ndim1, ndim2)), + } + posterior = { + "mu": np.random.randn(nchains, ndraws), + "tau": abs(np.random.randn(nchains, ndraws)), + "eta": np.random.randn(nchains, ndraws, ndim1, ndim2), + "theta": np.random.randn(nchains, ndraws, ndim1, ndim2), + } + posterior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2)} + sample_stats = { + "energy": np.random.randn(nchains, ndraws), + "diverging": np.random.randn(nchains, ndraws) > 0.90, + "log_likelihood": np.random.randn(nchains, ndraws, ndim1, ndim2), + } + prior = { + "mu": np.random.randn(nchains, ndraws) / 2, + "tau": abs(np.random.randn(nchains, ndraws)) / 2, + "eta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2, + "theta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2, + } + prior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2} + sample_stats_prior = { + "energy": np.random.randn(nchains, ndraws), + "diverging": (np.random.randn(nchains, ndraws) > 0.95).astype(int), + } + model = from_dict( + posterior=posterior, + posterior_predictive=posterior_predictive, + sample_stats=sample_stats, + prior=prior, + prior_predictive=prior_predictive, + sample_stats_prior=sample_stats_prior, + observed_data={"y": data["y"]}, + dims={"y": ["dim1", "dim2"], "log_likelihood": ["dim1", "dim2"]}, + coords={"dim1": range(ndim1), "dim2": range(ndim2)}, ) return model @@ -88,6 +138,15 @@ class Models: return Models() +@pytest.fixture(scope="module") +def multidim_models(): + class Models: + model_1 = create_multidimensional_model(seed=10) + model_2 = create_multidimensional_model(seed=11) + + return Models() + + @pytest.fixture(scope="function", autouse=True) def clean_plots(request, save_figs): """Close plots after each test, optionally save if --save is specified during test invocation""" @@ -763,3 +822,104 @@ def test_fast_kde_cumulative(limits): data = np.random.normal(0, 1, 1000) density_fast = _fast_kde(data, xmin=limits[0], xmax=limits[1], cumulative=True)[0] np.testing.assert_almost_equal(round(density_fast[-1], 3), 1) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"ic": "loo"}, + {"xlabels": True, "scale": "log"}, + {"color": "obs_dim", "xlabels": True}, + {"color": "obs_dim", "legend": True}, + {"ic": "loo", "color": "blue", "coords": {"obs_dim": slice(2, 5)}}, + {"color": np.random.uniform(size=8), "threshold": 0.1}, + ], +) +@pytest.mark.parametrize("add_model", [False, True]) +@pytest.mark.parametrize("use_elpddata", [False, True]) +def test_plot_elpd(models, add_model, use_elpddata, kwargs): + model_dict = {"Model 1": models.model_1, "Model 2": models.model_2} + if add_model: + model_dict["Model 3"] = create_model(seed=12) + + if use_elpddata: + ic = kwargs.get("ic", "waic") + scale = kwargs.get("scale", "deviance") + if ic == "waic": + model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()} + else: + model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()} + + axes = plot_elpd(model_dict, **kwargs) + assert np.all(axes) + if add_model: + assert axes.shape[0] == axes.shape[1] + assert axes.shape[0] == len(model_dict) - 1 + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"ic": "loo"}, + {"xlabels": True, "scale": "log"}, + {"color": "dim1", "xlabels": True}, + {"color": "dim2", "legend": True}, + {"ic": "loo", "color": "blue", "coords": {"dim2": slice(2, 4)}}, + {"color": np.random.uniform(size=35), "threshold": 0.1}, + ], +) +@pytest.mark.parametrize("add_model", [False, True]) +@pytest.mark.parametrize("use_elpddata", [False, True]) +def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs): + model_dict = {"Model 1": multidim_models.model_1, "Model 2": multidim_models.model_2} + if add_model: + model_dict["Model 3"] = create_multidimensional_model(seed=12) + + if use_elpddata: + ic = kwargs.get("ic", "waic") + scale = kwargs.get("scale", "deviance") + if ic == "waic": + model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()} + else: + model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()} + + axes = plot_elpd(model_dict, **kwargs) + assert np.all(axes) + if add_model: + assert axes.shape[0] == axes.shape[1] + assert axes.shape[0] == len(model_dict) - 1 + + +def test_plot_elpd_bad_ic(models): + model_dict = { + "Model 1": waic(models.model_1, pointwise=True), + "Model 2": loo(models.model_2, pointwise=True), + } + with pytest.raises(ValueError): + plot_elpd(model_dict, ic="bad_ic") + + +def test_plot_elpd_ic_error(models): + model_dict = { + "Model 1": waic(models.model_1, pointwise=True), + "Model 2": loo(models.model_2, pointwise=True), + } + with pytest.raises(SyntaxError): + plot_elpd(model_dict) + + +def test_plot_elpd_scale_error(models): + model_dict = { + "Model 1": waic(models.model_1, pointwise=True), + "Model 2": waic(models.model_2, pointwise=True, scale="log"), + } + with pytest.raises(SyntaxError): + plot_elpd(model_dict) + + +def test_plot_elpd_one_model(models): + model_dict = {"Model 1": models.model_1} + with pytest.raises(Exception): + plot_elpd(model_dict) diff --git a/arviz/tests/test_stats.py b/arviz/tests/test_stats.py index 165ec28dbe..24ec39dcdc 100644 --- a/arviz/tests/test_stats.py +++ b/arviz/tests/test_stats.py @@ -1,7 +1,7 @@ # pylint: disable=redefined-outer-name from copy import deepcopy import numpy as np -from numpy.testing import assert_almost_equal, assert_array_almost_equal +from numpy.testing import assert_allclose, assert_array_almost_equal import pytest from scipy.stats import linregress from xarray import Dataset, DataArray @@ -40,7 +40,7 @@ def test_r2_score(): x = np.linspace(0, 1, 100) y = np.random.normal(x, 1) res = linregress(x, y) - assert_almost_equal(res.rvalue ** 2, r2_score(y, res.intercept + res.slope * x).r2, 2) + assert_allclose(res.rvalue ** 2, r2_score(y, res.intercept + res.slope * x).r2, 2) def test_r2_score_multivariate(): @@ -57,8 +57,8 @@ def test_compare_same(centered_eight, method): data_dict = {"first": centered_eight, "second": centered_eight} weight = compare(data_dict, method=method)["weight"] - assert_almost_equal(weight[0], weight[1]) - assert_almost_equal(np.sum(weight), 1.0) + assert_allclose(weight[0], weight[1]) + assert_allclose(np.sum(weight), 1.0) def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight): @@ -76,7 +76,7 @@ def test_compare_different(centered_eight, non_centered_eight, ic, method, scale model_dict = {"centered": centered_eight, "non_centered": non_centered_eight} weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"] assert weight["non_centered"] >= weight["centered"] - assert_almost_equal(np.sum(weight), 1.0) + assert_allclose(np.sum(weight), 1.0) def test_compare_different_size(centered_eight, non_centered_eight): @@ -216,6 +216,15 @@ def test_waic_warning(centered_eight): assert waic(centered_eight, pointwise=True) is not None +@pytest.mark.parametrize("scale", ["deviance", "log", "negative_log"]) +def test_waic_print(centered_eight, scale): + waic_data = waic(centered_eight, scale=scale).__repr__() + waic_pointwise = waic(centered_eight, scale=scale, pointwise=True).__repr__() + assert waic_data is not None + assert waic_pointwise is not None + assert waic_data == waic_pointwise + + def test_loo(centered_eight): assert loo(centered_eight) is not None @@ -265,14 +274,22 @@ def test_loo_warning(centered_eight): assert loo(centered_eight, pointwise=True) is not None +@pytest.mark.parametrize("scale", ["deviance", "log", "negative_log"]) +def test_loo_print(centered_eight, scale): + loo_data = loo(centered_eight, scale=scale).__repr__() + loo_pointwise = loo(centered_eight, scale=scale, pointwise=True).__repr__() + assert loo_data is not None + assert loo_pointwise is not None + assert len(loo_data) < len(loo_pointwise) + assert loo_data == loo_pointwise[: len(loo_data)] + + def test_psislw(): data = load_arviz_data("centered_eight") pareto_k = loo(data, pointwise=True, reff=0.7)["pareto_k"] log_likelihood = data.sample_stats.log_likelihood # pylint: disable=no-member - n_samples = log_likelihood.chain.size * log_likelihood.draw.size - new_shape = (n_samples,) + log_likelihood.shape[2:] - log_likelihood = log_likelihood.values.reshape(*new_shape) - assert_almost_equal(pareto_k, psislw(-log_likelihood, 0.7)[1]) + log_likelihood = log_likelihood.stack(samples=("chain", "draw")) + assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1]) @pytest.mark.parametrize("probs", [True, False]) diff --git a/arviz/tests/test_stats_utils.py b/arviz/tests/test_stats_utils.py index 53c9e101b4..1a23d476a9 100644 --- a/arviz/tests/test_stats_utils.py +++ b/arviz/tests/test_stats_utils.py @@ -4,7 +4,13 @@ import pytest from scipy.special import logsumexp -from ..stats.stats_utils import logsumexp as _logsumexp, make_ufunc, wrap_xarray_ufunc, not_valid +from ..stats.stats_utils import ( + logsumexp as _logsumexp, + make_ufunc, + wrap_xarray_ufunc, + not_valid, + ELPDData, +) @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) @@ -198,3 +204,8 @@ def test_valid_shape(): assert not_valid( np.ones((10, 10)), check_nan=False, shape_kwargs=dict(min_chains=100, min_draws=2) ) + + +def test_elpd_data_error(): + with pytest.raises(ValueError): + ELPDData(data=[0, 1, 2], index=["not IC", "se", "p"]).__repr__() diff --git a/doc/api.rst b/doc/api.rst index fd90fe266c..5d53110abe 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,6 +24,7 @@ Plots plot_khat plot_pair plot_parallel + plot_elpd plot_posterior plot_ppc plot_rank