Skip to content

Commit

Permalink
#76 #44 start to fix did plot
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Nov 19, 2022
1 parent 7840611 commit 90cc898
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,19 @@ def plot(self):
fig, ax = plt.subplots()

# Plot raw data
sns.lineplot(
self.data,
x=self.time_variable_name,
y=self.outcome_variable_name,
hue=self.group_variable_name,
units="unit",
estimator=None,
alpha=0.25,
ax=ax,
)
# NOTE: This will not work when there is just ONE unit in each group
# sns.lineplot(
# self.data,
# x=self.time_variable_name,
# y=self.outcome_variable_name,
# hue=self.group_variable_name,
# # units="unit",
# estimator=None,
# alpha=0.25,
# ax=ax,
# )
# Plot model fit to control group
# NOTE: This will not work when there is just ONE unit in each group
parts = ax.violinplot(
az.extract(
self.y_pred_control, group="posterior_predictive", var_names="mu"
Expand All @@ -328,6 +330,7 @@ def plot(self):
pc.set_alpha(0.5)

# Plot model fit to treatment group
# NOTE: This will not work when there is just ONE unit in each group
parts = ax.violinplot(
az.extract(
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
Expand All @@ -337,18 +340,19 @@ def plot(self):
showmedians=False,
widths=0.2,
)
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
parts = ax.violinplot(
az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
).values.T,
positions=self.x_pred_counterfactual[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)
# # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
# # NOTE: This will not work when there is just ONE unit in each group
# parts = ax.violinplot(
# az.extract(
# self.y_pred_counterfactual,
# group="posterior_predictive",
# var_names="mu",
# ).values.T,
# positions=self.x_pred_counterfactual[self.time_variable_name].values,
# showmeans=False,
# showmedians=False,
# widths=0.2,
# )
# arrow to label the causal impact
y_pred_treatment = (
self.y_pred_treatment["posterior_predictive"]
Expand Down Expand Up @@ -378,9 +382,9 @@ def plot(self):
)
# formatting
ax.set(
xlim=[-0.15, 1.25],
xticks=[0, 1],
xticklabels=["pre", "post"],
# xlim=[-0.15, 1.25],
xticks=self.x_pred_treatment[self.time_variable_name].values,
# xticklabels=["pre", "post"],
title=self._causal_impact_summary_stat(),
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
Expand Down

0 comments on commit 90cc898

Please sign in to comment.