Skip to content

Commit

Permalink
optimize interface (#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhishi authored May 20, 2022
1 parent f7303bc commit c251b70
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 68 deletions.
109 changes: 54 additions & 55 deletions docs/tutorials/model_diagnostics.ipynb

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions orbit/diagnostics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,9 @@ def metric_horizon_barplot(

@orbit_style_decorator
def params_comparison_boxplot(
model_name_list,
data_list,
label_list,
data,
var_names,
model_names,
color_list=sns.color_palette(),
title="Params Comparison",
fig_size=(10, 6),
Expand All @@ -654,10 +654,9 @@ def params_comparison_boxplot(
):
"""compare the distribution of parameters from different models uisng a boxplot.
Parameters:
model_name_list : a list of strings, the names of models
data_list : a list of np.arrays, the distributions of parameters to compare
label_list : a list of strings, the labels of the parameters to compare
(the order of labels must match the order of the data in the data_list)
data : a list of dict with keys as the parameters of interest
var_names : a list of strings, the labels of the parameters to compare
model_names : a list of strings, the names of models to compare
color_list : a list of strings, the color to use for differentiating models
title : string
the title of the chart
Expand All @@ -676,7 +675,7 @@ def params_comparison_boxplot(

fig, ax = plt.subplots(1, 1, figsize=fig_size)
handles = []
n_models = len(model_name_list)
n_models = len(model_names)
pos = []

if n_models % 2 == 0:
Expand All @@ -692,10 +691,14 @@ def params_comparison_boxplot(

pos = sorted(pos)

for i in range(len(data_list)):
for i in range(len(model_names)):
plt_arr = []
for var in var_names:
plt_arr.append(data[i][var].flatten())
plt_arr = np.vstack(plt_arr).T
globals()[f"bp{i}"] = ax.boxplot(
data_list[i],
positions=np.arange(data_list[i].shape[1]) + pos[i],
plt_arr,
positions=np.arange(plt_arr.shape[1]) + pos[i],
widths=box_width,
patch_artist=True,
manage_ticks=False,
Expand All @@ -705,8 +708,8 @@ def params_comparison_boxplot(
)
handles.append(globals()[f"bp{i}"]["boxes"][0])

plt.xticks(np.arange(len(label_list)), label_list)
ax.legend(handles, model_name_list)
plt.xticks(np.arange(len(var_names)), var_names)
ax.legend(handles, model_names)
plt.xlabel("params")
plt.ylabel("value")
plt.title(title)
Expand Down

0 comments on commit c251b70

Please sign in to comment.