diff --git a/doubleml/did/did_multi.py b/doubleml/did/did_multi.py index a9e9e790..94ef112c 100644 --- a/doubleml/did/did_multi.py +++ b/doubleml/did/did_multi.py @@ -979,12 +979,13 @@ def aggregate(self, aggregation="group"): def plot_effects( self, level=0.95, + result_type="effect", joint=True, figsize=(12, 8), color_palette="colorblind", date_format=None, - y_label="Effect", - title="Estimated ATTs by Group", + y_label=None, + title=None, jitter_value=None, default_jitter=0.1, ): @@ -996,6 +997,10 @@ def plot_effects( level : float The confidence level for the intervals. Default is ``0.95``. + result_type : str + Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values, + ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds. + Default is ``'effect'``. joint : bool Indicates whether joint confidence intervals are computed. Default is ``True``. @@ -1010,10 +1015,10 @@ def plot_effects( Default is ``None``. y_label : str Label for y-axis. - Default is ``"Effect"``. + Default is ``None``. title : str Title for the entire plot. - Default is ``"Estimated ATTs by Group"``. + Default is ``None``. jitter_value : float Amount of jitter to apply to points. Default is ``None``. @@ -1035,8 +1040,29 @@ def plot_effects( """ if self.framework is None: raise ValueError("Apply fit() before plot_effects().") + + if result_type not in ["effect", "rv", "est_bounds", "ci_bounds"]: + raise ValueError("result_type must be either 'effect', 'rv', 'est_bounds' or 'ci_bounds'.") + + if result_type != "effect" and self._framework.sensitivity_params is None: + raise ValueError( + f"result_type='{result_type}' requires sensitivity analysis. " "Please call sensitivity_analysis() first." + ) + df = self._create_ci_dataframe(level=level, joint=joint) + # Set default y_label and title based on result_type + label_configs = { + "effect": {"y_label": "Effect", "title": "Estimated ATTs by Group"}, + "rv": {"y_label": "Robustness Value", "title": "Robustness Values by Group"}, + "est_bounds": {"y_label": "Estimate Bounds", "title": "Estimate Bounds by Group"}, + "ci_bounds": {"y_label": "Confidence Interval Bounds", "title": "Confidence Interval Bounds by Group"}, + } + + config = label_configs[result_type] + y_label = y_label if y_label is not None else config["y_label"] + title = title if title is not None else config["title"] + # Sort time periods and treatment groups first_treated_periods = sorted(df["First Treated"].unique()) n_periods = len(first_treated_periods) @@ -1068,7 +1094,7 @@ def plot_effects( period_df = df[df["First Treated"] == period] ax = axes[idx] - self._plot_single_group(ax, period_df, period, colors, is_datetime, jitter_value) + self._plot_single_group(ax, period_df, period, result_type, colors, is_datetime, jitter_value) # Set axis labels if idx == n_periods - 1: # Only bottom plot gets x label @@ -1085,7 +1111,7 @@ def plot_effects( legend_ax.axis("off") legend_elements = [ Line2D([0], [0], color="red", linestyle=":", alpha=0.7, label="Treatment start"), - Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label="Zero effect"), + Line2D([0], [0], color="black", linestyle="--", alpha=0.5, label=f"Zero {result_type}"), Line2D([0], [0], marker="o", color=colors["pre"], linestyle="None", label="Pre-treatment", markersize=5), ] @@ -1108,7 +1134,7 @@ def plot_effects( return fig, axes - def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_value): + def _plot_single_group(self, ax, period_df, period, result_type, colors, is_datetime, jitter_value): """ Plot estimates for a single treatment group on the given axis. @@ -1120,6 +1146,10 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ DataFrame containing estimates for a specific time period. period : int or datetime Treatment period for this group. + result_type : str + Type of result to plot. Either ``'effect'`` for point estimates, ``'rv'`` for robustness values, + ``'est_bounds'`` for estimate bounds, or ``'ci_bounds'`` for confidence interval bounds. + Default is ``'effect'``. colors : dict Dictionary with 'pre', 'anticipation' (if applicable), and 'post' color values. is_datetime : bool @@ -1165,6 +1195,31 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ # Define category mappings categories = [("pre", pre_treatment_mask), ("anticipation", anticipation_mask), ("post", post_treatment_mask)] + # Define plot configurations for each result type + plot_configs = { + "effect": {"plot_col": "Estimate", "err_col_upper": "CI Upper", "err_col_lower": "CI Lower", "s_val": 30}, + "rv": {"plot_col": "RV", "plot_col_2": "RVa", "s_val": 50}, + "est_bounds": { + "plot_col": "Estimate", + "err_col_upper": "Estimate Upper Bound", + "err_col_lower": "Estimate Lower Bound", + "s_val": 30, + }, + "ci_bounds": { + "plot_col": "Estimate", + "err_col_upper": "CI Upper Bound", + "err_col_lower": "CI Lower Bound", + "s_val": 30, + }, + } + + config = plot_configs[result_type] + plot_col = config["plot_col"] + plot_col_2 = config.get("plot_col_2") + err_col_upper = config.get("err_col_upper") + err_col_lower = config.get("err_col_lower") + s_val = config["s_val"] + # Plot each category for category_name, mask in categories: if not mask.any(): @@ -1179,22 +1234,33 @@ def _plot_single_group(self, ax, period_df, period, colors, is_datetime, jitter_ if not category_data.empty: ax.scatter( - category_data["jittered_x"], category_data["Estimate"], color=colors[category_name], alpha=0.8, s=30 - ) - ax.errorbar( - category_data["jittered_x"], - category_data["Estimate"], - yerr=[ - category_data["Estimate"] - category_data["CI Lower"], - category_data["CI Upper"] - category_data["Estimate"], - ], - fmt="o", - capsize=3, - color=colors[category_name], - markersize=4, - markeredgewidth=1, - linewidth=1, + category_data["jittered_x"], category_data[plot_col], color=colors[category_name], alpha=0.8, s=s_val ) + if result_type in ["effect", "est_bounds", "ci_bounds"]: + ax.errorbar( + category_data["jittered_x"], + category_data[plot_col], + yerr=[ + category_data[plot_col] - category_data[err_col_lower], + category_data[err_col_upper] - category_data[plot_col], + ], + fmt="o", + capsize=3, + color=colors[category_name], + markersize=4, + markeredgewidth=1, + linewidth=1, + ) + + elif result_type == "rv": + ax.scatter( + category_data["jittered_x"], + category_data[plot_col_2], + color=colors[category_name], + alpha=0.8, + s=s_val, + marker="s", + ) # Format axes if is_datetime: @@ -1431,6 +1497,8 @@ def _create_ci_dataframe(self, level=0.95, joint=True): - 'CI Lower': Lower bound of confidence intervals - 'CI Upper': Upper bound of confidence intervals - 'Pre-Treatment': Boolean indicating if evaluation period is before treatment + - 'RV': Robustness values (if sensitivity_analysis() has been called before) + - 'RVa': Robustness values for (1-a) confidence bounds (if sensitivity_analysis() has been called before) Notes ----- @@ -1459,5 +1527,11 @@ def _create_ci_dataframe(self, level=0.95, joint=True): "Pre-Treatment": [gt_combination[2] < gt_combination[0] for gt_combination in self.gt_combinations], } ) - + if self._framework.sensitivity_params is not None: + df["RV"] = self.framework.sensitivity_params["rv"] + df["RVa"] = self.framework.sensitivity_params["rva"] + df["CI Lower Bound"] = self.framework.sensitivity_params["ci"]["lower"] + df["CI Upper Bound"] = self.framework.sensitivity_params["ci"]["upper"] + df["Estimate Lower Bound"] = self.framework.sensitivity_params["theta"]["lower"] + df["Estimate Upper Bound"] = self.framework.sensitivity_params["theta"]["upper"] return df diff --git a/doubleml/did/tests/test_did_multi_plot.py b/doubleml/did/tests/test_did_multi_plot.py index 4a55449d..d4275cde 100644 --- a/doubleml/did/tests/test_did_multi_plot.py +++ b/doubleml/did/tests/test_did_multi_plot.py @@ -184,3 +184,109 @@ def test_plot_effects_jitter(doubleml_did_fixture): assert fig_default != fig plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_types(doubleml_did_fixture): + """Test plot_effects with different result types.""" + dml_obj = doubleml_did_fixture["model"] + + # Test default result_type='effect' + fig_effect, axes_effect = dml_obj.plot_effects(result_type="effect") + assert isinstance(fig_effect, plt.Figure) + assert isinstance(axes_effect, list) + + # Check that the default y-label is set correctly + assert axes_effect[0].get_ylabel() == "Effect" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_rv(doubleml_did_fixture): + """Test plot_effects with result_type='rv' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='rv' + fig_rv, axes_rv = dml_obj.plot_effects(result_type="rv") + assert isinstance(fig_rv, plt.Figure) + assert isinstance(axes_rv, list) + + # Check that the y-label is set correctly + assert axes_rv[0].get_ylabel() == "Robustness Value" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_est_bounds(doubleml_did_fixture): + """Test plot_effects with result_type='est_bounds' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='est_bounds' + fig_est, axes_est = dml_obj.plot_effects(result_type="est_bounds") + assert isinstance(fig_est, plt.Figure) + assert isinstance(axes_est, list) + + # Check that the y-label is set correctly + assert axes_est[0].get_ylabel() == "Estimate Bounds" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_ci_bounds(doubleml_did_fixture): + """Test plot_effects with result_type='ci_bounds' (requires sensitivity analysis).""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type='ci_bounds' + fig_ci, axes_ci = dml_obj.plot_effects(result_type="ci_bounds") + assert isinstance(fig_ci, plt.Figure) + assert isinstance(axes_ci, list) + + # Check that the y-label is set correctly + assert axes_ci[0].get_ylabel() == "Confidence Interval Bounds" + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_invalid(doubleml_did_fixture): + """Test plot_effects with invalid result_type.""" + dml_obj = doubleml_did_fixture["model"] + + # Test with invalid result_type + with pytest.raises(ValueError, match="result_type must be either"): + dml_obj.plot_effects(result_type="invalid_type") + + plt.close("all") + + +@pytest.mark.ci +def test_plot_effects_result_type_with_custom_labels(doubleml_did_fixture): + """Test plot_effects with result_type and custom labels.""" + dml_obj = doubleml_did_fixture["model"] + + # Perform sensitivity analysis first + dml_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03) + + # Test result_type with custom labels + custom_title = "Custom Sensitivity Plot" + custom_ylabel = "Custom Bounds Label" + + fig, axes = dml_obj.plot_effects(result_type="est_bounds", title=custom_title, y_label=custom_ylabel) + + assert isinstance(fig, plt.Figure) + assert fig._suptitle.get_text() == custom_title + assert axes[0].get_ylabel() == custom_ylabel + + plt.close("all")