diff --git a/causalpy/data_validation.py b/causalpy/data_validation.py new file mode 100644 index 00000000..c7988f3e --- /dev/null +++ b/causalpy/data_validation.py @@ -0,0 +1,135 @@ +import warnings # noqa: I001 + +import pandas as pd +import numpy as np +from causalpy.custom_exceptions import ( + BadIndexException, # NOQA + DataException, + FormulaException, +) +from causalpy.utils import _is_variable_dummy_coded + + +class PrePostFitDataValidator: + """Mixin class for validating the input data and model formula for PrePostFit""" + + def _input_validation(self, data, treatment_time): + """Validate the input data and model formula for correctness""" + if isinstance(data.index, pd.DatetimeIndex) and not isinstance( + treatment_time, pd.Timestamp + ): + raise BadIndexException( + "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp." + ) + if not isinstance(data.index, pd.DatetimeIndex) and isinstance( + treatment_time, pd.Timestamp + ): + raise BadIndexException( + "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501 + ) + + +class DiDDataValidator: + """Mixin class for validating the input data and model formula for Difference in Differences experiments.""" + + def _input_validation(self): + """Validate the input data and model formula for correctness""" + if "post_treatment" not in self.formula: + raise FormulaException( + "A predictor called `post_treatment` should be in the formula" + ) + + if "post_treatment" not in self.data.columns: + raise DataException( + "Require a boolean column labelling observations which are `treated`" + ) + + if "unit" not in self.data.columns: + raise DataException( + "Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501 + ) + + if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False: + raise DataException( + f"""The grouping variable {self.group_variable_name} should be dummy + coded. Consisting of 0's and 1's only.""" + ) + + +class RDDataValidator: + """Mixin class for validating the input data and model formula for Regression Discontinuity experiments.""" + + def _input_validation(self): + """Validate the input data and model formula for correctness""" + if "treated" not in self.formula: + raise FormulaException( + "A predictor called `treated` should be in the formula" + ) + + if _is_variable_dummy_coded(self.data["treated"]) is False: + raise DataException( + """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 + ) + + +class RegressionKinkDataValidator: + """Mixin class for validating the input data and model formula for Regression Kink experiments.""" + + def _input_validation(self): + """Validate the input data and model formula for correctness""" + if "treated" not in self.formula: + raise FormulaException( + "A predictor called `treated` should be in the formula" + ) + + if _is_variable_dummy_coded(self.data["treated"]) is False: + raise DataException( + """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 + ) + + if self.bandwidth <= 0: + raise ValueError("The bandwidth must be greater than zero.") + + if self.epsilon <= 0: + raise ValueError("Epsilon must be greater than zero.") + + +class PrePostNEGDDataValidator: + """Mixin class for validating the input data and model formula for PrePostNEGD experiments.""" + + def _input_validation(self) -> None: + """Validate the input data and model formula for correctness""" + if not _is_variable_dummy_coded(self.data[self.group_variable_name]): + raise DataException( + f""" + There must be 2 levels of the grouping variable + {self.group_variable_name}. I.e. the treated and untreated. + """ + ) + + +class IVDataValidator: + """Mixin class for validating the input data and model formula for IV experiments.""" + + def _input_validation(self): + """Validate the input data and model formula for correctness""" + treatment = self.instruments_formula.split("~")[0] + test = treatment.strip() in self.instruments_data.columns + test = test & (treatment.strip() in self.data.columns) + if not test: + raise DataException( + f""" + The treatment variable: + {treatment} must appear in the instrument_data to be used + as an outcome variable and in the data object to be used as a covariate. + """ + ) + Z = self.data[treatment.strip()] + check_binary = len(np.unique(Z)) > 2 + if check_binary: + warnings.warn( + """Warning. The treatment variable is not Binary. + This is not necessarily a problem but it violates + the assumption of a simple IV experiment. + The coefficients should be interpreted appropriately.""" + ) diff --git a/causalpy/pymc_experiments.py b/causalpy/pymc_experiments.py index 9f385db2..fbb7e0f5 100644 --- a/causalpy/pymc_experiments.py +++ b/causalpy/pymc_experiments.py @@ -23,13 +23,16 @@ from patsy import build_design_matrices, dmatrices from sklearn.linear_model import LinearRegression as sk_lin_reg -from causalpy.custom_exceptions import ( - BadIndexException, # NOQA - DataException, - FormulaException, +from causalpy.data_validation import ( + PrePostFitDataValidator, + DiDDataValidator, + RDDataValidator, + RegressionKinkDataValidator, + PrePostNEGDDataValidator, + IVDataValidator, ) from causalpy.plot_utils import plot_xY -from causalpy.utils import _is_variable_dummy_coded, round_num +from causalpy.utils import round_num LEGEND_FONT_SIZE = 12 az.style.use("arviz-darkgrid") @@ -108,7 +111,7 @@ def print_coefficients(self, round_to=None) -> None: ) -class PrePostFit(ExperimentalDesign): +class PrePostFit(ExperimentalDesign, PrePostFitDataValidator): """ A class to analyse quasi-experiments where parameter estimation is based on just the pre-intervention data. @@ -160,7 +163,6 @@ def __init__( ) -> None: super().__init__(model=model, **kwargs) self._input_validation(data, treatment_time) - self.treatment_time = treatment_time # set experiment type - usually done in subclasses self.expt_type = "Pre-Post Fit" @@ -214,21 +216,6 @@ def __init__( # cumulative impact post self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind") - def _input_validation(self, data, treatment_time): - """Validate the input data and model formula for correctness""" - if isinstance(data.index, pd.DatetimeIndex) and not isinstance( - treatment_time, pd.Timestamp - ): - raise BadIndexException( - "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp." - ) - if not isinstance(data.index, pd.DatetimeIndex) and isinstance( - treatment_time, pd.Timestamp - ): - raise BadIndexException( - "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501 - ) - def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs): """ Plot the results @@ -438,7 +425,7 @@ def plot(self, plot_predictors=False, **kwargs): return fig, ax -class DifferenceInDifferences(ExperimentalDesign): +class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator): """A class to analyse data from Difference in Difference settings. .. note:: @@ -568,29 +555,6 @@ def __init__( if "post_treatment" in label and self.group_variable_name in label: self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i}) - def _input_validation(self): - """Validate the input data and model formula for correctness""" - if "post_treatment" not in self.formula: - raise FormulaException( - "A predictor called `post_treatment` should be in the formula" - ) - - if "post_treatment" not in self.data.columns: - raise DataException( - "Require a boolean column labelling observations which are `treated`" - ) - - if "unit" not in self.data.columns: - raise DataException( - "Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501 - ) - - if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False: - raise DataException( - f"""The grouping variable {self.group_variable_name} should be dummy - coded. Consisting of 0's and 1's only.""" - ) - def plot(self, round_to=None): """Plot the results. @@ -749,7 +713,7 @@ def summary(self, round_to=None) -> None: self.print_coefficients(round_to) -class RegressionDiscontinuity(ExperimentalDesign): +class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator): """ A class to analyse sharp regression discontinuity experiments. @@ -876,18 +840,6 @@ def __init__( - self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"] ) - def _input_validation(self): - """Validate the input data and model formula for correctness""" - if "treated" not in self.formula: - raise FormulaException( - "A predictor called `treated` should be in the formula" - ) - - if _is_variable_dummy_coded(self.data["treated"]) is False: - raise DataException( - """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 - ) - def _is_treated(self, x): """Returns ``True`` if `x` is greater than or equal to the treatment threshold. @@ -970,7 +922,7 @@ def summary(self, round_to: None) -> None: self.print_coefficients(round_to) -class RegressionKink(ExperimentalDesign): +class RegressionKink(ExperimentalDesign, RegressionKinkDataValidator): """ A class to analyse sharp regression kink experiments. @@ -1095,24 +1047,6 @@ def _probe_kink_point(self): mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"] return mu_kink_left, mu_kink, mu_kink_right - def _input_validation(self): - """Validate the input data and model formula for correctness""" - if "treated" not in self.formula: - raise FormulaException( - "A predictor called `treated` should be in the formula" - ) - - if _is_variable_dummy_coded(self.data["treated"]) is False: - raise DataException( - """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 - ) - - if self.bandwidth <= 0: - raise ValueError("The bandwidth must be greater than zero.") - - if self.epsilon <= 0: - raise ValueError("Epsilon must be greater than zero.") - def _is_treated(self, x): """Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501 return np.greater_equal(x, self.kink_point) @@ -1193,7 +1127,7 @@ def summary(self, round_to=None) -> None: self.print_coefficients(round_to) -class PrePostNEGD(ExperimentalDesign): +class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator): """ A class to analyse data from pretest/posttest designs @@ -1302,18 +1236,6 @@ def __init__( {"coeffs": self._get_treatment_effect_coeff()} ) - # ================================================================ - - def _input_validation(self) -> None: - """Validate the input data and model formula for correctness""" - if not _is_variable_dummy_coded(self.data[self.group_variable_name]): - raise DataException( - f""" - There must be 2 levels of the grouping variable - {self.group_variable_name}. I.e. the treated and untreated. - """ - ) - def plot(self, round_to=None): """Plot the results @@ -1408,7 +1330,7 @@ def _get_treatment_effect_coeff(self) -> str: raise NameError("Unable to find coefficient name for the treatment effect") -class InstrumentalVariable(ExperimentalDesign): +class InstrumentalVariable(ExperimentalDesign, IVDataValidator): """ A class to analyse instrumental variable style experiments. @@ -1555,26 +1477,3 @@ def get_naive_OLS_fit(self): beta_params.insert(0, ols_reg.intercept_[0]) self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params)) self.ols_reg = ols_reg - - def _input_validation(self): - """Validate the input data and model formula for correctness""" - treatment = self.instruments_formula.split("~")[0] - test = treatment.strip() in self.instruments_data.columns - test = test & (treatment.strip() in self.data.columns) - if not test: - raise DataException( - f""" - The treatment variable: - {treatment} must appear in the instrument_data to be used - as an outcome variable and in the data object to be used as a covariate. - """ - ) - Z = self.data[treatment.strip()] - check_binary = len(np.unique(Z)) > 2 - if check_binary: - warnings.warn( - """Warning. The treatment variable is not Binary. - This is not necessarily a problem but it violates - the assumption of a simple IV experiment. - The coefficients should be interpreted appropriately.""" - ) diff --git a/causalpy/skl_experiments.py b/causalpy/skl_experiments.py index 5c51d489..f9dce73b 100644 --- a/causalpy/skl_experiments.py +++ b/causalpy/skl_experiments.py @@ -17,6 +17,11 @@ import seaborn as sns from patsy import build_design_matrices, dmatrices +from causalpy.data_validation import ( + DiDDataValidator, + PrePostFitDataValidator, + RDDataValidator, +) from causalpy.utils import round_num LEGEND_FONT_SIZE = 12 @@ -35,7 +40,7 @@ def __init__(self, model=None, **kwargs): raise ValueError("fitting_model not set or passed.") -class PrePostFit(ExperimentalDesign): +class PrePostFit(ExperimentalDesign, PrePostFitDataValidator): """ A class to analyse quasi-experiments where parameter estimation is based on just the pre-intervention data. @@ -74,6 +79,7 @@ def __init__( **kwargs, ): super().__init__(model=model, **kwargs) + self._input_validation(data, treatment_time) self.treatment_time = treatment_time # split data in to pre and post intervention self.datapre = data[data.index < self.treatment_time] @@ -284,7 +290,7 @@ def plot(self, plot_predictors=False, round_to=None, **kwargs): return (fig, ax) -class DifferenceInDifferences(ExperimentalDesign): +class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator): """ .. note:: @@ -334,6 +340,7 @@ def __init__( self.formula = formula self.time_variable_name = time_variable_name self.group_variable_name = group_variable_name + self._input_validation() self.treated = treated # level of the group_variable_name that was treated self.untreated = ( untreated # level of the group_variable_name that was untreated @@ -486,7 +493,7 @@ def plot(self, round_to=None): return (fig, ax) -class RegressionDiscontinuity(ExperimentalDesign): +class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator): """ A class to analyse sharp regression discontinuity experiments. @@ -550,6 +557,7 @@ def __init__( self.treatment_threshold = treatment_threshold self.bandwidth = bandwidth self.epsilon = epsilon + self._input_validation() if self.bandwidth is not None: fmin = self.treatment_threshold - self.bandwidth diff --git a/causalpy/tests/test_input_validation.py b/causalpy/tests/test_input_validation.py index 9de46179..5c895b4d 100644 --- a/causalpy/tests/test_input_validation.py +++ b/causalpy/tests/test_input_validation.py @@ -35,6 +35,17 @@ def test_did_validation_post_treatment_formula(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + with pytest.raises(FormulaException): + _ = cp.skl_experiments.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_SOMETHING", + time_variable_name="t", + group_variable_name="group", + treated=1, + untreated=0, + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_post_treatment_data(): """Test that we get a DataException if do not include post_treatment in the data""" @@ -57,6 +68,17 @@ def test_did_validation_post_treatment_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + with pytest.raises(DataException): + _ = cp.skl_experiments.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + treated=1, + untreated=0, + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_unit_data(): """Test that we get a DataException if do not include unit in the data""" @@ -79,6 +101,17 @@ def test_did_validation_unit_data(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + with pytest.raises(DataException): + _ = cp.skl_experiments.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + treated=1, + untreated=0, + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + def test_did_validation_group_dummy_coded(): """Test that we get a DataException if the group variable is not dummy coded""" @@ -101,6 +134,17 @@ def test_did_validation_group_dummy_coded(): model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), ) + with pytest.raises(DataException): + _ = cp.skl_experiments.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + treated=1, + untreated=0, + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + # Synthetic Control @@ -118,6 +162,16 @@ def test_sc_input_error(): model=cp.pymc_models.WeightedSumFitter(sample_kwargs=sample_kwargs), ) + with pytest.raises(BadIndexException): + df = cp.load_data("sc") + treatment_time = pd.to_datetime("2016 June 24") + _ = cp.skl_experiments.SyntheticControl( + df, + treatment_time, + formula="actual ~ 0 + a + b + c + d + e + f + g", + model=cp.skl_models.WeightedProportion(), + ) + def test_sc_brexit_input_error(): """Confirm a BadIndexException is raised if the data index is datetime and the @@ -187,6 +241,16 @@ def test_rd_validation_treated_in_formula(): treatment_threshold=0.5, ) + with pytest.raises(FormulaException): + from sklearn.linear_model import LinearRegression + + _ = cp.skl_experiments.RegressionDiscontinuity( + df, + formula="y ~ 1 + x", + model=LinearRegression(), + treatment_threshold=0.5, + ) + def test_rd_validation_treated_is_dummy(): """Test that we get a DataException if treated is not dummy coded""" @@ -206,6 +270,16 @@ def test_rd_validation_treated_is_dummy(): treatment_threshold=0.5, ) + from sklearn.linear_model import LinearRegression + + with pytest.raises(DataException): + _ = cp.skl_experiments.RegressionDiscontinuity( + df, + formula="y ~ 1 + x + treated", + model=LinearRegression(), + treatment_threshold=0.5, + ) + def test_iv_treatment_var_is_present(): """Test the treatment variable is present for Instrumental Variable experiment""" diff --git a/docs/source/_static/classes.png b/docs/source/_static/classes.png index 35af55d7..5842bfff 100644 Binary files a/docs/source/_static/classes.png and b/docs/source/_static/classes.png differ diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 2cf9af2b..9a7d6c00 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 97.0% + interrogate: 96.6% @@ -12,8 +12,8 @@ interrogate interrogate - 97.0% - 97.0% + 96.6% + 96.6% diff --git a/docs/source/_static/packages.png b/docs/source/_static/packages.png index 5de25ef3..9713e0b7 100644 Binary files a/docs/source/_static/packages.png and b/docs/source/_static/packages.png differ