Skip to content

Commit

Permalink
get_bins working update
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Sep 2, 2024
1 parent 00c4996 commit 006010d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dynamic = ["version", "description"]
dependencies = [
"arviz-base==0.2",
"arviz-stats[xarray]==0.2",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@get_bins",
]

[tool.flit.module]
Expand Down
40 changes: 5 additions & 35 deletions src/arviz_plots/plots/rootogramplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.base import array_stats

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
Expand Down Expand Up @@ -176,7 +175,6 @@ def plot_rootogram(
filter_vars=filter_vars,
coords=coords,
)
# print(f"\npp_distribution = {pp_distribution}")

total_pp_samples = np.prod(
[pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims]
Expand Down Expand Up @@ -218,8 +216,6 @@ def plot_rootogram(
# if "remove_axis" in plot_kwargs:
plot_kwargs_dist["remove_axis"] = False # plot_kwargs["remove_axis"]

# print(f"\n aes_map = {aes_map}")

# obs distribution calculated outside `if observed` since plotting predictive bars requires it
observed_data_group = "observed_data"
obs_distribution = process_group_variables_coords(
Expand All @@ -229,37 +225,29 @@ def plot_rootogram(
filter_vars=filter_vars,
coords=coords,
)
# print(f"\n obs_distribution = {obs_distribution}")

# ---------(observed data)-----------
# observed data calculations are made outside of and before 'if observed' since predictive also
# depends on this computed data (number of bins and top of predictive bars for rootograms)

# use get_bins func from arviz-stats on observed data and then use those bins for
# computing histograms for predictive data as well
# WIP: currently only the bins for one variable (without any facetting) is retrieved and used
bins = array_stats.get_bins(obs_distribution["home_points"].values)
print(f"\n bins = {bins}")

# this portion is situated in an out of convention spot becuse obs_hist_dims is required
obs_hist_dims, obs_hist_aes, obs_hist_ignore = filter_aes(
plot_collection, aes_map, "observed", reduce_dims
)

# use get_bins func from arviz-stats on observed data and then use those bins for
# computing histograms for predictive data as well
bins = obs_distribution.azstats.get_bins(dims=list(obs_hist_dims) + list(sample_dims))

obs_stats_kwargs = copy(stats_kwargs.get("observed", {}))
obs_stats_kwargs.setdefault("bins", bins)

obs_hist = obs_distribution.azstats.histogram(
dims=list(obs_hist_dims) + list(sample_dims), **obs_stats_kwargs
)

# print(f"\n obs_hist = {obs_hist}")

obs_hist.loc[{"plot_axis": "histogram"}] = (obs_hist.sel(plot_axis="histogram")) ** 0.5

# print(f"\n obs_density.data_vars = {obs_density.data_vars}")
# print(f"\n obs_density.keys() = {obs_density.keys()}")

# new_obs_hist with histogram->y and left_edge/right_edge midpoint->x
new_obs_hist = xr.Dataset()

Expand All @@ -270,24 +258,16 @@ def plot_rootogram(
left_edges = np.array(left_edges)
right_edges = np.array(right_edges)

# print(f"\n left_edges = {left_edges}")
# print(f"\n right_edges = {right_edges}")

x = (left_edges + right_edges) / 2
y = obs_hist[var_name].sel(plot_axis="histogram").values

# print(f"\n new_obs_hist y= {y}")
# print(f"x = {x} | y = {y}")

stacked_data = np.stack((x, y), axis=-1)
new_var = xr.DataArray(
stacked_data, dims=["hist_dim", "plot_axis"], coords={"plot_axis": ["x", "y"]}
)

new_obs_hist[var_name] = new_var

print(f"\n new_obs_hist = {new_obs_hist}")

# ---------(PPC data)-------------

min_bottom = xr.Dataset() # minimum value of the histogram 'bottoms, for observed_rug 'y'
Expand All @@ -313,8 +293,6 @@ def plot_rootogram(
dims=list(pp_hist_dims) + list(sample_dims), **pp_stats_kwargs
)

# print(f"\n pp_density histogram form pp_hist = {pp_hist}")

# the top of the predictive bars height = the observed height for that bin
# the bottom = difference between observed and predictive height for that bin

Expand All @@ -332,10 +310,8 @@ def plot_rootogram(

# getting top of histogram (observed values dataset's 'y' coord)
new_histogram = new_obs_hist[var_name].sel(plot_axis="y").values
# print(f"\n new_pp_hist (new_histogram) observed heights= {new_histogram}")

histogram_bottom = new_histogram - pp_hist[var_name].sel(plot_axis="histogram").values
# print(f"\n new_pp_hist histogram_bottom= {histogram_bottom}")

stacked_data = np.stack(
(new_histogram, left_edges, right_edges, histogram_bottom), axis=-1
Expand All @@ -352,8 +328,6 @@ def plot_rootogram(
min_histogram_bottom = min(histogram_bottom)
min_bottom[var_name] = min_histogram_bottom - (0.2 * (0 - min_histogram_bottom))

print(f"\n new_pp_hist = {new_pp_hist}")

plot_collection.map(
hist, "predictive", data=new_pp_hist, ignore_aes=pp_hist_ignore, **pp_kwargs
)
Expand Down Expand Up @@ -411,18 +385,14 @@ def plot_rootogram(
if "size" not in rug_aes:
rug_kwargs.setdefault("size", 30)

# rug_kwargs.setdefault("y", min_bottom)
print(f"\n min_bottom = {min_bottom}")

# print(f"\nobs_distribution = {obs_distribution}")
rug_kwargs.setdefault("y", min_bottom)

plot_collection.map(
trace_rug,
"observed_rug",
data=obs_distribution,
ignore_aes=rug_ignore,
xname=False,
y=min_bottom,
**rug_kwargs,
)

Expand Down

0 comments on commit 006010d

Please sign in to comment.