Skip to content

Commit

Permalink
#76 #44 DID now works generalises to custom varnames + level values
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Nov 19, 2022
1 parent d348ee8 commit 7840611
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 34 deletions.
36 changes: 32 additions & 4 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def __init__(
data: pd.DataFrame,
formula: str,
time_variable_name: str,
group_variable_name: str,
treated: str,
untreated: str,
prediction_model=None,
**kwargs,
):
Expand All @@ -217,13 +220,24 @@ def __init__(
self.expt_type = "Difference in Differences"
self.formula = formula
self.time_variable_name = time_variable_name
self.group_variable_name = group_variable_name
self.treated = treated # level of the group_variable_name that was treated
self.untreated = (
untreated # level of the group_variable_name that was untreated
)
y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]

assert (
"treated" in formula
), "A predictor column called `treated` should be in the provided dataframe"

# TODO: check that data in column self.group_variable_name has TWO levels

# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data

# DEVIATION FROM SKL EXPERIMENT CODE =============================
Expand All @@ -232,23 +246,37 @@ def __init__(
self.prediction_model.fit(X=self.X, y=self.y, coords=COORDS)
# ================================================================

time_levels = self.data[self.time_variable_name].unique()

# predicted outcome for control group
self.x_pred_control = pd.DataFrame(
{"group": [0, 0], "t": [0.0, 1.0], "treated": [0, 0]}
{
self.group_variable_name: [self.untreated, self.untreated],
self.time_variable_name: time_levels,
"treated": [0, 0],
}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for treatment group
self.x_pred_treatment = pd.DataFrame(
{"group": [1, 1], "t": [0.0, 1.0], "treated": [0, 1]}
{
self.group_variable_name: [self.treated, self.treated],
self.time_variable_name: time_levels,
"treated": [0, 1],
}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for counterfactual
self.x_pred_counterfactual = pd.DataFrame(
{"group": [1], "t": [1.0], "treated": [0]}
{
self.group_variable_name: [self.treated],
self.time_variable_name: time_levels[1],
"treated": [0],
}
)
(new_x,) = build_design_matrices(
[self._x_design_info], self.x_pred_counterfactual
Expand Down Expand Up @@ -278,7 +306,7 @@ def plot(self):
self.data,
x=self.time_variable_name,
y=self.outcome_variable_name,
hue="group",
hue=self.group_variable_name,
units="unit",
estimator=None,
alpha=0.25,
Expand Down
38 changes: 25 additions & 13 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

291 changes: 274 additions & 17 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

0 comments on commit 7840611

Please sign in to comment.