Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pair_plot_focus #90

Open
OriolAbril opened this issue Sep 11, 2024 · 1 comment
Open

Add pair_plot_focus #90

OriolAbril opened this issue Sep 11, 2024 · 1 comment

Comments

@OriolAbril
Copy link
Member

Talking with @daniel-saunders-phil he mentioned he often wanted to generate pair plots of a specific variable against several others. I think we could have a function that makes that easy (and paves the way with visuals and computation for plot_pair once we have the PlotMatrix manager) because as we compare multiple variables to always the same one the facetting/mapping logic is the one in PlotCollection.

I was thinking the API would be similar to a regular plot with dt, group, var_names... and then a target (or similar) name where you provide either a string (var_name in dt input) or a DataArray.


pseudo example

from arviz_base import load_arviz_data
from arviz_base.labels import BaseLabeller, NoVarLabeller
import arviz_plots as azp

azp.style.use("arviz-clean")
plt.rcParams["font.size"] = 8  # clean theme sets font too big

idata = load_arviz_data("centered_eight")
target_name = "mu"
pc = azp.PlotCollection.wrap(
    idata.posterior.ds[["theta"]],
    cols=["school"],
    col_wrap=4,
    plot_grid_kws={"figsize": (8, 3), "sharey": True},
)
pc.map(
    azp.visuals.scatter_x,
    "samples",
    y=idata.posterior[target_name],
    marker=".",
    alpha=.4,
    edgecolor="none",
)
pc.map(
    azp.visuals.labelled_title,
    "title",
    subset_info=True,
    labeller=NoVarLabeller(),
    ignore_aes={"color"},
    size=8,
)

pc.viz["chart"].item().suptitle(f"Pair plots of theta subsets vs {target_name}");

imatge

@daniel-saunders-phil
Copy link

daniel-saunders-phil commented Sep 30, 2024

I spent some time playing with this and it's really handy. I wanted to add two features but couldn't find the path forward. How might I color different chains differently? Also, how might I flag divergences?

Edit: the chains can be colored differently by

pc = azp.PlotCollection.wrap(
    idata.posterior.ds[["theta"]],
    cols=["school"],
    col_wrap=4,
    plot_grid_kws={"figsize": (8, 3), "sharey": True},
    aes={"color": ["chain"]},
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants