Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Forest dashboard #16

Merged
merged 6 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arviz_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

from .elpd import dashboard_elpd
from .ppc import dashboard_ppc
from .forest import *
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
from arviz_dashboard.marginal.one_d import posterior_marginal1d
from arviz_dashboard.trace.trace import trace
217 changes: 217 additions & 0 deletions arviz_dashboard/forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import arviz as az
import bokeh.io
import panel as pn
import param
from IPython.display import display

bokeh.io.reset_output()
bokeh.io.output_notebook()

pn.extension()


class ModelVar(param.Parameterized):
model = param.Selector("")
data_variable = param.Selector("")
coor_variable = param.Selector("")

def __init__(self, idatas_cmp, **params) -> None:
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
self.idatas_cmp = idatas_cmp
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
self.default_model = list(self.idatas_cmp.keys())[0]
self.param["model"].objects = list(self.idatas_cmp.keys())
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
self.param["model"].default = self.default_model
self.param["data_variable"].objects = list(
self.idatas_cmp[self.default_model].posterior.data_vars.variables
)
super().__init__(**params)

@param.depends("model", watch=True)
def _update_data_variables(self):
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
data_variables = list(
self.idatas_cmp[self.model].posterior.data_vars.variables
)
self.param["data_variable"].objects = data_variables
if self.data_variable not in data_variables:
self.data_variable = data_variables[0]

@param.depends("data_variable", watch=True)
def _update_coordinates(self):
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
if (
self.idatas_cmp[self.model]
.posterior.data_vars.variables[self.data_variable][0][0]
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
.size
> 1
):
coor_variables = list(
self.idatas_cmp[self.model].posterior.indexes["school"]
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
)
else:
coor_variables = [""]
self.param["coor_variable"].objects = coor_variables
if self.coor_variable not in coor_variables:
self.coor_variable = coor_variables[0]
yilinxia marked this conversation as resolved.
Show resolved Hide resolved


class ForestDashboard(ModelVar):
def __init__(self, idatas_cmp) -> None:
self.idatas_cmp = idatas_cmp
super().__init__(self.idatas_cmp)

def dashboard_forest(self):
# define the widgets
multi_select = pn.widgets.MultiSelect(
name="ModelSelect",
options=list(self.idatas_cmp.keys()),
value=["mA"],
)
thre_slider = pn.widgets.FloatSlider(
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
name="HDI Probability",
start=0,
end=1,
step=0.05,
value=0.7,
width=200,
)
truncate_checkbox = pn.widgets.Checkbox(name="Ridgeplot Truncate")
ridge_quant = pn.widgets.RangeSlider(
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
name="Ridgeplot Quantiles",
start=0,
end=1,
value=(0.25, 0.75),
step=0.01,
width=200,
)
op_slider = pn.widgets.FloatSlider(
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
name="Ridgeplot Overlap",
start=0,
end=1,
step=0.05,
value=0.7,
width=200,
)

rope_slider = pn.widgets.RangeSlider(
name="Rope Range",
start=-10,
end=10,
value=(2, 5),
step=1,
width=200,
)

# construct widget
@pn.depends(
multi_select.param.value,
thre_slider.param.value,
rope_slider.param.value,
self.param.data_variable,
self.param.coor_variable,
)
def get_forest_plot(
multi_select, thre_slider, rope_slider,
data_variable, coor_variable
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
):
# generate graph
data = []
for model_ in multi_select:
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
data.append(self.idatas_cmp[model_])
# add rope
rope = {}
school = {}
school["school"] = coor_variable
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
school["rope"] = rope_slider
rope[data_variable] = [school]
# print(rope)
forest_plt = az.plot_forest(
data,
model_names=multi_select,
rope=rope,
kind="forestplot",
hdi_prob=thre_slider,
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
backend="bokeh",
figsize=(9, 9),
show=False,
combined=True,
colors="cycle",
)
return forest_plt[0][0]
yilinxia marked this conversation as resolved.
Show resolved Hide resolved

@pn.depends(
multi_select.param.value,
thre_slider.param.value,
truncate_checkbox.param.value,
ridge_quant.param.value,
op_slider.param.value,
)
def get_ridge_plot(
multi_select,
thre_slider,
truncate_checkbox,
ridge_quant,
op_slider,
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
):
# calculate the ridgeplot_quantiles
temp_quant = list(ridge_quant)
quant_ls = temp_quant
quant_ls.sort()
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
avg_quant = sum(temp_quant) / 2
if quant_ls[0] < 0.5 and quant_ls[1] > 0.5:
quant_ls.append(0.5)
quant_ls.sort()
else:
quant_ls.append(avg_quant)
quant_ls.sort()

# generate graph
data = []
for model_ in multi_select:
data.append(self.idatas_cmp[model_])

ridge_plt = az.plot_forest(
data,
model_names=multi_select,
kind="ridgeplot",
hdi_prob=thre_slider,
ridgeplot_truncate=truncate_checkbox,
ridgeplot_quantiles=quant_ls,
ridgeplot_overlap=op_slider,
backend="bokeh",
figsize=(9, 9),
show=False,
combined=True,
colors="white",
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
)
return ridge_plt[0][0]

plot_result_1 = pn.Column(
pn.WidgetBox(
"add rope",
pn.Row(
self.param.model,
self.param.data_variable,
self.param.coor_variable,
),
rope_slider,
),
get_forest_plot,
)
plot_result_2 = pn.Column(
pn.Row(truncate_checkbox),
pn.Row(ridge_quant, op_slider),
get_ridge_plot,
)
# show up
display(
yilinxia marked this conversation as resolved.
Show resolved Hide resolved
pn.Column(
pn.Row(multi_select),
thre_slider,
# pn.Row(self.param),
pn.Tabs(
("Forest_Plot", plot_result_1),
(
"Rdiget_Plot",
plot_result_2,
),
),
).servable(),
)
Loading