diff --git a/causallib/evaluation/plots/mixins.py b/causallib/evaluation/plots/mixins.py index a9a7524..e8deea2 100644 --- a/causallib/evaluation/plots/mixins.py +++ b/causallib/evaluation/plots/mixins.py @@ -22,6 +22,7 @@ def plot_covariate_balance( aggregate_folds=True, thresh=None, plot_semi_grid=True, + label_imbalanced=True, **kwargs, ): """Plot covariate balance before and after weighting. @@ -34,7 +35,8 @@ def plot_covariate_balance( aggregate_folds (bool, optional): Whether to aggregate folds. Defaults to True. Ignored when kind="slope". thresh (float, optional): Draw threshold line at value. Defaults to None. - plot_semi_grid (bool, optional): Defaults to True. Ignored when kind="slope". + plot_semi_grid (bool, optional): Defaults to True. only for kind="love". + label_imbalanced (bool): Label covariates that weren't properly balanced. Ignored when kind="love". Returns: matplotlib.axes.Axes: axis with plot @@ -56,19 +58,18 @@ def plot_covariate_balance( table1_folds=table1_folds, ax=ax, thresh=thresh, + label_imbalanced=label_imbalanced, **kwargs, ) - if kind == "scatter": return plots.plot_mean_features_imbalance_scatter_plot( table1_folds=table1_folds, ax=ax, thresh=thresh, + label_imbalanced=label_imbalanced, **kwargs, ) - - raise ValueError(f"Unsupported covariate balance plot kind {kind}") diff --git a/causallib/evaluation/plots/plots.py b/causallib/evaluation/plots/plots.py index 567f844..ea8d105 100644 --- a/causallib/evaluation/plots/plots.py +++ b/causallib/evaluation/plots/plots.py @@ -883,6 +883,7 @@ def plot_mean_features_imbalance_scatter_plot( table1_folds, aggregate_folds=True, thresh=None, + label_imbalanced=True, ax=None, ): # get current axes @@ -912,21 +913,21 @@ def plot_mean_features_imbalance_scatter_plot( violating = table1["weighted"] > thresh # determain color for dot on plot color = violating.replace({False: "C0", True: "C1"}) - - + ax.scatter( x=table1['unweighted'], y=table1['weighted'], marker=next(marker_cycle), color=color ) - for covariate_name, covariate_diff in table1.loc[violating].iterrows(): - ax.text( - x=covariate_diff["unweighted"], - y=covariate_diff["weighted"], - s=covariate_name, - horizontalalignment="left", - ) + if label_imbalanced: + for covariate_name, covariate_diff in table1.loc[violating].iterrows(): + ax.text( + x=covariate_diff["unweighted"], + y=covariate_diff["weighted"], + s=covariate_name, + horizontalalignment="left", + ) # Plot vertical and horizontal threshold line if thresh is not None: @@ -952,7 +953,7 @@ def plot_mean_features_imbalance_scatter_plot( def plot_mean_features_imbalance_slope_folds( - table1_folds, cv=None, thresh=None, ax=None + table1_folds, cv=None, thresh=None, label_imbalanced=True, ax=None ): method_pretty_name = { "smd": "Standard Mean Difference", @@ -975,6 +976,7 @@ def plot_mean_features_imbalance_slope_folds( left=aggregated_table1["unweighted"], right=aggregated_table1["weighted"], thresh=thresh, + label_imbalanced=label_imbalanced, ax=ax, ) @@ -988,7 +990,7 @@ def plot_mean_features_imbalance_slope_folds( def slope_graph( - left, right, thresh=None, color_below="C0", color_above="C1", marker="o", ax=None + left, right, thresh=None, label_imbalanced=True, color_below="C0", color_above="C1", marker="o", ax=None ): ax = ax or plt.gca() left_xtick = left.name or "unweighted" @@ -1015,7 +1017,7 @@ def slope_graph( color=cur_color, marker=marker, ) - if cur_right > thresh: + if label_imbalanced and cur_right > thresh: ax.text(x=1.01, y=cur_right, s=idx, horizontalalignment="left") # Place y-tick labels on both sides: diff --git a/causallib/tests/test_plots.py b/causallib/tests/test_plots.py index 0d14ee5..edf9a8b 100644 --- a/causallib/tests/test_plots.py +++ b/causallib/tests/test_plots.py @@ -122,7 +122,6 @@ def test_plot_covariate_balance_love_draws_thresh(self): self.assertEqual(thresh, axis.get_lines()[0].get_xdata()[0]) plt.close() - def test_plot_covariate_balance_scatter_draws_thresh(self): thresh = 0.1 f, ax = plt.subplots() @@ -140,6 +139,17 @@ def test_plot_covariate_balance_slope_labeled_correctly(self): self.assertEqual([x.get_xdata() for x in axis.get_lines()][1][0], "unweighted") plt.close() + def test_plot_covariate_balance_types_exchangeable_kwargs(self): + f, ax = plt.subplots(1, 3) + for i, kind in enumerate(["love", "slope", "scatter"]): + self.propensity_evaluation.plot_covariate_balance( + kind=kind, ax=ax[i], + plot_semi_grid=True, # A "love"-only kwarg + label_imbalanced=True, # A "slope" and "scatter" only kwarg + thresh=0.1, # So that there are imbalanced variables plotted + ) + plt.close(f) + def test_roc_curve_has_dashed_diag(self): self.ensure_roc_curve_has_dashed_diag(self.propensity_evaluation) self.ensure_roc_curve_has_dashed_diag(self.bin_outcome_evaluation)