Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand data validation to also cover the scikit-learn experiment classes #290

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions causalpy/data_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import warnings # noqa: I001

import pandas as pd
import numpy as np
from causalpy.custom_exceptions import (
BadIndexException, # NOQA
DataException,
FormulaException,
)
from causalpy.utils import _is_variable_dummy_coded


class PrePostFitDataValidator:
"""Mixin class for validating the input data and model formula for PrePostFit"""

def _input_validation(self, data, treatment_time):
"""Validate the input data and model formula for correctness"""
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
)
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
)


class DiDDataValidator:
"""Mixin class for validating the input data and model formula for Difference in Differences experiments."""

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "post_treatment" not in self.formula:
raise FormulaException(
"A predictor called `post_treatment` should be in the formula"
)

if "post_treatment" not in self.data.columns:
raise DataException(
"Require a boolean column labelling observations which are `treated`"
)

if "unit" not in self.data.columns:
raise DataException(
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
)

if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
raise DataException(
f"""The grouping variable {self.group_variable_name} should be dummy
coded. Consisting of 0's and 1's only."""
)


class RDDataValidator:
"""Mixin class for validating the input data and model formula for Regression Discontinuity experiments."""

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "treated" not in self.formula:
raise FormulaException(
"A predictor called `treated` should be in the formula"
)

if _is_variable_dummy_coded(self.data["treated"]) is False:
raise DataException(
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
)


class RegressionKinkDataValidator:
"""Mixin class for validating the input data and model formula for Regression Kink experiments."""

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "treated" not in self.formula:
raise FormulaException(

Check warning on line 81 in causalpy/data_validation.py

View check run for this annotation

Codecov / codecov/patch

causalpy/data_validation.py#L81

Added line #L81 was not covered by tests
"A predictor called `treated` should be in the formula"
)

if _is_variable_dummy_coded(self.data["treated"]) is False:
raise DataException(

Check warning on line 86 in causalpy/data_validation.py

View check run for this annotation

Codecov / codecov/patch

causalpy/data_validation.py#L86

Added line #L86 was not covered by tests
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
)

if self.bandwidth <= 0:
raise ValueError("The bandwidth must be greater than zero.")

if self.epsilon <= 0:
raise ValueError("Epsilon must be greater than zero.")


class PrePostNEGDDataValidator:
"""Mixin class for validating the input data and model formula for PrePostNEGD experiments."""

def _input_validation(self) -> None:
"""Validate the input data and model formula for correctness"""
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
raise DataException(
f"""
There must be 2 levels of the grouping variable
{self.group_variable_name}. I.e. the treated and untreated.
"""
)


class IVDataValidator:
"""Mixin class for validating the input data and model formula for IV experiments."""

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
treatment = self.instruments_formula.split("~")[0]
test = treatment.strip() in self.instruments_data.columns
test = test & (treatment.strip() in self.data.columns)
if not test:
raise DataException(
f"""
The treatment variable:
{treatment} must appear in the instrument_data to be used
as an outcome variable and in the data object to be used as a covariate.
"""
)
Z = self.data[treatment.strip()]
check_binary = len(np.unique(Z)) > 2
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
)
129 changes: 14 additions & 115 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
from patsy import build_design_matrices, dmatrices
from sklearn.linear_model import LinearRegression as sk_lin_reg

from causalpy.custom_exceptions import (
BadIndexException, # NOQA
DataException,
FormulaException,
from causalpy.data_validation import (
PrePostFitDataValidator,
DiDDataValidator,
RDDataValidator,
RegressionKinkDataValidator,
PrePostNEGDDataValidator,
IVDataValidator,
)
from causalpy.plot_utils import plot_xY
from causalpy.utils import _is_variable_dummy_coded, round_num
from causalpy.utils import round_num

LEGEND_FONT_SIZE = 12
az.style.use("arviz-darkgrid")
Expand Down Expand Up @@ -108,7 +111,7 @@ def print_coefficients(self, round_to=None) -> None:
)


class PrePostFit(ExperimentalDesign):
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
"""
A class to analyse quasi-experiments where parameter estimation is based on just
the pre-intervention data.
Expand Down Expand Up @@ -160,7 +163,6 @@ def __init__(
) -> None:
super().__init__(model=model, **kwargs)
self._input_validation(data, treatment_time)

self.treatment_time = treatment_time
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
Expand Down Expand Up @@ -214,21 +216,6 @@ def __init__(
# cumulative impact post
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")

def _input_validation(self, data, treatment_time):
"""Validate the input data and model formula for correctness"""
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
)
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
treatment_time, pd.Timestamp
):
raise BadIndexException(
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
)

def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
"""
Plot the results
Expand Down Expand Up @@ -438,7 +425,7 @@ def plot(self, plot_predictors=False, **kwargs):
return fig, ax


