From 006010deb3ba7e64195d0421cb6be0d779636c30 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Mon, 2 Sep 2024 18:03:03 +0530 Subject: [PATCH] get_bins working update --- pyproject.toml | 2 +- src/arviz_plots/plots/rootogramplot.py | 40 ++++---------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9116756..0539086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray]==0.2", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@get_bins", ] [tool.flit.module] diff --git a/src/arviz_plots/plots/rootogramplot.py b/src/arviz_plots/plots/rootogramplot.py index 8d9487f..e6ef20a 100644 --- a/src/arviz_plots/plots/rootogramplot.py +++ b/src/arviz_plots/plots/rootogramplot.py @@ -8,7 +8,6 @@ import xarray as xr from arviz_base import rcParams from arviz_base.labels import BaseLabeller -from arviz_stats.base import array_stats from arviz_plots.plot_collection import PlotCollection from arviz_plots.plots.utils import filter_aes, process_group_variables_coords @@ -176,7 +175,6 @@ def plot_rootogram( filter_vars=filter_vars, coords=coords, ) - # print(f"\npp_distribution = {pp_distribution}") total_pp_samples = np.prod( [pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims] @@ -218,8 +216,6 @@ def plot_rootogram( # if "remove_axis" in plot_kwargs: plot_kwargs_dist["remove_axis"] = False # plot_kwargs["remove_axis"] - # print(f"\n aes_map = {aes_map}") - # obs distribution calculated outside `if observed` since plotting predictive bars requires it observed_data_group = "observed_data" obs_distribution = process_group_variables_coords( @@ -229,23 +225,20 @@ def plot_rootogram( filter_vars=filter_vars, coords=coords, ) - # print(f"\n obs_distribution = {obs_distribution}") # ---------(observed data)----------- # observed data calculations are made outside of and before 'if observed' since predictive also # depends on this computed data (number of bins and top of predictive bars for rootograms) - # use get_bins func from arviz-stats on observed data and then use those bins for - # computing histograms for predictive data as well - # WIP: currently only the bins for one variable (without any facetting) is retrieved and used - bins = array_stats.get_bins(obs_distribution["home_points"].values) - print(f"\n bins = {bins}") - # this portion is situated in an out of convention spot becuse obs_hist_dims is required obs_hist_dims, obs_hist_aes, obs_hist_ignore = filter_aes( plot_collection, aes_map, "observed", reduce_dims ) + # use get_bins func from arviz-stats on observed data and then use those bins for + # computing histograms for predictive data as well + bins = obs_distribution.azstats.get_bins(dims=list(obs_hist_dims) + list(sample_dims)) + obs_stats_kwargs = copy(stats_kwargs.get("observed", {})) obs_stats_kwargs.setdefault("bins", bins) @@ -253,13 +246,8 @@ def plot_rootogram( dims=list(obs_hist_dims) + list(sample_dims), **obs_stats_kwargs ) - # print(f"\n obs_hist = {obs_hist}") - obs_hist.loc[{"plot_axis": "histogram"}] = (obs_hist.sel(plot_axis="histogram")) ** 0.5 - # print(f"\n obs_density.data_vars = {obs_density.data_vars}") - # print(f"\n obs_density.keys() = {obs_density.keys()}") - # new_obs_hist with histogram->y and left_edge/right_edge midpoint->x new_obs_hist = xr.Dataset() @@ -270,15 +258,9 @@ def plot_rootogram( left_edges = np.array(left_edges) right_edges = np.array(right_edges) - # print(f"\n left_edges = {left_edges}") - # print(f"\n right_edges = {right_edges}") - x = (left_edges + right_edges) / 2 y = obs_hist[var_name].sel(plot_axis="histogram").values - # print(f"\n new_obs_hist y= {y}") - # print(f"x = {x} | y = {y}") - stacked_data = np.stack((x, y), axis=-1) new_var = xr.DataArray( stacked_data, dims=["hist_dim", "plot_axis"], coords={"plot_axis": ["x", "y"]} @@ -286,8 +268,6 @@ def plot_rootogram( new_obs_hist[var_name] = new_var - print(f"\n new_obs_hist = {new_obs_hist}") - # ---------(PPC data)------------- min_bottom = xr.Dataset() # minimum value of the histogram 'bottoms, for observed_rug 'y' @@ -313,8 +293,6 @@ def plot_rootogram( dims=list(pp_hist_dims) + list(sample_dims), **pp_stats_kwargs ) - # print(f"\n pp_density histogram form pp_hist = {pp_hist}") - # the top of the predictive bars height = the observed height for that bin # the bottom = difference between observed and predictive height for that bin @@ -332,10 +310,8 @@ def plot_rootogram( # getting top of histogram (observed values dataset's 'y' coord) new_histogram = new_obs_hist[var_name].sel(plot_axis="y").values - # print(f"\n new_pp_hist (new_histogram) observed heights= {new_histogram}") histogram_bottom = new_histogram - pp_hist[var_name].sel(plot_axis="histogram").values - # print(f"\n new_pp_hist histogram_bottom= {histogram_bottom}") stacked_data = np.stack( (new_histogram, left_edges, right_edges, histogram_bottom), axis=-1 @@ -352,8 +328,6 @@ def plot_rootogram( min_histogram_bottom = min(histogram_bottom) min_bottom[var_name] = min_histogram_bottom - (0.2 * (0 - min_histogram_bottom)) - print(f"\n new_pp_hist = {new_pp_hist}") - plot_collection.map( hist, "predictive", data=new_pp_hist, ignore_aes=pp_hist_ignore, **pp_kwargs ) @@ -411,10 +385,7 @@ def plot_rootogram( if "size" not in rug_aes: rug_kwargs.setdefault("size", 30) - # rug_kwargs.setdefault("y", min_bottom) - print(f"\n min_bottom = {min_bottom}") - - # print(f"\nobs_distribution = {obs_distribution}") + rug_kwargs.setdefault("y", min_bottom) plot_collection.map( trace_rug, @@ -422,7 +393,6 @@ def plot_rootogram( data=obs_distribution, ignore_aes=rug_ignore, xname=False, - y=min_bottom, **rug_kwargs, )