Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes to matplotlib subplot handling #1205

Merged
merged 6 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,49 +199,52 @@ def plot_pair(
ax.tick_params(labelsize=xt_labelsize)

else:
not_marginals = int(not marginals)
num_subplot_cols = numvars - not_marginals
max_plots = (
numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
num_subplot_cols ** 2
if rcParams["plot.max_subplots"] is None
else rcParams["plot.max_subplots"]
)
vars_to_plot = np.sum(np.arange(numvars).cumsum() < max_plots)
if vars_to_plot < numvars:
cols_to_plot = np.sum(np.arange(1, num_subplot_cols + 1).cumsum() <= max_plots)
if cols_to_plot < num_subplot_cols:
vars_to_plot = cols_to_plot
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
"of resulting pair plots with these variables, generating only a "
"{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot),
UserWarning,
)
numvars = vars_to_plot
else:
vars_to_plot = numvars - not_marginals

(figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 2, numvars - 2
figsize, textsize, vars_to_plot, vars_to_plot
)

point_estimate_marker_kwargs.setdefault("s", markersize + 50)

if ax is None:
fig, ax = plt.subplots(numvars, numvars, figsize=figsize, **backend_kwargs)
fig, ax = plt.subplots(vars_to_plot, vars_to_plot, figsize=figsize, **backend_kwargs,)
hexbin_values = []
for i in range(0, numvars):
for i in range(0, vars_to_plot):
var1 = infdata_group[i]

for j in range(0, numvars):
var2 = infdata_group[j]
for j in range(0, vars_to_plot):
var2 = infdata_group[j + not_marginals]
if i > j:
if ax[j, i].get_figure() is not None:
ax[j, i].remove()
continue

elif i == j:
if marginals:
loc = "right"
plot_dist(var1, ax=ax[i, j], **marginal_kwargs)
else:
loc = "left"
if ax[j, i].get_figure() is not None:
ax[j, i].remove()
continue
elif i == j and marginals:
loc = "right"
plot_dist(var1, ax=ax[i, j], **marginal_kwargs)

else:
if i == j:
loc = "left"

if "scatter" in kind:
ax[j, i].plot(var1, var2, **scatter_kwargs)

Expand Down Expand Up @@ -285,15 +288,15 @@ def plot_pair(

if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j]
y_name = flat_var_names[j + not_marginals]
if x_name and y_name not in difference:
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
**reference_values_kwargs,
)

if j != numvars - 1:
if j != vars_to_plot - 1:
ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
else:
ax[j, i].set_xlabel(
Expand All @@ -303,7 +306,9 @@ def plot_pair(
ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter())
else:
ax[j, i].set_ylabel(
"{}".format(flat_var_names[j]), fontsize=ax_labelsize, wrap=True
"{}".format(flat_var_names[j + not_marginals]),
fontsize=ax_labelsize,
wrap=True,
)
ax[j, i].tick_params(labelsize=xt_labelsize)

Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def plot_trace(

plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims))
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max_plots
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
if len(plotters) > max_plots:
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_plot_trace_discrete(discrete_model):

def test_plot_trace_max_subplots_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 1}):
with rc_context(rc={"plot.max_subplots": 2}):
axes = plot_trace(models.model_1, backend="bokeh", show=False)
assert axes.shape

Expand Down
19 changes: 17 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def test_plot_trace_discrete(discrete_model):

def test_plot_trace_max_subplots_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 1}):
with rc_context(rc={"plot.max_subplots": 6}):
axes = plot_trace(models.model_1)
assert axes.shape
assert axes.shape == (3, 2)


@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
Expand Down Expand Up @@ -478,6 +478,21 @@ def test_plot_pair_overlaid(models, kwargs):
assert ax.shape


@pytest.mark.parametrize("marginals", [True, False])
@pytest.mark.parametrize("max_subplots", [True, False])
def test_plot_pair_shapes(marginals, max_subplots):
rng = np.random.default_rng()
idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
if max_subplots:
with rc_context({"plot.max_subplots": 6}):
with pytest.warns(UserWarning, match="3x3 grid"):
ax = plot_pair(idata, marginals=marginals)
else:
ax = plot_pair(idata, marginals=marginals)
side = 3 if max_subplots else (4 + marginals)
assert ax.shape == (side, side)


@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
@pytest.mark.parametrize("animated", [False, True])
Expand Down