Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve plot clarity with combined mean and HDI legend elements #145

Merged
merged 5 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions causalpy/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib.collections import PolyCollection
from matplotlib.lines import Line2D


def plot_xY(
Expand All @@ -13,28 +15,35 @@ def plot_xY(
ax: plt.Axes,
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,
hdi_prob: float = 0.94,
include_label: bool = True,
) -> None:
label: Union[str, None] = None,
) -> Tuple[Line2D, PolyCollection]:
"""Utility function to plot HDI intervals."""

if plot_hdi_kwargs is None:
plot_hdi_kwargs = {}

az.plot_hdi(
(h_line,) = ax.plot(
x,
Y.mean(dim=["chain", "draw"]),
ls="-",
**plot_hdi_kwargs,
label=f"{label}",
)
ax_hdi = az.plot_hdi(
x,
Y,
hdi_prob=hdi_prob,
fill_kwargs={
"alpha": 0.25,
"label": f"{hdi_prob*100}% HDI" if include_label else None,
"label": " ",
},
smooth=False,
ax=ax,
**plot_hdi_kwargs,
)
ax.plot(
x,
Y.mean(dim=["chain", "draw"]),
color="k",
label="Posterior mean" if include_label else None,
)
# Return handle to patch. We get a list of the childen of the axis. Filter for just
# the PolyCollection objects. Take the last one.
h_patch = list(
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
)[-1]
return (h_line, h_patch)
125 changes: 88 additions & 37 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def __init__(
# causal impact pre (ie the residuals of the model fit to observed)
pre_data = xr.DataArray(self.pre_y[:, 0], dims=["obs_ind"])
self.pre_impact = (
pre_data - self.pre_pred["posterior_predictive"].y_hat
pre_data - self.pre_pred["posterior_predictive"].mu
).transpose(..., "obs_ind")

# causal impact post (ie the residuals of the model fit to observed)
post_data = xr.DataArray(self.post_y[:, 0], dims=["obs_ind"])
self.post_impact = (
post_data - self.post_pred["posterior_predictive"].y_hat
post_data - self.post_pred["posterior_predictive"].mu
).transpose(..., "obs_ind")

# cumulative impact post
Expand All @@ -118,31 +118,43 @@ def plot(self):

# TOP PLOT --------------------------------------------------
# pre-intervention period
plot_xY(
h_line, h_patch = plot_xY(
self.datapre.index,
self.pre_pred["posterior_predictive"].y_hat,
self.pre_pred["posterior_predictive"].mu,
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
)
ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
handles = [(h_line, h_patch)]
labels = ["Pre-intervention period"]

(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
handles.append(h)
labels.append("Observations")

# post intervention period
plot_xY(
h_line, h_patch = plot_xY(
self.datapost.index,
self.post_pred["posterior_predictive"].y_hat,
self.post_pred["posterior_predictive"].mu,
ax=ax[0],
include_label=False,
plot_hdi_kwargs={"color": "C1"},
)
handles.append((h_line, h_patch))
labels.append("Synthetic control")

ax[0].plot(self.datapost.index, self.post_y, "k.")
# Shaded causal effect
ax[0].fill_between(
h = ax[0].fill_between(
self.datapost.index,
y1=az.extract(
self.post_pred, group="posterior_predictive", var_names="y_hat"
self.post_pred, group="posterior_predictive", var_names="mu"
).mean("sample"),
y2=np.squeeze(self.post_y),
color="C0",
alpha=0.25,
label="Causal impact",
)
handles.append(h)
labels.append("Causal impact")

ax[0].set(
title=f"""
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
Expand All @@ -155,12 +167,13 @@ def plot(self):
self.datapre.index,
self.pre_impact,
ax=ax[1],
plot_hdi_kwargs={"color": "C0"},
)
plot_xY(
self.datapost.index,
self.post_impact,
ax=ax[1],
include_label=False,
plot_hdi_kwargs={"color": "C1"},
)
ax[1].axhline(y=0, c="k")
ax[1].fill_between(
Expand All @@ -173,12 +186,12 @@ def plot(self):
ax[1].set(title="Causal Impact")

# BOTTOM PLOT -----------------------------------------------

ax[2].set(title="Cumulative Causal Impact")
plot_xY(
self.datapost.index,
self.post_impact_cumulative,
ax=ax[2],
plot_hdi_kwargs={"color": "C1"},
)
ax[2].axhline(y=0, c="k")

Expand All @@ -189,10 +202,13 @@ def plot(self):
ls="-",
lw=3,
color="r",
label="Treatment time",
)

ax[0].legend(fontsize=LEGEND_FONT_SIZE)
ax[0].legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)

return (fig, ax)

Expand Down Expand Up @@ -353,39 +369,46 @@ def __init__(
)

def plot(self):
"""Plot the results"""
"""Plot the results.
Creating the combined mean + HDI legend entries is a bit involved.
"""
fig, ax = plt.subplots()

# Plot raw data
# NOTE: This will not work when there is just ONE unit in each group
sns.lineplot(
sns.scatterplot(
self.data,
x=self.time_variable_name,
y=self.outcome_variable_name,
hue=self.group_variable_name,
units="unit", # NOTE: assumes we have a `unit` predictor variable
estimator=None,
alpha=0.5,
alpha=1,
legend=False,
markers=True,
ax=ax,
)

# Plot model fit to control group
time_points = self.x_pred_control[self.time_variable_name].values
plot_xY(
h_line, h_patch = plot_xY(
time_points,
self.y_pred_control.posterior_predictive.y_hat,
self.y_pred_control.posterior_predictive.mu,
ax=ax,
plot_hdi_kwargs={"color": "C0"},
label="Control group",
)
handles = [(h_line, h_patch)]
labels = ["Control group"]

# Plot model fit to treatment group
time_points = self.x_pred_control[self.time_variable_name].values
plot_xY(
h_line, h_patch = plot_xY(
time_points,
self.y_pred_treatment.posterior_predictive.y_hat,
self.y_pred_treatment.posterior_predictive.mu,
ax=ax,
plot_hdi_kwargs={"color": "C1"},
label="Treatment group",
)
handles.append((h_line, h_patch))
labels.append("Treatment group")

# Plot counterfactual - post-test for treatment group IF no treatment
# had occurred.
Expand All @@ -403,26 +426,34 @@ def plot(self):
widths=0.2,
)
for pc in parts["bodies"]:
pc.set_facecolor("C2")
pc.set_facecolor("C0")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
else:
plot_xY(
h_line, h_patch = plot_xY(
time_points,
self.y_pred_counterfactual.posterior_predictive.y_hat,
self.y_pred_counterfactual.posterior_predictive.mu,
ax=ax,
plot_hdi_kwargs={"color": "C2"},
label="Counterfactual",
)
handles.append((h_line, h_patch))
labels.append("Counterfactual")

# arrow to label the causal impact
self._plot_causal_impact_arrow(ax)

# formatting
ax.set(
xticks=self.x_pred_treatment[self.time_variable_name].values,
title=self._causal_impact_summary_stat(),
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
ax.legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)
return fig, ax

def _plot_causal_impact_arrow(self, ax):
"""
Expand Down Expand Up @@ -582,12 +613,17 @@ def plot(self):
c="k", # hue="treated",
ax=ax,
)

# Plot model fit to data
plot_xY(
h_line, h_patch = plot_xY(
self.x_pred[self.running_variable_name],
self.pred["posterior_predictive"].mu,
ax=ax,
plot_hdi_kwargs={"color": "C1"},
)
handles = [(h_line, h_patch)]
labels = ["Posterior mean"]

# create strings to compose title
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
r2 = f"Bayesian $R^2$ on all data = {title_info}"
Expand All @@ -605,7 +641,11 @@ def plot(self):
color="r",
label="treatment threshold",
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
ax.legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)
return (fig, ax)

def summary(self):
Expand Down Expand Up @@ -710,27 +750,38 @@ def plot(self):
hue="group",
alpha=0.5,
data=self.data,
legend=True,
ax=ax[0],
)
ax[0].set(xlabel="Pretest", ylabel="Posttest")

# plot posterior predictive of untreated
plot_xY(
h_line, h_patch = plot_xY(
self.pred_xi,
self.pred_untreated["posterior_predictive"].y_hat,
self.pred_untreated["posterior_predictive"].mu,
ax=ax[0],
plot_hdi_kwargs={"color": "C0"},
label="Control group",
)
handles = [(h_line, h_patch)]
labels = ["Control group"]

# plot posterior predictive of treated
plot_xY(
h_line, h_patch = plot_xY(
self.pred_xi,
self.pred_treated["posterior_predictive"].y_hat,
self.pred_treated["posterior_predictive"].mu,
ax=ax[0],
plot_hdi_kwargs={"color": "C1"},
label="Treatment group",
)
handles.append((h_line, h_patch))
labels.append("Treatment group")

ax[0].legend(fontsize=LEGEND_FONT_SIZE)
ax[0].legend(
handles=(h_tuple for h_tuple in handles),
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)

# Plot estimated caual impact / treatment effect
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/ancova_pymc.ipynb

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

58 changes: 29 additions & 29 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

270 changes: 246 additions & 24 deletions docs/notebooks/generate_plots.ipynb

Large diffs are not rendered by default.

30 changes: 14 additions & 16 deletions docs/notebooks/geolift1.ipynb

Large diffs are not rendered by default.

Loading