Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing axes to plot_density fails with several datasets #1197

Closed
astoeriko opened this issue May 20, 2020 · 2 comments · Fixed by #1198
Closed

Passing axes to plot_density fails with several datasets #1197

astoeriko opened this issue May 20, 2020 · 2 comments · Fixed by #1198
Assignees

Comments

@astoeriko
Copy link

Describe the bug
I would like to pass existing axes to plot_density with the ax keyword. This works fine if I plot data from a single dataset. But I get an error when I try to plot several datasets at once (I think the error arises when it tries to produce a legend).
Plotting several datasets without providing axes also works fine.

To Reproduce

import arviz
import matplotlib.pyplot as plt

test_data = arviz.load_arviz_data('centered_eight')

# This works
fig, ax = plt.subplots(3, 3)
ax1 = arviz.plot_density(data=[test_data.posterior], var_names=['mu', 'theta'], ax=ax);
# This works as well
ax2 = arviz.plot_density(data=[test_data.prior, test_data.posterior], var_names=['mu', 'theta']);
# This does not work
fig3, ax3 = plt.subplots(3, 3)
arviz.plot_density(data=[test_data.prior, test_data.posterior], var_names=['mu', 'theta'], ax=ax3);

This is the error I get:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-133-7feead268671> in <module>
     11 # This does not work
     12 fig3, ax3 = plt.subplots(3, 3)
---> 13 arviz.plot_density(data=[test_data.prior, test_data.posterior], var_names=['mu', 'theta'], ax=ax3);

~/miniconda3/envs/sunode/lib/python3.7/site-packages/arviz/plots/densityplot.py in plot_density(data, group, data_labels, var_names, transform, credible_interval, point_estimate, colors, outline, hpd_markers, shade, bw, figsize, textsize, ax, backend, backend_kwargs, show)
    263     # TODO: Add backend kwargs
    264     plot = get_plotting_function("plot_density", "densityplot", backend)
--> 265     ax = plot(**plot_density_kwargs)
    266     return ax

~/miniconda3/envs/sunode/lib/python3.7/site-packages/arviz/plots/backends/matplotlib/densityplot.py in plot_density(ax, all_labels, to_plot, colors, bw, figsize, length_plotters, rows, cols, titlesize, xt_labelsize, linewidth, markersize, credible_interval, point_estimate, hpd_markers, outline, shade, n_data, data_labels, backend_kwargs, show)
     75     if n_data > 1:
     76         for m_idx, label in enumerate(data_labels):
---> 77             ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize)
     78         ax[0].legend(fontsize=xt_labelsize)
     79 

AttributeError: 'numpy.ndarray' object has no attribute 'plot'

Additional context
ArviZ version: 0.7.0
matplotlib version: 3.2.1

@aloctavodia aloctavodia self-assigned this May 20, 2020
@aloctavodia
Copy link
Contributor

Thanks for reporting the issue. I will take a look at it.

@astoeriko
Copy link
Author

Thanks for the fix. :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants