Skip to content

Commit

Permalink
move out new_ds to arviz-stats
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 12, 2024
1 parent 3498759 commit 9892b0b
Showing 1 changed file with 5 additions and 39 deletions.
44 changes: 5 additions & 39 deletions src/arviz_plots/plots/psensedistplot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""PsenseDist plot code."""
from importlib import import_module

from arviz_base import extract, rcParams
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.psense import _get_power_scale_weights
from arviz_stats.psense import power_scale_dataset
from xarray import concat

from arviz_plots.plot_collection import PlotCollection, process_facet_dims
Expand Down Expand Up @@ -136,16 +136,16 @@ def plot_psense_dist(
# Here we are generating new datasets for the prior and likelihood
# by resampling the original dataset with the power scale weights
# Instead we could have weighted KDEs/ecdfs/etc
ds_prior = new_ds(dt, "prior", alphas, sample_dims=sample_dims)
ds_likelihood = new_ds(dt, "likelihood", alphas, sample_dims=sample_dims)
ds_prior = power_scale_dataset(dt, "prior", alphas, sample_dims=sample_dims)
ds_likelihood = power_scale_dataset(dt, "likelihood", alphas, sample_dims=sample_dims)
distribution = concat([ds_prior, ds_likelihood], dim="component_group").assign_coords(
{"component_group": ["prior", "likelihood"]}
)
distribution = process_group_variables_coords(
distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords
)
if len(sample_dims) > 1:
# sample dims will have been stacked and renamed by `new_ds`
# sample dims will have been stacked and renamed by `power_scale_dataset`
sample_dims = ["sample"]

if backend is None:
Expand Down Expand Up @@ -250,37 +250,3 @@ def plot_psense_dist(
)

return plot_collection


def new_ds(dt, group, alphas, sample_dims):
"""Resample the dataset with the power scale weights."""
lower_w, upper_w = _get_power_scale_weights(dt, alphas, group=group, sample_dims=sample_dims)
lower_w = lower_w.values.flatten()
upper_w = upper_w.values.flatten()
s_size = len(lower_w)

idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims
idxs_to_drop = set(idxs_to_drop).union(
[
idx
for idx in dt["posterior"].xindexes
if any(dim in dt["posterior"][idx].dims for dim in sample_dims)
]
)
resampled = [
extract(
dt,
group="posterior",
sample_dims=sample_dims,
num_samples=s_size,
weights=weights,
random_seed=42,
resampling_method="stratified",
).drop_indexes(idxs_to_drop)
for weights in (lower_w, upper_w)
]
resampled.insert(
1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop)
)

return concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]])

0 comments on commit 9892b0b

Please sign in to comment.