class DifferenceInDifferences(ExperimentalDesign):
class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
"""A class to analyse data from Difference in Difference settings.

.. note::
Expand Down Expand Up @@ -568,29 +555,6 @@ def __init__(
if "post_treatment" in label and self.group_variable_name in label:
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "post_treatment" not in self.formula:
raise FormulaException(
"A predictor called `post_treatment` should be in the formula"
)

if "post_treatment" not in self.data.columns:
raise DataException(
"Require a boolean column labelling observations which are `treated`"
)

if "unit" not in self.data.columns:
raise DataException(
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
)

if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
raise DataException(
f"""The grouping variable {self.group_variable_name} should be dummy
coded. Consisting of 0's and 1's only."""
)

def plot(self, round_to=None):
"""Plot the results.

Expand Down Expand Up @@ -749,7 +713,7 @@ def summary(self, round_to=None) -> None:
self.print_coefficients(round_to)


class RegressionDiscontinuity(ExperimentalDesign):
class RegressionDiscontinuity(ExperimentalDesign, RDDataValidator):
"""
A class to analyse sharp regression discontinuity experiments.

Expand Down Expand Up @@ -876,18 +840,6 @@ def __init__(
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
)

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "treated" not in self.formula:
raise FormulaException(
"A predictor called `treated` should be in the formula"
)

if _is_variable_dummy_coded(self.data["treated"]) is False:
raise DataException(
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
)

def _is_treated(self, x):
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.

Expand Down Expand Up @@ -970,7 +922,7 @@ def summary(self, round_to: None) -> None:
self.print_coefficients(round_to)


class RegressionKink(ExperimentalDesign):
class RegressionKink(ExperimentalDesign, RegressionKinkDataValidator):
"""
A class to analyse sharp regression kink experiments.

Expand Down Expand Up @@ -1095,24 +1047,6 @@ def _probe_kink_point(self):
mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"]
return mu_kink_left, mu_kink, mu_kink_right

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
if "treated" not in self.formula:
raise FormulaException(
"A predictor called `treated` should be in the formula"
)

if _is_variable_dummy_coded(self.data["treated"]) is False:
raise DataException(
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
)

if self.bandwidth <= 0:
raise ValueError("The bandwidth must be greater than zero.")

if self.epsilon <= 0:
raise ValueError("Epsilon must be greater than zero.")

def _is_treated(self, x):
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
return np.greater_equal(x, self.kink_point)
Expand Down Expand Up @@ -1193,7 +1127,7 @@ def summary(self, round_to=None) -> None:
self.print_coefficients(round_to)


class PrePostNEGD(ExperimentalDesign):
class PrePostNEGD(ExperimentalDesign, PrePostNEGDDataValidator):
"""
A class to analyse data from pretest/posttest designs

Expand Down Expand Up @@ -1302,18 +1236,6 @@ def __init__(
{"coeffs": self._get_treatment_effect_coeff()}
)

# ================================================================

def _input_validation(self) -> None:
"""Validate the input data and model formula for correctness"""
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
raise DataException(
f"""
There must be 2 levels of the grouping variable
{self.group_variable_name}. I.e. the treated and untreated.
"""
)

def plot(self, round_to=None):
"""Plot the results

Expand Down Expand Up @@ -1408,7 +1330,7 @@ def _get_treatment_effect_coeff(self) -> str:
raise NameError("Unable to find coefficient name for the treatment effect")


class InstrumentalVariable(ExperimentalDesign):
class InstrumentalVariable(ExperimentalDesign, IVDataValidator):
"""
A class to analyse instrumental variable style experiments.

Expand Down Expand Up @@ -1555,26 +1477,3 @@ def get_naive_OLS_fit(self):
beta_params.insert(0, ols_reg.intercept_[0])
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
self.ols_reg = ols_reg

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
treatment = self.instruments_formula.split("~")[0]
test = treatment.strip() in self.instruments_data.columns
test = test & (treatment.strip() in self.data.columns)
if not test:
raise DataException(
f"""
The treatment variable:
{treatment} must appear in the instrument_data to be used
as an outcome variable and in the data object to be used as a covariate.
"""
)
Z = self.data[treatment.strip()]
check_binary = len(np.unique(Z)) > 2
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
)
Loading
Loading