@@ -238,6 +238,8 @@ def __init__(
238238
239239 # TODO: check that data in column self.group_variable_name has TWO levels
240240
241+ # TODO: check we have `unit` as a predictor column which is an vector of labels of unique units
242+
241243 # TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data
242244
243245 # DEVIATION FROM SKL EXPERIMENT CODE =============================
@@ -303,18 +305,17 @@ def plot(self):
303305
304306 # Plot raw data
305307 # NOTE: This will not work when there is just ONE unit in each group
306- # sns.lineplot(
307- # self.data,
308- # x=self.time_variable_name,
309- # y=self.outcome_variable_name,
310- # hue=self.group_variable_name,
311- # # units="unit",
312- # estimator=None,
313- # alpha=0.25 ,
314- # ax=ax,
315- # )
308+ sns .lineplot (
309+ self .data ,
310+ x = self .time_variable_name ,
311+ y = self .outcome_variable_name ,
312+ hue = self .group_variable_name ,
313+ units = "unit" , # NOTE: assumes we have a `unit` predictor variable
314+ estimator = None ,
315+ alpha = 0.5 ,
316+ ax = ax ,
317+ )
316318 # Plot model fit to control group
317- # NOTE: This will not work when there is just ONE unit in each group
318319 parts = ax .violinplot (
319320 az .extract (
320321 self .y_pred_control , group = "posterior_predictive" , var_names = "mu"
@@ -330,7 +331,6 @@ def plot(self):
330331 pc .set_alpha (0.5 )
331332
332333 # Plot model fit to treatment group
333- # NOTE: This will not work when there is just ONE unit in each group
334334 parts = ax .violinplot (
335335 az .extract (
336336 self .y_pred_treatment , group = "posterior_predictive" , var_names = "mu"
@@ -340,20 +340,41 @@ def plot(self):
340340 showmedians = False ,
341341 widths = 0.2 ,
342342 )
343- # # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
344- # # NOTE: This will not work when there is just ONE unit in each group
345- # parts = ax.violinplot(
346- # az.extract(
347- # self.y_pred_counterfactual,
348- # group="posterior_predictive",
349- # var_names="mu",
350- # ).values.T,
351- # positions=self.x_pred_counterfactual[self.time_variable_name].values,
352- # showmeans=False,
353- # showmedians=False,
354- # widths=0.2,
355- # )
343+ for pc in parts ["bodies" ]:
344+ pc .set_facecolor ("C1" )
345+ pc .set_edgecolor ("None" )
346+ pc .set_alpha (0.5 )
347+ # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
348+ parts = ax .violinplot (
349+ az .extract (
350+ self .y_pred_counterfactual ,
351+ group = "posterior_predictive" ,
352+ var_names = "mu" ,
353+ ).values .T ,
354+ positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
355+ showmeans = False ,
356+ showmedians = False ,
357+ widths = 0.2 ,
358+ )
359+ for pc in parts ["bodies" ]:
360+ pc .set_facecolor ("C2" )
361+ pc .set_edgecolor ("None" )
362+ pc .set_alpha (0.5 )
356363 # arrow to label the causal impact
364+ self ._plot_causal_impact_arrow (ax )
365+ # formatting
366+ ax .set (
367+ xticks = self .x_pred_treatment [self .time_variable_name ].values ,
368+ title = self ._causal_impact_summary_stat (),
369+ )
370+ ax .legend (fontsize = LEGEND_FONT_SIZE )
371+ return (fig , ax )
372+
373+ def _plot_causal_impact_arrow (self , ax ):
374+ """
375+ draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
376+ """
377+ # Calculate y values to plot the arrow between
357378 y_pred_treatment = (
358379 self .y_pred_treatment ["posterior_predictive" ]
359380 .mu .isel ({"obs_ind" : 1 })
@@ -363,32 +384,28 @@ def plot(self):
363384 y_pred_counterfactual = (
364385 self .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
365386 )
387+ # Calculate the x position to plot at
388+ diff = np .ptp (self .x_pred_treatment [self .time_variable_name ].values )
389+ x = np .max (self .x_pred_treatment [self .time_variable_name ].values ) + 0.1 * diff
390+ # Plot the arrow
366391 ax .annotate (
367392 "" ,
368- xy = (1.15 , y_pred_counterfactual ),
393+ xy = (x , y_pred_counterfactual ),
369394 xycoords = "data" ,
370- xytext = (1.15 , y_pred_treatment ),
395+ xytext = (x , y_pred_treatment ),
371396 textcoords = "data" ,
372- arrowprops = {"arrowstyle" : "<-> " , "color" : "green" , "lw" : 3 },
397+ arrowprops = {"arrowstyle" : "<-" , "color" : "green" , "lw" : 3 },
373398 )
399+ # Plot text annotation next to arrow
374400 ax .annotate (
375401 "causal\n impact" ,
376- xy = (1.15 , np .mean ([y_pred_counterfactual , y_pred_treatment ])),
402+ xy = (x , np .mean ([y_pred_counterfactual , y_pred_treatment ])),
377403 xycoords = "data" ,
378404 xytext = (5 , 0 ),
379405 textcoords = "offset points" ,
380406 color = "green" ,
381407 va = "center" ,
382408 )
383- # formatting
384- ax .set (
385- # xlim=[-0.15, 1.25],
386- xticks = self .x_pred_treatment [self .time_variable_name ].values ,
387- # xticklabels=["pre", "post"],
388- title = self ._causal_impact_summary_stat (),
389- )
390- ax .legend (fontsize = LEGEND_FONT_SIZE )
391- return (fig , ax )
392409
393410 def _causal_impact_summary_stat (self ):
394411 percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
0 commit comments