diff --git a/causalpy/pymc_experiments.py b/causalpy/pymc_experiments.py index fb85b756..de84774d 100644 --- a/causalpy/pymc_experiments.py +++ b/causalpy/pymc_experiments.py @@ -302,17 +302,19 @@ def plot(self): fig, ax = plt.subplots() # Plot raw data - sns.lineplot( - self.data, - x=self.time_variable_name, - y=self.outcome_variable_name, - hue=self.group_variable_name, - units="unit", - estimator=None, - alpha=0.25, - ax=ax, - ) + # NOTE: This will not work when there is just ONE unit in each group + # sns.lineplot( + # self.data, + # x=self.time_variable_name, + # y=self.outcome_variable_name, + # hue=self.group_variable_name, + # # units="unit", + # estimator=None, + # alpha=0.25, + # ax=ax, + # ) # Plot model fit to control group + # NOTE: This will not work when there is just ONE unit in each group parts = ax.violinplot( az.extract( self.y_pred_control, group="posterior_predictive", var_names="mu" @@ -328,6 +330,7 @@ def plot(self): pc.set_alpha(0.5) # Plot model fit to treatment group + # NOTE: This will not work when there is just ONE unit in each group parts = ax.violinplot( az.extract( self.y_pred_treatment, group="posterior_predictive", var_names="mu" @@ -337,18 +340,19 @@ def plot(self): showmedians=False, widths=0.2, ) - # Plot counterfactual - post-test for treatment group IF no treatment had occurred. - parts = ax.violinplot( - az.extract( - self.y_pred_counterfactual, - group="posterior_predictive", - var_names="mu", - ).values.T, - positions=self.x_pred_counterfactual[self.time_variable_name].values, - showmeans=False, - showmedians=False, - widths=0.2, - ) + # # Plot counterfactual - post-test for treatment group IF no treatment had occurred. + # # NOTE: This will not work when there is just ONE unit in each group + # parts = ax.violinplot( + # az.extract( + # self.y_pred_counterfactual, + # group="posterior_predictive", + # var_names="mu", + # ).values.T, + # positions=self.x_pred_counterfactual[self.time_variable_name].values, + # showmeans=False, + # showmedians=False, + # widths=0.2, + # ) # arrow to label the causal impact y_pred_treatment = ( self.y_pred_treatment["posterior_predictive"] @@ -378,9 +382,9 @@ def plot(self): ) # formatting ax.set( - xlim=[-0.15, 1.25], - xticks=[0, 1], - xticklabels=["pre", "post"], + # xlim=[-0.15, 1.25], + xticks=self.x_pred_treatment[self.time_variable_name].values, + # xticklabels=["pre", "post"], title=self._causal_impact_summary_stat(), ) ax.legend(fontsize=LEGEND_FONT_SIZE)